diff --git a/pyproject.toml b/pyproject.toml index 2aff75b7..71acc6fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,18 +10,17 @@ dynamic = ["version"] requires-python = ">= 3.8" dependencies = [ "aiohttp >= 3.9.4", - "grpcio >= 1.60.0", "protobuf >= 4.24.0", "types-protobuf >= 4.24.0.20240129", - "grpc-stubs >= 1.53.0.5", "http-message-signatures >= 0.5.0", "tblib >= 3.0.0", "typing_extensions >= 4.10" ] [project.optional-dependencies] -fastapi = ["fastapi", "httpx"] +fastapi = ["fastapi"] flask = ["flask"] +httpx = ["httpx"] lambda = ["awslambdaric"] dev = [ @@ -60,17 +59,12 @@ profile = "black" src_paths = ["src"] [tool.coverage.run] -omit = ["*_pb2_grpc.py", "*_pb2.py", "tests/*", "examples/*", "src/buf/*"] +omit = ["*_pb2.py", "tests/*", "examples/*", "src/buf/*"] [tool.mypy] exclude = [ '^src/buf', '^tests/examples', - # mypy 1.10.0 reports false positives for these two files: - # src/dispatch/sdk/v1/function_pb2_grpc.py:74: error: Module has no attribute "experimental" [attr-defined] - # src/dispatch/sdk/v1/dispatch_pb2_grpc.py:80: error: Module has no attribute "experimental" [attr-defined] - '^src/dispatch/sdk/v1/function_pb2_grpc.py', - '^src/dispatch/sdk/v1/dispatch_pb2_grpc.py', ] [tool.pytest.ini_options] diff --git a/src/buf/validate/expression_pb2_grpc.py b/src/buf/validate/expression_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/buf/validate/expression_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/buf/validate/priv/private_pb2_grpc.py b/src/buf/validate/priv/private_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/buf/validate/priv/private_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/buf/validate/validate_pb2_grpc.py b/src/buf/validate/validate_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/buf/validate/validate_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/function.py b/src/dispatch/function.py index dc8a54aa..54925813 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -333,7 +333,7 @@ def _register(self, name: str, wrapped_func: PrimitiveFunction): def batch(self) -> Batch: """Returns a Batch instance that can be used to build a set of calls to dispatch.""" - return self.client.batch() + return Batch(self.client) _registries: Dict[str, Registry] = {} diff --git a/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py b/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/call_pb2_grpc.py b/src/dispatch/sdk/v1/call_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/call_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/dispatch_pb2_grpc.py b/src/dispatch/sdk/v1/dispatch_pb2_grpc.py deleted file mode 100644 index 793cfbd3..00000000 --- a/src/dispatch/sdk/v1/dispatch_pb2_grpc.py +++ /dev/null @@ -1,94 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from dispatch.sdk.v1 import dispatch_pb2 as dispatch_dot_sdk_dot_v1_dot_dispatch__pb2 - - -class DispatchServiceStub(object): - """DispatchService is a service allowing the trigger of programmable endpoints - from a dispatch SDK. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Dispatch = channel.unary_unary( - "/dispatch.sdk.v1.DispatchService/Dispatch", - request_serializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchRequest.SerializeToString, - response_deserializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchResponse.FromString, - ) - - -class DispatchServiceServicer(object): - """DispatchService is a service allowing the trigger of programmable endpoints - from a dispatch SDK. - """ - - def Dispatch(self, request, context): - """Dispatch submits a list of asynchronous function calls to the service. - - The method does not wait for executions to complete before returning, - it only ensures that the creation was persisted, and returns unique - identifiers to represent the executions. - - The request contains a list of executions to be triggered; the method is - atomic, either all executions are recorded, or none and an error is - returned to explain the reason for the failure. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_DispatchServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - "Dispatch": grpc.unary_unary_rpc_method_handler( - servicer.Dispatch, - request_deserializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchRequest.FromString, - response_serializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "dispatch.sdk.v1.DispatchService", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class DispatchService(object): - """DispatchService is a service allowing the trigger of programmable endpoints - from a dispatch SDK. - """ - - @staticmethod - def Dispatch( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/dispatch.sdk.v1.DispatchService/Dispatch", - dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchRequest.SerializeToString, - dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) diff --git a/src/dispatch/sdk/v1/error_pb2_grpc.py b/src/dispatch/sdk/v1/error_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/error_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/exit_pb2_grpc.py b/src/dispatch/sdk/v1/exit_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/exit_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/function_pb2_grpc.py b/src/dispatch/sdk/v1/function_pb2_grpc.py deleted file mode 100644 index 82193b36..00000000 --- a/src/dispatch/sdk/v1/function_pb2_grpc.py +++ /dev/null @@ -1,88 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from dispatch.sdk.v1 import function_pb2 as dispatch_dot_sdk_dot_v1_dot_function__pb2 - - -class FunctionServiceStub(object): - """The FunctionService service is used to interface with programmable endpoints - exposing remote functions. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Run = channel.unary_unary( - "/dispatch.sdk.v1.FunctionService/Run", - request_serializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunRequest.SerializeToString, - response_deserializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunResponse.FromString, - ) - - -class FunctionServiceServicer(object): - """The FunctionService service is used to interface with programmable endpoints - exposing remote functions. - """ - - def Run(self, request, context): - """Run runs the function identified by the request, and returns a response - that either contains a result when the function completed, or a poll - directive and the associated coroutine state if the function was suspended. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_FunctionServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - "Run": grpc.unary_unary_rpc_method_handler( - servicer.Run, - request_deserializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunRequest.FromString, - response_serializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "dispatch.sdk.v1.FunctionService", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class FunctionService(object): - """The FunctionService service is used to interface with programmable endpoints - exposing remote functions. - """ - - @staticmethod - def Run( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/dispatch.sdk.v1.FunctionService/Run", - dispatch_dot_sdk_dot_v1_dot_function__pb2.RunRequest.SerializeToString, - dispatch_dot_sdk_dot_v1_dot_function__pb2.RunResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) diff --git a/src/dispatch/sdk/v1/poll_pb2_grpc.py b/src/dispatch/sdk/v1/poll_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/poll_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/status_pb2_grpc.py b/src/dispatch/sdk/v1/status_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/status_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test.py similarity index 99% rename from src/dispatch/test/__init__.py rename to src/dispatch/test.py index 7250c96d..2ad018d8 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test.py @@ -40,14 +40,7 @@ STATUS_TLS_ERROR, ) -from .client import EndpointClient -from .server import DispatchServer -from .service import DispatchService - __all__ = [ - "EndpointClient", - "DispatchServer", - "DispatchService", "function", "method", "main", diff --git a/src/dispatch/test/client.py b/src/dispatch/test/client.py deleted file mode 100644 index 6ff3ba88..00000000 --- a/src/dispatch/test/client.py +++ /dev/null @@ -1,155 +0,0 @@ -from datetime import datetime -from typing import Optional - -import grpc - -from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.sdk.v1 import function_pb2_grpc as function_grpc -from dispatch.signature import ( - CaseInsensitiveDict, - Ed25519PrivateKey, - Request, - sign_request, -) -from dispatch.test.http import HttpClient - - -class EndpointClient: - """Test client for a Dispatch programmable endpoint. - - Note that this is different from dispatch.Client, which is a client - for the Dispatch API. The EndpointClient is a client similar to the one - that Dispatch itself would use to interact with an endpoint that provides - functions. - """ - - def __init__( - self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None - ): - """Initialize the client. - - Args: - http_client: Client to use to make HTTP requests. - signing_key: Optional Ed25519 private key to use to sign requests. - """ - channel = _HttpGrpcChannel(http_client, signing_key=signing_key) - self._stub = function_grpc.FunctionServiceStub(channel) - - def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse: - """Send a run request to an endpoint and return its response. - - Args: - request: A FunctionService Run request. - - Returns: - RunResponse: the response from the endpoint. - """ - return self._stub.Run(request) - - -class _HttpGrpcChannel(grpc.Channel): - def __init__( - self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None - ): - self.http_client = http_client - self.signing_key = signing_key - - def subscribe(self, callback, try_to_connect=False): - raise NotImplementedError() - - def unsubscribe(self, callback): - raise NotImplementedError() - - def unary_unary(self, method, request_serializer=None, response_deserializer=None): - return _UnaryUnaryMultiCallable( - self.http_client, - method, - request_serializer, - response_deserializer, - self.signing_key, - ) - - def unary_stream(self, method, request_serializer=None, response_deserializer=None): - raise NotImplementedError() - - def stream_unary(self, method, request_serializer=None, response_deserializer=None): - raise NotImplementedError() - - def stream_stream( - self, method, request_serializer=None, response_deserializer=None - ): - raise NotImplementedError() - - def close(self): - raise NotImplementedError() - - def __enter__(self): - raise NotImplementedError() - - def __exit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError() - - -class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): - def __init__( - self, - client, - method, - request_serializer, - response_deserializer, - signing_key: Optional[Ed25519PrivateKey] = None, - ): - self.client = client - self.method = method - self.request_serializer = request_serializer - self.response_deserializer = response_deserializer - self.signing_key = signing_key - - def __call__( - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, - ): - url = self.client.url_for(self.method) # note: method==path in gRPC parlance - - request = Request( - method="POST", - url=url, - body=self.request_serializer(request), - headers=CaseInsensitiveDict({"Content-Type": "application/grpc+proto"}), - ) - - if self.signing_key is not None: - sign_request(request, self.signing_key, datetime.now()) - - response = self.client.post( - request.url, body=request.body, headers=request.headers - ) - response.raise_for_status() - return self.response_deserializer(response.body) - - def with_call( - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, - ): - raise NotImplementedError() - - def future( - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, - ): - raise NotImplementedError() diff --git a/src/dispatch/test/fastapi.py b/src/dispatch/test/fastapi.py deleted file mode 100644 index 381b1800..00000000 --- a/src/dispatch/test/fastapi.py +++ /dev/null @@ -1,10 +0,0 @@ -from fastapi import FastAPI -from fastapi.testclient import TestClient - -import dispatch.test.httpx -from dispatch.test.client import HttpClient - - -def http_client(app: FastAPI) -> HttpClient: - """Build a client for a FastAPI app.""" - return dispatch.test.httpx.Client(TestClient(app)) diff --git a/src/dispatch/test/flask.py b/src/dispatch/test/flask.py deleted file mode 100644 index e8cc3cbe..00000000 --- a/src/dispatch/test/flask.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Mapping - -import werkzeug.test -from flask import Flask - -from dispatch.test.http import HttpClient, HttpResponse - - -def http_client(app: Flask) -> HttpClient: - """Build a client for a Flask app.""" - return Client(app.test_client()) - - -class Client(HttpClient): - def __init__(self, client: werkzeug.test.Client): - self.client = client - - def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: - response = self.client.get(url, headers=headers.items()) - return Response(response) - - def post( - self, url: str, body: bytes, headers: Mapping[str, str] = {} - ) -> HttpResponse: - response = self.client.post(url, data=body, headers=headers.items()) - return Response(response) - - def url_for(self, path: str) -> str: - return "http://localhost" + path - - -class Response(HttpResponse): - def __init__(self, response): - self.response = response - - @property - def status_code(self): - return self.response.status_code - - @property - def body(self): - return self.response.data - - def raise_for_status(self): - if self.response.status_code // 100 != 2: - raise RuntimeError(f"HTTP status code {self.response.status_code}") diff --git a/src/dispatch/test/http.py b/src/dispatch/test/http.py deleted file mode 100644 index cf7ba9fa..00000000 --- a/src/dispatch/test/http.py +++ /dev/null @@ -1,34 +0,0 @@ -from dataclasses import dataclass -from typing import Mapping, Protocol - -import aiohttp - -from dispatch.function import Client as DefaultClient - - -@dataclass -class HttpResponse(Protocol): - status_code: int - body: bytes - - def raise_for_status(self): - """Raise an exception on non-2xx responses.""" - ... - - -class HttpClient(Protocol): - """Protocol for HTTP clients.""" - - def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: - """Make a GET request.""" - ... - - def post( - self, url: str, body: bytes, headers: Mapping[str, str] = {} - ) -> HttpResponse: - """Make a POST request.""" - ... - - def url_for(self, path: str) -> str: - """Get the fully-qualified URL for a path.""" - ... diff --git a/src/dispatch/test/httpx.py b/src/dispatch/test/httpx.py deleted file mode 100644 index 9d9f7c52..00000000 --- a/src/dispatch/test/httpx.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Mapping - -import httpx - -from dispatch.test.http import HttpClient, HttpResponse - - -class Client(HttpClient): - def __init__(self, client: httpx.Client): - self.client = client - - def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: - response = self.client.get(url, headers=headers) - return Response(response) - - def post( - self, url: str, body: bytes, headers: Mapping[str, str] = {} - ) -> HttpResponse: - response = self.client.post(url, content=body, headers=headers) - return Response(response) - - def url_for(self, path: str) -> str: - return str(httpx.URL(self.client.base_url).join(path)) - - -class Response(HttpResponse): - def __init__(self, response: httpx.Response): - self.response = response - - @property - def status_code(self): - return self.response.status_code - - @property - def body(self): - return self.response.content - - def raise_for_status(self): - self.response.raise_for_status() diff --git a/src/dispatch/test/server.py b/src/dispatch/test/server.py deleted file mode 100644 index a2d022b8..00000000 --- a/src/dispatch/test/server.py +++ /dev/null @@ -1,61 +0,0 @@ -import concurrent.futures -import sys - -import grpc - -from dispatch.sdk.v1 import dispatch_pb2_grpc as dispatch_grpc - - -class DispatchServer: - """Test server for a Dispatch service. This is useful for testing - a mock version of Dispatch locally (e.g. see - dispatch.test.DispatchService). - - Args: - service: Dispatch service to serve. - hostname: Hostname to bind to. - port: Port to bind to, or 0 to bind to any available port. - """ - - def __init__( - self, - service: dispatch_grpc.DispatchServiceServicer, - hostname: str = "127.0.0.1", - port: int = 0, - ): - self._thread_pool = concurrent.futures.thread.ThreadPoolExecutor() - self._server = grpc.server(self._thread_pool) - - self._hostname = hostname - self._port = self._server.add_insecure_port(f"{hostname}:{port}") - - dispatch_grpc.add_DispatchServiceServicer_to_server(service, self._server) - - @property - def url(self): - """Returns the URL of the server.""" - return f"http://{self._hostname}:{self._port}" - - def start(self): - """Start the server.""" - self._server.start() - - def wait(self): - """Block until the server terminates.""" - self._server.wait_for_termination() - - def stop(self): - """Stop the server.""" - self._server.stop(0) - self._server.wait_for_termination() - if sys.version_info >= (3, 9): - self._thread_pool.shutdown(wait=True, cancel_futures=True) - else: - self._thread_pool.shutdown(wait=True) - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py deleted file mode 100644 index ac23738b..00000000 --- a/src/dispatch/test/service.py +++ /dev/null @@ -1,362 +0,0 @@ -import enum -import logging -import os -import threading -import time -from collections import OrderedDict -from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple - -import grpc -from google.protobuf import any_pb2 as any_pb -from typing_extensions import TypeAlias - -import dispatch.sdk.v1.call_pb2 as call_pb -import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb -import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc -import dispatch.sdk.v1.function_pb2 as function_pb -import dispatch.sdk.v1.poll_pb2 as poll_pb -from dispatch.id import DispatchID -from dispatch.proto import CallResult, Error, Status -from dispatch.test import EndpointClient - -_default_retry_on_status = { - Status.THROTTLED, - Status.TIMEOUT, - Status.TEMPORARY_ERROR, - Status.DNS_ERROR, - Status.TCP_ERROR, - Status.TLS_ERROR, - Status.HTTP_ERROR, -} - - -logger = logging.getLogger(__name__) - - -RoundTrip: TypeAlias = Tuple[function_pb.RunRequest, function_pb.RunResponse] -"""A request to a Dispatch endpoint, and the response that was received.""" - - -class CallType(enum.Enum): - """Type of function call.""" - - CALL = 0 - RESUME = 1 - RETRY = 2 - - -class DispatchService(dispatch_grpc.DispatchServiceServicer): - """Test instance of Dispatch that provides the bare minimum - functionality required to test functions locally.""" - - def __init__( - self, - endpoint_client: EndpointClient, - api_key: Optional[str] = None, - retry_on_status: Optional[Set[Status]] = None, - collect_roundtrips: bool = False, - ): - """Initialize the Dispatch service. - - Args: - endpoint_client: Client to use to interact with the local Dispatch - endpoint (that provides the functions). - api_key: Expected API key on requests to the service. If omitted, the - value of the DISPATCH_API_KEY environment variable is used instead. - retry_on_status: Set of status codes to enable retries for. - collect_roundtrips: Enable collection of request/response round-trips - to the configured endpoint. - """ - super().__init__() - - self.endpoint_client = endpoint_client - - if api_key is None: - api_key = os.getenv("DISPATCH_API_KEY") - self.api_key = api_key - - if retry_on_status is None: - retry_on_status = _default_retry_on_status - self.retry_on_status = retry_on_status - - self._next_dispatch_id = 1 - - self.queue: List[Tuple[DispatchID, function_pb.RunRequest, CallType]] = [] - - self.pollers: Dict[DispatchID, Poller] = {} - self.parents: Dict[DispatchID, Poller] = {} - - self.roundtrips: OrderedDict[DispatchID, List[RoundTrip]] = OrderedDict() - self.collect_roundtrips = collect_roundtrips - - self._thread: Optional[threading.Thread] = None - self._stop_event = threading.Event() - self._work_signal = threading.Condition() - - def Dispatch(self, request: dispatch_pb.DispatchRequest, context): - """RPC handler for Dispatch requests. Requests are only queued for - processing here.""" - self._validate_authentication(context) - - resp = dispatch_pb.DispatchResponse() - - with self._work_signal: - for call in request.calls: - dispatch_id = self._make_dispatch_id() - logger.debug("enqueueing call to function: %s", call.function) - resp.dispatch_ids.append(dispatch_id) - run_request = function_pb.RunRequest( - function=call.function, - input=call.input, - dispatch_id=dispatch_id, - root_dispatch_id=dispatch_id, - ) - self.queue.append((dispatch_id, run_request, CallType.CALL)) - - self._work_signal.notify() - - return resp - - def _validate_authentication(self, context: grpc.ServicerContext): - expected = f"Bearer {self.api_key}" - for key, value in context.invocation_metadata(): - if key == "authorization": - if value == expected: - return - logger.warning( - "a client attempted to dispatch a function call with an incorrect API key. Is the client's DISPATCH_API_KEY correct?" - ) - context.abort( - grpc.StatusCode.UNAUTHENTICATED, - f"Invalid authorization header. Expected '{expected}', got {value!r}", - ) - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Missing authorization header") - - def _make_dispatch_id(self) -> DispatchID: - dispatch_id = self._next_dispatch_id - self._next_dispatch_id += 1 - return "{:032x}".format(dispatch_id) - - def dispatch_calls(self): - """Synchronously dispatch pending function calls to the - configured endpoint.""" - _next_queue: List[Tuple[DispatchID, function_pb.RunRequest, CallType]] = [] - while self.queue: - dispatch_id, request, call_type = self.queue.pop(0) - - if call_type == CallType.CALL: - logger.info("calling function %s", request.function) - elif call_type == CallType.RESUME: - logger.info("resuming function %s", request.function) - elif call_type == CallType.RETRY: - logger.info("retrying function %s", request.function) - - try: - response = self.endpoint_client.run(request) - except: - logger.warning("call to function %s failed", request.function) - self.queue.extend(_next_queue) - self.queue.append((dispatch_id, request, CallType.RETRY)) - raise - - if self.collect_roundtrips: - try: - roundtrips = self.roundtrips[dispatch_id] - except KeyError: - roundtrips = [] - - roundtrips.append((request, response)) - self.roundtrips[dispatch_id] = roundtrips - - status = Status(response.status) - if status == Status.OK: - logger.info("call to function %s succeeded", request.function) - else: - exc = None - if response.HasField("exit"): - if response.exit.HasField("result"): - result = response.exit.result - if result.HasField("error"): - exc = Error._from_proto(result.error).to_exception() - - if exc is not None: - logger.warning( - "call to function %s failed (%s => %s: %s)", - request.function, - status, - exc.__class__.__name__, - str(exc), - ) - else: - logger.warning( - "call to function %s failed (%s)", - request.function, - status, - ) - - if status in self.retry_on_status: - _next_queue.append((dispatch_id, request, CallType.RETRY)) - - elif response.HasField("poll"): - assert not response.HasField("exit") - - logger.info("suspending function %s", request.function) - - logger.debug("registering poller %s", dispatch_id) - - assert dispatch_id not in self.pollers - poller = Poller( - id=dispatch_id, - parent_id=request.parent_dispatch_id, - root_id=request.root_dispatch_id, - function=request.function, - typed_coroutine_state=response.poll.typed_coroutine_state, - waiting={}, - results={}, - ) - self.pollers[dispatch_id] = poller - - for call in response.poll.calls: - child_dispatch_id = self._make_dispatch_id() - child_request = function_pb.RunRequest( - function=call.function, - input=call.input, - dispatch_id=child_dispatch_id, - parent_dispatch_id=request.dispatch_id, - root_dispatch_id=request.root_dispatch_id, - ) - - _next_queue.append( - (child_dispatch_id, child_request, CallType.CALL) - ) - self.parents[child_dispatch_id] = poller - poller.waiting[child_dispatch_id] = call - - else: - assert response.HasField("exit") - - if response.exit.HasField("tail_call"): - tail_call = response.exit.tail_call - logger.debug( - "enqueueing tail call for %s", - tail_call.function, - ) - tail_call_request = function_pb.RunRequest( - function=tail_call.function, - input=tail_call.input, - dispatch_id=request.dispatch_id, - parent_dispatch_id=request.parent_dispatch_id, - root_dispatch_id=request.root_dispatch_id, - ) - _next_queue.append((dispatch_id, tail_call_request, CallType.CALL)) - - elif dispatch_id in self.parents: - result = response.exit.result - poller = self.parents[dispatch_id] - logger.debug( - "notifying poller %s of call result %s", poller.id, dispatch_id - ) - - call = poller.waiting[dispatch_id] - result.correlation_id = call.correlation_id - poller.results[dispatch_id] = result - del self.parents[dispatch_id] - del poller.waiting[dispatch_id] - - logger.debug( - "poller %s has %d waiting and %d ready results", - poller.id, - len(poller.waiting), - len(poller.results), - ) - - if not poller.waiting: - logger.debug( - "poller %s is now ready; enqueueing delivery of %d call result(s)", - poller.id, - len(poller.results), - ) - poll_results_request = function_pb.RunRequest( - dispatch_id=poller.id, - parent_dispatch_id=poller.parent_id, - root_dispatch_id=poller.root_id, - function=poller.function, - poll_result=poll_pb.PollResult( - typed_coroutine_state=poller.typed_coroutine_state, - results=poller.results.values(), - ), - ) - del self.pollers[poller.id] - _next_queue.append( - (poller.id, poll_results_request, CallType.RESUME) - ) - - self.queue = _next_queue - - def start(self): - """Start starts a background thread to continuously dispatch calls to the - configured endpoint.""" - if self._thread is not None: - raise RuntimeError("service has already been started") - - self._stop_event.clear() - self._thread = threading.Thread(target=self._dispatch_continuously) - self._thread.start() - - def stop(self): - """Stop stops the background thread that's dispatching calls to - the configured endpoint.""" - self._stop_event.set() - with self._work_signal: - self._work_signal.notify() - if self._thread is not None: - self._thread.join() - self._thread = None - - def _dispatch_continuously(self): - while True: - with self._work_signal: - if not self.queue and not self._stop_event.is_set(): - self._work_signal.wait() - - if self._stop_event.is_set(): - break - - try: - self.dispatch_calls() - except Exception as e: - logger.exception(e) - - # Introduce an artificial delay before continuing with - # follow-up work (retries, dispatching nested calls). - # This serves two purposes. Firstly, this is just a mock - # Dispatch server providing the bare minimum of functionality. - # Since there's no adaptive concurrency control, and no backoff - # between call attempts, the mock server may busy-loop without - # some sort of delay. Secondly, a bit of latency mimics the - # latency you would see in a production system and makes the - # log output easier to parse. - time.sleep(0.15) - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - -@dataclass -class Poller: - id: DispatchID - parent_id: DispatchID - root_id: DispatchID - - function: str - - typed_coroutine_state: any_pb.Any - # TODO: support max_wait/min_results/max_results - - waiting: Dict[DispatchID, call_pb.Call] - results: Dict[DispatchID, call_pb.CallResult] diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index be4964aa..554a032a 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -37,8 +37,6 @@ public_key_from_pem, ) from dispatch.status import Status -from dispatch.test import EndpointClient -from dispatch.test.fastapi import http_client class TestFastAPI(dispatch.test.TestCase): @@ -76,27 +74,3 @@ def dispatch_test_stop(self): loop = self.runner.get_loop() loop.call_soon_threadsafe(self.event.set) - -def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): - return Dispatch( - app, - registry=Registry( - name=__name__, - endpoint=endpoint, - client=Client( - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", - ), - ), - ) - - -def create_endpoint_client( - app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None -): - return EndpointClient(http_client(app), signing_key) - - -def response_output(resp: function_pb.RunResponse) -> Any: - return any_unpickle(resp.exit.result.output) -