diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 47495b34..cae6f76c 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -1,2 +1,155 @@ """The Dispatch SDK for Python. """ + +from __future__ import annotations +import pickle +import os +from urllib.parse import urlparse +from functools import cached_property +from collections.abc import Iterable +from typing import Any, TypeAlias +from dataclasses import dataclass + +import grpc +import google.protobuf + +import ring.record.v1.record_pb2 as record_pb +import ring.task.v1.service_pb2 as service +import ring.task.v1.service_pb2_grpc as service_grpc +import dispatch.coroutine + + +__all__ = ["Client", "TaskID", "TaskInput", "TaskDef"] + + +@dataclass(frozen=True, repr=False) +class TaskID: + """Unique task identifier in Dispatch. + + It should be treated as an opaque value. + """ + + partition_number: int + block_id: int + record_offset: int + record_size: int + + @classmethod + def _from_proto(cls, proto: record_pb.ID) -> TaskID: + return cls( + partition_number=proto.partition_number, + block_id=proto.block_id, + record_offset=proto.record_offset, + record_size=proto.record_size, + ) + + def _to_proto(self) -> record_pb.ID: + return record_pb.ID( + partition_number=self.partition_number, + block_id=self.block_id, + record_offset=self.record_offset, + record_size=self.record_size, + ) + + def __str__(self) -> str: + parts = [ + self.partition_number, + self.block_id, + self.record_offset, + self.record_size, + ] + return "".join("{:08x}".format(a) for a in parts) + + def __repr__(self) -> str: + return f"TaskID({self})" + + +@dataclass(frozen=True) +class TaskInput: + """Definition of a task to be created on Dispatch. + + Attributes: + coroutine_uri: The URI of the coroutine to execute. + input: The input to pass to the coroutine. If the input is a protobuf + message, it will be wrapped in a google.protobuf.Any message. If the + input is not a protobuf message, it will be pickled and wrapped in a + google.protobuf.Any message. + """ + + coroutine_uri: str + input: Any + + +TaskDef: TypeAlias = TaskInput | dispatch.coroutine.Call +"""Definition of a task to be created on Dispatch. + +Can be either a TaskInput or a Call. TaskInput can be created manually, likely +to call a coroutine outside the current code base. Call is created by the +`dispatch.coroutine` module and is used to call a coroutine defined in the +current code base. +""" + + +def _taskdef_to_proto(taskdef: TaskDef) -> service.CreateTaskInput: + input = taskdef.input + match input: + case google.protobuf.any_pb2.Any(): + input_any = input + case google.protobuf.message.Message(): + input_any = google.protobuf.any_pb2.Any() + input_any.Pack(input) + case _: + pickled = pickle.dumps(input) + input_any = google.protobuf.any_pb2.Any() + input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) + return service.CreateTaskInput(coroutine_uri=taskdef.coroutine_uri, input=input_any) + + +class Client: + """Client for the Dispatch API.""" + + def __init__( + self, api_key: None | str = None, api_url="https://api.stealthrocket.cloud" + ): + """Create a new Dispatch client. + + Args: + api_key: Dispatch API key to use for authentication. Uses the value of + the DISPATCH_API_KEY environment variable by default. + api_url: The URL of the Dispatch API to use. Defaults to the public + Dispatch API. + + Raises: + ValueError: if the API key is missing. + """ + if not api_key: + api_key = os.environ.get("DISPATCH_API_KEY") + if not api_key: + raise ValueError("api_key is required") + + result = urlparse(api_url) + match result.scheme: + case "http": + creds = grpc.local_channel_credentials() + case "https": + creds = grpc.ssl_channel_credentials() + case _: + raise ValueError(f"Invalid API scheme: '{result.scheme}'") + + call_creds = grpc.access_token_call_credentials(api_key) + creds = grpc.composite_channel_credentials(creds, call_creds) + channel = grpc.secure_channel(result.netloc, creds) + + self._stub = service_grpc.ServiceStub(channel) + + def create_tasks(self, tasks: Iterable[TaskDef]) -> Iterable[TaskID]: + """Create tasks on Dispatch using the provided inputs. + + Returns: + The ID of the created tasks, in the same order as the inputs. + """ + req = service.CreateTasksRequest() + for task in tasks: + req.tasks.append(_taskdef_to_proto(task)) + resp = self._stub.CreateTasks(req) + return [TaskID._from_proto(x.id) for x in resp.tasks] diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index 43ca7202..b360e5ab 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -12,10 +12,12 @@ from __future__ import annotations import enum -from typing import Any +import pickle +from typing import Any, Callable from dataclasses import dataclass + import google.protobuf.message -import pickle + from ring.coroutine.v1 import coroutine_pb2 from ring.status.v1 import status_pb2 @@ -77,7 +79,8 @@ class Coroutine: """Callable wrapper around a function meant to be used throughout the Dispatch Python SDK.""" - def __init__(self, func): + def __init__(self, uri: str, func: Callable[[Input], Output]): + self._uri = uri self._func = func def __call__(self, *args, **kwargs): @@ -85,7 +88,7 @@ def __call__(self, *args, **kwargs): @property def uri(self) -> str: - return self._func.__qualname__ + return self._uri def call_with(self, input: Any, correlation_id: int | None = None) -> Call: """Create a Call of this coroutine with the provided input. Useful to @@ -362,8 +365,3 @@ def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any: pb_any = google.protobuf.any_pb2.Any() pb_any.Pack(pb_bytes) return pb_any - - -def _coroutine_uri_to_qualname(coroutine_uri: str) -> str: - # TODO: fix this when we decide on the format of coroutine URIs. - return coroutine_uri.split("/")[-1] diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 38d4aaa1..f0929973 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -17,18 +17,21 @@ def read_root(): my_cool_coroutine.call() """ -import ring.coroutine.v1.coroutine_pb2 -from collections.abc import Callable -from typing import Any import os +from typing import Any, Dict +from collections.abc import Callable + import fastapi import fastapi.responses -import google.protobuf.wrappers_pb2 +from httpx import _urlparse + +import ring.coroutine.v1.coroutine_pb2 import dispatch.coroutine def configure( app: fastapi.FastAPI, + public_url: str, api_key: None | str = None, ): """Configure the FastAPI app to use Dispatch programmable endpoints. @@ -40,6 +43,8 @@ def configure( app: The FastAPI app to configure. api_key: Dispatch API key to use for authentication. Uses the value of the DISPATCH_API_KEY environment variable by default. + public_url: Full URL of the application the dispatch programmable + endpoint will be running on. Raises: ValueError: If any of the required arguments are missing. @@ -48,19 +53,26 @@ def configure( if not app: raise ValueError("app is required") + if not public_url: + raise ValueError("public_url is required") if not api_key: raise ValueError("api_key is required") - dispatch_app = _new_app() + parsed_url = _urlparse.urlparse(public_url) + if not parsed_url.netloc or not parsed_url.scheme: + raise ValueError("public_url must be a full URL with protocol and domain") + + dispatch_app = _new_app(public_url) app.__setattr__("dispatch_coroutine", dispatch_app.dispatch_coroutine) app.mount("/ring.coroutine.v1.ExecutorService", dispatch_app) class _DispatchAPI(fastapi.FastAPI): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._coroutines = {} + def __init__(self, public_url: str): + super().__init__() + self._coroutines: Dict[str, dispatch.coroutine.Coroutine] = {} + self._public_url = _urlparse.urlparse(public_url) def dispatch_coroutine(self): """Register a coroutine with the Dispatch programmable endpoints. @@ -74,7 +86,9 @@ def dispatch_coroutine(self): """ def wrap(func: Callable[[dispatch.coroutine.Input], dispatch.coroutine.Output]): - coro = dispatch.coroutine.Coroutine(func) + name = func.__qualname__ + uri = str(self._public_url.copy_with(fragment="function=" + name)) + coro = dispatch.coroutine.Coroutine(uri, func) if coro.uri in self._coroutines: raise ValueError(f"Coroutine {coro.uri} already registered") self._coroutines[coro.uri] = coro @@ -87,9 +101,8 @@ class _GRPCResponse(fastapi.Response): media_type = "application/grpc+proto" -def _new_app(): - app = _DispatchAPI() - app._coroutines = {} +def _new_app(public_url: str): + app = _DispatchAPI(public_url) @app.post( # The endpoint for execution is hardcoded at the moment. If the service @@ -113,9 +126,18 @@ async def execute(request: fastapi.Request): # TODO: be more graceful. This will crash if the coroutine is not found, # and the coroutine version is not taken into account. - coroutine = app._coroutines[ - dispatch.coroutine._coroutine_uri_to_qualname(req.coroutine_uri) - ] + + uri = req.coroutine_uri + + coroutine = app._coroutines.get(uri, None) + if coroutine is None: + # TODO: integrate with logging + print("Coroutine not found:") + print(" uri:", uri) + print("Available coroutines:") + for k in app._coroutines: + print(" ", k) + raise KeyError(f"coroutine '{uri}' not available on this system") coro_input = dispatch.coroutine.Input(req) diff --git a/tests/task_service.py b/tests/task_service.py new file mode 100644 index 00000000..f83779ed --- /dev/null +++ b/tests/task_service.py @@ -0,0 +1,105 @@ +import concurrent.futures.thread + +import grpc + +import ring.task.v1.service_pb2 as service_pb +import ring.task.v1.service_pb2_grpc as service_grpc +from ring.coroutine.v1 import coroutine_pb2_grpc as coroutine_grpc +from ring.coroutine.v1 import coroutine_pb2 as coroutine_pb +from dispatch import Client, TaskInput, TaskID + + +_test_auth_token = "THIS_IS_A_TEST_AUTH_TOKEN" + + +class FakeRing(service_grpc.ServiceServicer): + def __init__(self): + super().__init__() + self.current_partition = 1 + self.current_block_id = 1 + self.current_offset = 0 + + self.created_tasks = [] + self.responses = {} # indexed by task id + + self.pending_tasks = [] + + def _validate_authentication(self, context: grpc.ServicerContext): + expected = f"Bearer {_test_auth_token}" + for key, value in context.invocation_metadata(): + if key == "authorization": + if value == expected: + return + context.abort( + grpc.StatusCode.UNAUTHENTICATED, + f"Invalid authorization header. Expected '{expected}', got '{value!r}'", + ) + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Missing authorization header") + + def CreateTasks(self, request: service_pb.CreateTasksRequest, context): + self._validate_authentication(context) + + resp = service_pb.CreateTasksResponse() + + for t in request.tasks: + id = TaskID( + partition_number=self.current_partition, + block_id=self.current_block_id, + record_offset=self.current_offset, + record_size=1, + ) + self.current_offset += 1 + self.created_tasks.append({"id": id, "task": t}) + self.pending_tasks.append({"id": id, "task": t}) + resp.tasks.append(service_pb.CreateTaskOutput(id=id._to_proto())) + self.current_block_id += 1 + + return resp + + def execute(self, client: coroutine_grpc.ExecutorServiceStub): + """Synchronously execute all the pending tasks until there is no + pending task left. + + """ + while len(self.pending_tasks) > 0: + entry = self.pending_tasks.pop(0) + task = entry["task"] + + req = coroutine_pb.ExecuteRequest( + coroutine_uri=task.coroutine_uri, input=task.input + ) + + resp = client.Execute(req) + self.responses[entry["id"]] = resp + + +class ServerTest: + """Server test is a test fixture that starts a fake task service server and + provides a client setup to talk to it. + + Instantiate in a setUp() method and call stop() in a tearDown() method. + + """ + + def __init__(self): + self.thread_pool = concurrent.futures.thread.ThreadPoolExecutor() + self.server = grpc.server(self.thread_pool) + + port = self.server.add_insecure_port("127.0.0.1:0") + + self.servicer = FakeRing() + + service_grpc.add_ServiceServicer_to_server(self.servicer, self.server) + self.server.start() + + self.client = Client( + api_key=_test_auth_token, api_url=f"http://127.0.0.1:{port}" + ) + + def stop(self): + self.server.stop(0) + self.server.wait_for_termination() + self.thread_pool.shutdown(wait=True, cancel_futures=True) + + def execute(self, client: coroutine_grpc.ExecutorServiceStub): + return self.servicer.execute(client) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..23e915b4 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,63 @@ +import unittest + +from google.protobuf import wrappers_pb2, any_pb2 + +from dispatch import Client, TaskInput, TaskID +from dispatch.coroutine import _any_unpickle as any_unpickle +from .task_service import ServerTest + + +class TestClient(unittest.TestCase): + def setUp(self): + self.server = ServerTest() + # shortcuts + self.servicer = self.server.servicer + self.client = self.server.client + + def tearDown(self): + self.server.stop() + + def test_create_one_task_pickle(self): + results = self.client.create_tasks( + [TaskInput(coroutine_uri="my-cool-coroutine", input=42)] + ) + self.assertEqual(len(results), 1) + id = results[0] + self.assertTrue(id.partition_number != 0) + self.assertTrue(id.block_id != 0) + + created_tasks = self.servicer.created_tasks + self.assertEqual(len(created_tasks), 1) + entry = created_tasks[0] + self.assertEqual(entry["id"], id) + task = entry["task"] + self.assertEqual(task.coroutine_uri, "my-cool-coroutine") + self.assertEqual(any_unpickle(task.input), 42) + + def test_create_one_task_proto(self): + proto = wrappers_pb2.Int32Value(value=42) + results = self.client.create_tasks( + [TaskInput(coroutine_uri="my-cool-coroutine", input=proto)] + ) + id = results[0] + created_tasks = self.servicer.created_tasks + entry = created_tasks[0] + task = entry["task"] + # proto has been wrapper in an any + x = wrappers_pb2.Int32Value() + task.input.Unpack(x) + self.assertEqual(x, proto) + + def test_create_one_task_proto_any(self): + proto = wrappers_pb2.Int32Value(value=42) + proto_any = any_pb2.Any() + proto_any.Pack(proto) + results = self.client.create_tasks( + [TaskInput(coroutine_uri="my-cool-coroutine", input=proto)] + ) + id = results[0] + created_tasks = self.servicer.created_tasks + entry = created_tasks[0] + task = entry["task"] + # proto any has not been modified + self.assertEqual(task.input, proto_any) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index ec48a790..eaf221d6 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -15,7 +15,9 @@ class TestFastAPI(unittest.TestCase): def test_configure(self): app = fastapi.FastAPI() - dispatch.fastapi.configure(app, api_key="test-key") + dispatch.fastapi.configure( + app, api_key="test-key", public_url="https://127.0.0.1:9999" + ) @app.get("/") def read_root(): @@ -34,16 +36,27 @@ def read_root(): def test_configure_no_app(self): with self.assertRaises(ValueError): - dispatch.fastapi.configure(None, api_key="test-key") + dispatch.fastapi.configure( + None, api_key="test-key", public_url="http://127.0.0.1:9999" + ) def test_configure_no_api_key(self): app = fastapi.FastAPI() with self.assertRaises(ValueError): - dispatch.fastapi.configure(app, api_key=None) + dispatch.fastapi.configure( + app, api_key=None, public_url="http://127.0.0.1:9999" + ) + + def test_configure_no_public_url(self): + app = fastapi.FastAPI() + with self.assertRaises(ValueError): + dispatch.fastapi.configure(app, api_key="test", public_url="") def test_fastapi_simple_request(self): app = fastapi.FastAPI() - dispatch.fastapi.configure(app, api_key="test-key") + dispatch.fastapi.configure( + app, api_key="test-key", public_url="http://127.0.0.1:9999/" + ) @app.dispatch_coroutine() def my_cool_coroutine(input: Input) -> Output: @@ -86,7 +99,9 @@ def response_output(resp: coroutine_pb2.ExecuteResponse) -> Any: class TestCoroutine(unittest.TestCase): def setUp(self): self.app = fastapi.FastAPI() - dispatch.fastapi.configure(self.app, api_key="test-key") + dispatch.fastapi.configure( + self.app, api_key="test-key", public_url="https://127.0.0.1:9999" + ) http_client = TestClient(self.app) self.client = executor_service.client(http_client) diff --git a/tests/test_full.py b/tests/test_full.py new file mode 100644 index 00000000..1f51c915 --- /dev/null +++ b/tests/test_full.py @@ -0,0 +1,59 @@ +import unittest + +import fastapi +from fastapi.testclient import TestClient + +from dispatch import Client, TaskInput, TaskID +from dispatch.coroutine import Input, Output, Error, Status +from dispatch.coroutine import _any_unpickle as any_unpickle +import dispatch.fastapi +from .test_client import ServerTest +from . import executor_service + + +class TestFullFastapi(unittest.TestCase): + def setUp(self): + self.app = fastapi.FastAPI() + dispatch.fastapi.configure( + self.app, api_key="test-key", public_url="http://test" + ) + http_client = TestClient(self.app) + self.app_client = executor_service.client(http_client) + self.server = ServerTest() + # shortcuts + self.client = self.server.client + self.servicer = self.server.servicer + + def tearDown(self): + self.server.stop() + + def execute_tasks(self): + self.server.execute(self.app_client) + + def test_simple_end_to_end(self): + # The FastAPI server. + @self.app.dispatch_coroutine() + def my_cool_coroutine(input: Input) -> Output: + return Output.value(f"Hello world: {input.input}") + + # The client. + [task_id] = self.client.create_tasks( + [TaskInput(coroutine_uri=my_cool_coroutine.uri, input=52)] + ) + + # Simulate execution for testing purposes. + self.execute_tasks() + + # Validate results. + resp = self.servicer.responses[task_id] + self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52") + + def test_simple_call_with(self): + @self.app.dispatch_coroutine() + def my_cool_coroutine(input: Input) -> Output: + return Output.value(f"Hello world: {input.input}") + + [task_id] = self.client.create_tasks([my_cool_coroutine.call_with(52)]) + self.execute_tasks() + resp = self.servicer.responses[task_id] + self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52") diff --git a/tests/test_ring.py b/tests/test_ring.py deleted file mode 100644 index 4acb14b1..00000000 --- a/tests/test_ring.py +++ /dev/null @@ -1,50 +0,0 @@ -import unittest -import concurrent.futures.thread - -import dispatch -import ring.task.v1.service_pb2 as service_pb -import ring.task.v1.service_pb2_grpc as service_grpc -import dispatch.http.v1.http_pb2 -import grpc -import google.protobuf.any_pb2 - - -class FakeRing(service_grpc.ServiceServicer): - def CreateTasks(self, request, context): - return service_pb.CreateTasksResponse() - - -class TestRing(unittest.TestCase): - def setUp(self): - self.thread_pool = concurrent.futures.thread.ThreadPoolExecutor() - self.server = grpc.server(self.thread_pool) - - port = self.server.add_insecure_port("127.0.0.1:0") - - servicer = FakeRing() - - service_grpc.add_ServiceServicer_to_server(servicer, self.server) - self.server.start() - - channel = grpc.insecure_channel(f"127.0.0.1:{port}") - self.ring_stub = service_grpc.ServiceStub(channel) - - def tearDown(self): - self.server.stop(0) - self.server.wait_for_termination() - self.thread_pool.shutdown(wait=True, cancel_futures=True) - - def test_ring(self): - request = dispatch.http.v1.http_pb2.Request( - url="https://www.google.com", method="GET" - ) - - input = google.protobuf.any_pb2.Any() - input.Pack(request) - - create_task_input = service_pb.CreateTaskInput( - coroutine_uri="arn:aws:lambda:us-west-2:012345678912:function:dispatch-http", - input=input, - ) - req = service_pb.CreateTasksRequest(tasks=[create_task_input]) - resp = self.ring_stub.CreateTasks(req)