Skip to content

Commit

Permalink
Merge pull request #12 from stealthrocket/client
Browse files Browse the repository at this point in the history
Client
  • Loading branch information
pelletier authored Feb 1, 2024
2 parents 53f8760 + a482fdc commit f70ce87
Show file tree
Hide file tree
Showing 8 changed files with 444 additions and 79 deletions.
153 changes: 153 additions & 0 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,155 @@
"""The Dispatch SDK for Python.
"""

from __future__ import annotations
import pickle
import os
from urllib.parse import urlparse
from functools import cached_property
from collections.abc import Iterable
from typing import Any, TypeAlias
from dataclasses import dataclass

import grpc
import google.protobuf

import ring.record.v1.record_pb2 as record_pb
import ring.task.v1.service_pb2 as service
import ring.task.v1.service_pb2_grpc as service_grpc
import dispatch.coroutine


__all__ = ["Client", "TaskID", "TaskInput", "TaskDef"]


@dataclass(frozen=True, repr=False)
class TaskID:
"""Unique task identifier in Dispatch.
It should be treated as an opaque value.
"""

partition_number: int
block_id: int
record_offset: int
record_size: int

@classmethod
def _from_proto(cls, proto: record_pb.ID) -> TaskID:
return cls(
partition_number=proto.partition_number,
block_id=proto.block_id,
record_offset=proto.record_offset,
record_size=proto.record_size,
)

def _to_proto(self) -> record_pb.ID:
return record_pb.ID(
partition_number=self.partition_number,
block_id=self.block_id,
record_offset=self.record_offset,
record_size=self.record_size,
)

def __str__(self) -> str:
parts = [
self.partition_number,
self.block_id,
self.record_offset,
self.record_size,
]
return "".join("{:08x}".format(a) for a in parts)

def __repr__(self) -> str:
return f"TaskID({self})"


@dataclass(frozen=True)
class TaskInput:
"""Definition of a task to be created on Dispatch.
Attributes:
coroutine_uri: The URI of the coroutine to execute.
input: The input to pass to the coroutine. If the input is a protobuf
message, it will be wrapped in a google.protobuf.Any message. If the
input is not a protobuf message, it will be pickled and wrapped in a
google.protobuf.Any message.
"""

coroutine_uri: str
input: Any


TaskDef: TypeAlias = TaskInput | dispatch.coroutine.Call
"""Definition of a task to be created on Dispatch.
Can be either a TaskInput or a Call. TaskInput can be created manually, likely
to call a coroutine outside the current code base. Call is created by the
`dispatch.coroutine` module and is used to call a coroutine defined in the
current code base.
"""


def _taskdef_to_proto(taskdef: TaskDef) -> service.CreateTaskInput:
input = taskdef.input
match input:
case google.protobuf.any_pb2.Any():
input_any = input
case google.protobuf.message.Message():
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(input)
case _:
pickled = pickle.dumps(input)
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
return service.CreateTaskInput(coroutine_uri=taskdef.coroutine_uri, input=input_any)


class Client:
"""Client for the Dispatch API."""

def __init__(
self, api_key: None | str = None, api_url="https://api.stealthrocket.cloud"
):
"""Create a new Dispatch client.
Args:
api_key: Dispatch API key to use for authentication. Uses the value of
the DISPATCH_API_KEY environment variable by default.
api_url: The URL of the Dispatch API to use. Defaults to the public
Dispatch API.
Raises:
ValueError: if the API key is missing.
"""
if not api_key:
api_key = os.environ.get("DISPATCH_API_KEY")
if not api_key:
raise ValueError("api_key is required")

result = urlparse(api_url)
match result.scheme:
case "http":
creds = grpc.local_channel_credentials()
case "https":
creds = grpc.ssl_channel_credentials()
case _:
raise ValueError(f"Invalid API scheme: '{result.scheme}'")

call_creds = grpc.access_token_call_credentials(api_key)
creds = grpc.composite_channel_credentials(creds, call_creds)
channel = grpc.secure_channel(result.netloc, creds)

self._stub = service_grpc.ServiceStub(channel)

def create_tasks(self, tasks: Iterable[TaskDef]) -> Iterable[TaskID]:
"""Create tasks on Dispatch using the provided inputs.
Returns:
The ID of the created tasks, in the same order as the inputs.
"""
req = service.CreateTasksRequest()
for task in tasks:
req.tasks.append(_taskdef_to_proto(task))
resp = self._stub.CreateTasks(req)
return [TaskID._from_proto(x.id) for x in resp.tasks]
16 changes: 7 additions & 9 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

from __future__ import annotations
import enum
from typing import Any
import pickle
from typing import Any, Callable
from dataclasses import dataclass

import google.protobuf.message
import pickle

from ring.coroutine.v1 import coroutine_pb2
from ring.status.v1 import status_pb2

Expand Down Expand Up @@ -77,15 +79,16 @@ class Coroutine:
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK."""

def __init__(self, func):
def __init__(self, uri: str, func: Callable[[Input], Output]):
self._uri = uri
self._func = func

def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)

@property
def uri(self) -> str:
return self._func.__qualname__
return self._uri

def call_with(self, input: Any, correlation_id: int | None = None) -> Call:
"""Create a Call of this coroutine with the provided input. Useful to
Expand Down Expand Up @@ -362,8 +365,3 @@ def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any


def _coroutine_uri_to_qualname(coroutine_uri: str) -> str:
# TODO: fix this when we decide on the format of coroutine URIs.
return coroutine_uri.split("/")[-1]
52 changes: 37 additions & 15 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@ def read_root():
my_cool_coroutine.call()
"""

import ring.coroutine.v1.coroutine_pb2
from collections.abc import Callable
from typing import Any
import os
from typing import Any, Dict
from collections.abc import Callable

import fastapi
import fastapi.responses
import google.protobuf.wrappers_pb2
from httpx import _urlparse

import ring.coroutine.v1.coroutine_pb2
import dispatch.coroutine


def configure(
app: fastapi.FastAPI,
public_url: str,
api_key: None | str = None,
):
"""Configure the FastAPI app to use Dispatch programmable endpoints.
Expand All @@ -40,6 +43,8 @@ def configure(
app: The FastAPI app to configure.
api_key: Dispatch API key to use for authentication. Uses the value of
the DISPATCH_API_KEY environment variable by default.
public_url: Full URL of the application the dispatch programmable
endpoint will be running on.
Raises:
ValueError: If any of the required arguments are missing.
Expand All @@ -48,19 +53,26 @@ def configure(

if not app:
raise ValueError("app is required")
if not public_url:
raise ValueError("public_url is required")
if not api_key:
raise ValueError("api_key is required")

dispatch_app = _new_app()
parsed_url = _urlparse.urlparse(public_url)
if not parsed_url.netloc or not parsed_url.scheme:
raise ValueError("public_url must be a full URL with protocol and domain")

dispatch_app = _new_app(public_url)

app.__setattr__("dispatch_coroutine", dispatch_app.dispatch_coroutine)
app.mount("/ring.coroutine.v1.ExecutorService", dispatch_app)


class _DispatchAPI(fastapi.FastAPI):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._coroutines = {}
def __init__(self, public_url: str):
super().__init__()
self._coroutines: Dict[str, dispatch.coroutine.Coroutine] = {}
self._public_url = _urlparse.urlparse(public_url)

def dispatch_coroutine(self):
"""Register a coroutine with the Dispatch programmable endpoints.
Expand All @@ -74,7 +86,9 @@ def dispatch_coroutine(self):
"""

def wrap(func: Callable[[dispatch.coroutine.Input], dispatch.coroutine.Output]):
coro = dispatch.coroutine.Coroutine(func)
name = func.__qualname__
uri = str(self._public_url.copy_with(fragment="function=" + name))
coro = dispatch.coroutine.Coroutine(uri, func)
if coro.uri in self._coroutines:
raise ValueError(f"Coroutine {coro.uri} already registered")
self._coroutines[coro.uri] = coro
Expand All @@ -87,9 +101,8 @@ class _GRPCResponse(fastapi.Response):
media_type = "application/grpc+proto"


def _new_app():
app = _DispatchAPI()
app._coroutines = {}
def _new_app(public_url: str):
app = _DispatchAPI(public_url)

@app.post(
# The endpoint for execution is hardcoded at the moment. If the service
Expand All @@ -113,9 +126,18 @@ async def execute(request: fastapi.Request):

# TODO: be more graceful. This will crash if the coroutine is not found,
# and the coroutine version is not taken into account.
coroutine = app._coroutines[
dispatch.coroutine._coroutine_uri_to_qualname(req.coroutine_uri)
]

uri = req.coroutine_uri

coroutine = app._coroutines.get(uri, None)
if coroutine is None:
# TODO: integrate with logging
print("Coroutine not found:")
print(" uri:", uri)
print("Available coroutines:")
for k in app._coroutines:
print(" ", k)
raise KeyError(f"coroutine '{uri}' not available on this system")

coro_input = dispatch.coroutine.Input(req)

Expand Down
Loading

0 comments on commit f70ce87

Please sign in to comment.