diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index a67453e0..bf89ab43 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -2,9 +2,12 @@ """ +import ring.coroutine.v1.coroutine_pb2 + import os import fastapi -from fastapi.responses import PlainTextResponse +import fastapi.responses +import google.protobuf.wrappers_pb2 def configure( @@ -23,7 +26,6 @@ def configure( mount_path: The path to mount Dispatch programmable endpoints at. """ api_key = api_key or os.environ.get("DISPATCH_API_KEY") - api_url = api_url or "https://api.stealthrocket.cloud" if not app: raise ValueError("app is required") @@ -34,10 +36,37 @@ def configure( if not mount_path: raise ValueError("mount_path is required") - dispatch_app = fastapi.FastAPI() + dispatch_app = _new_app() + + app.mount(mount_path, dispatch_app) + + +class GRPCResponse(fastapi.Response): + media_type = "application/grpc+proto" + + +def _new_app(): + app = fastapi.FastAPI() - @dispatch_app.get("/", response_class=PlainTextResponse) + @app.get("/", response_class=fastapi.responses.PlainTextResponse) def read_root(): return "ok" - app.mount(mount_path, dispatch_app) + @app.post("/ring.coroutine.v1.ExecutorService/Execute", response_class=GRPCResponse) + async def execute(request: fastapi.Request): + data: bytes = await request.body() + + req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest.FromString(data) + + # TODO: unpack any + input = google.protobuf.wrappers_pb2.StringValue + req.input.Unpack(input) + + resp = ring.coroutine.v1.coroutine_pb2.ExecuteResponse( + coroutine_uri=req.coroutine_uri, + coroutine_version=req.coroutine_version, + ) + + return resp.SerializeToString() + + return app diff --git a/tests/executor_service.py b/tests/executor_service.py new file mode 100644 index 00000000..8bba0c5a --- /dev/null +++ b/tests/executor_service.py @@ -0,0 +1,157 @@ +import httpx +import grpc +import ring.coroutine.v1.coroutine_pb2 +import ring.coroutine.v1.coroutine_pb2_grpc + + +class UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + def __init__(self, client, method, request_serializer, response_deserializer): + self.client = client + self.method = method + self.request_serializer = request_serializer + self.response_deserializer = response_deserializer + + def __call__( + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): + """Synchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow + for the RPC. + metadata: Optional :term:`metadata` to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Only valid for + secure Channel. + wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. + + Returns: + The response value for the RPC. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + response = self.client.post( + self.method, + content=self.request_serializer(request), + headers={"Content-Type": "application/grpc+proto"}, + ) + response.raise_for_status() + return self.response_deserializer(response.content) + + def with_call( + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): + raise NotImplementedError() + + def future( + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): + raise NotImplementedError() + + +class HttpxGrpcChannel(grpc.Channel): + def __init__(self, http_client: httpx.Client): + self.http_client = http_client + + def subscribe(self, callback, try_to_connect=False): + raise NotImplementedError() + + def unsubscribe(self, callback): + raise NotImplementedError() + + def unary_unary(self, method, request_serializer=None, response_deserializer=None): + """Creates a UnaryUnaryMultiCallable for a unary-unary method. + + Args: + method: The name of the RPC method. + request_serializer: Optional :term:`serializer` for serializing the request + message. Request goes unserialized in case None is passed. + response_deserializer: Optional :term:`deserializer` for deserializing the + response message. Response goes undeserialized in case None + is passed. + + Returns: + A UnaryUnaryMultiCallable value for the named unary-unary method. + """ + return UnaryUnaryMultiCallable( + self.http_client, method, request_serializer, response_deserializer + ) + + def unary_stream(self, method, request_serializer=None, response_deserializer=None): + raise NotImplementedError() + + def stream_unary(self, method, request_serializer=None, response_deserializer=None): + raise NotImplementedError() + + def stream_stream( + self, method, request_serializer=None, response_deserializer=None + ): + raise NotImplementedError() + + def close(self): + raise NotImplementedError() + + def __enter__(self): + raise NotImplementedError() + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError() + + +def client( + http_client: httpx.Client, +) -> ring.coroutine.v1.coroutine_pb2_grpc.ExecutorServiceStub: + channel = HttpxGrpcChannel(http_client) + return ring.coroutine.v1.coroutine_pb2_grpc.ExecutorServiceStub(channel) + + +# class GrpcHttpxClient: +# """Client for the ring.coroutine.v1.ExecutorService gRPC service over an +# httpx client. +# """ + +# def __init__(self, http_client: httpx.Client): +# self.http_client = http_client + +# def execute(self, request: ring.coroutine.v1.coroutine_pb2.ExecuteRequest) -> ring.coroutine.v1.coroutine_pb2.ExecuteResponse: +# """Execute a coroutine. + +# Args: +# request: The request to execute. + +# Returns: +# The response from the coroutine. +# """ + +# response = self.http_client.post( +# "/ring.coroutine.v1.ExecutorService/Execute", +# content=request.SerializeToString(), +# headers={"Content-Type": "application/grpc+proto"}, +# ) +# response.raise_for_status() +# return ring.coroutine.v1.coroutine_pb2.ExecuteResponse.FromString(response.content) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 348ba06d..2f469469 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -4,9 +4,12 @@ import fastapi from fastapi.testclient import TestClient +import ring.coroutine.v1.coroutine_pb2 +from . import executor_service + class TestFastAPI(unittest.TestCase): - def test_fastapi(self): + def test_configure(self): app = fastapi.FastAPI() dispatch.fastapi.configure(app, api_key="test-key") @@ -25,3 +28,39 @@ def read_root(): resp = client.get("/dispatch/") self.assertEqual(resp.status_code, 200) self.assertEqual(resp.text, "ok") + + def test_configure_no_app(self): + with self.assertRaises(ValueError): + dispatch.fastapi.configure(None, api_key="test-key") + + def test_configure_no_api_key(self): + app = fastapi.FastAPI() + with self.assertRaises(ValueError): + dispatch.fastapi.configure(app, api_key=None) + + def test_configure_no_api_url(self): + app = fastapi.FastAPI() + with self.assertRaises(ValueError): + dispatch.fastapi.configure(app, api_key="test-key", api_url=None) + + def test_configure_no_mount_path(self): + app = fastapi.FastAPI() + with self.assertRaises(ValueError): + dispatch.fastapi.configure(app, api_key="test-key", mount_path=None) + + def test_fastapi_empty_request(self): + app = dispatch.fastapi._new_app() + http_client = TestClient(app) + + client = executor_service.client(http_client) + + req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest( + coroutine_uri="my-cool-coroutine", + coroutine_version="1", + ) + + resp = client.Execute(req) + + self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse) + self.assertEqual(resp.coroutine_uri, req.coroutine_uri) + self.assertEqual(resp.coroutine_version, req.coroutine_version) diff --git a/tests/test_ring.py b/tests/test_ring.py index fbdd7908..4acb14b1 100644 --- a/tests/test_ring.py +++ b/tests/test_ring.py @@ -35,7 +35,6 @@ def tearDown(self): self.thread_pool.shutdown(wait=True, cancel_futures=True) def test_ring(self): - request = dispatch.http.v1.http_pb2.Request( url="https://www.google.com", method="GET" )