Skip to content

Commit

Permalink
Add first route
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Jan 30, 2024
1 parent 68940f4 commit c15d708
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 15 deletions.
74 changes: 67 additions & 7 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
"""Integration of Dispatch programmable endpoints for FastAPI.
Example:
import fastapi
import dispatch.fastapi
app = fastapi.FastAPI()
dispatch.fastapi.configure(app, api_key="test-key")
@app.dispatch_coroutine()
def my_cool_coroutine():
return "Hello World!"
@app.get("/")
def read_root():
my_cool_coroutine.call()
"""

import ring.coroutine.v1.coroutine_pb2

from collections.abc import Callable
from typing import Any
import os
import fastapi
import fastapi.responses
Expand All @@ -18,7 +34,8 @@ def configure(
"""Configure the FastAPI app to use Dispatch programmable endpoints.
It mounts a sub-app at the given mount path that implements the Dispatch
interface.
interface. It also adds a a decorator named @app.dispatch_coroutine() to
register coroutines.
Args:
app: The FastAPI app to configure.
Expand All @@ -40,37 +57,80 @@ def configure(

dispatch_app = _new_app()

app.__setattr__("dispatch_coroutine", dispatch_app.dispatch_coroutine)
app.mount(mount_path, dispatch_app)


class _DispatchAPI(fastapi.FastAPI):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._coroutines = {}

def dispatch_coroutine(self):
"""Register a coroutine with the Dispatch programmable endpoints.
Args:
app: The FastAPI app to register the coroutine with.
coroutine: The coroutine to register.
Raises:
ValueError: If the coroutine is already registered.
"""

def wrap(coroutine: Callable[..., Any]):
if coroutine.__qualname__ in self._coroutines:
raise ValueError(
f"Coroutine {coroutine.__qualname__} already registered"
)
self._coroutines[coroutine.__qualname__] = coroutine
return coroutine

return wrap


class _GRPCResponse(fastapi.Response):
media_type = "application/grpc+proto"


def _coroutine_uri_to_function_name(coroutine_uri: str) -> str:
return coroutine_uri.split(":")[-1]
def _coroutine_uri_to_qualname(coroutine_uri: str) -> str:
return coroutine_uri.split("/")[-1]


def _new_app():
app = fastapi.FastAPI()
app = _DispatchAPI()
app._coroutines = {}

@app.get("/", response_class=fastapi.responses.PlainTextResponse)
def read_root():
return "ok"

@app.post("/ring.coroutine.v1.ExecutorService/Execute", response_class=_GRPCResponse)
@app.post(
"/ring.coroutine.v1.ExecutorService/Execute", response_class=_GRPCResponse
)
async def execute(request: fastapi.Request):
data: bytes = await request.body()

req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest.FromString(data)

coroutine = app._coroutines[_coroutine_uri_to_qualname(req.coroutine_uri)]

# TODO: unpack any
input = google.protobuf.wrappers_pb2.StringValue
input = google.protobuf.wrappers_pb2.StringValue()
req.input.Unpack(input)

output = coroutine(input.value)

# TODO pack any
output_pb = google.protobuf.wrappers_pb2.StringValue(value=output)
output_any = google.protobuf.any_pb2.Any()
output_any.Pack(output_pb)

resp = ring.coroutine.v1.coroutine_pb2.ExecuteResponse(
coroutine_uri=req.coroutine_uri,
coroutine_version=req.coroutine_version,
exit=ring.coroutine.v1.coroutine_pb2.Exit(
result=ring.coroutine.v1.coroutine_pb2.Result(output=output_any)
),
)

return resp.SerializeToString()
Expand Down
24 changes: 16 additions & 8 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dispatch.fastapi
import fastapi
from fastapi.testclient import TestClient

import google.protobuf.wrappers_pb2
import ring.coroutine.v1.coroutine_pb2
from . import executor_service

Expand Down Expand Up @@ -38,29 +38,37 @@ def test_configure_no_api_key(self):
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key=None)

def test_configure_no_api_url(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key="test-key", api_url=None)

def test_configure_no_mount_path(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key="test-key", mount_path=None)

def test_fastapi_empty_request(self):
def test_fastapi_simple_request(self):
app = dispatch.fastapi._new_app()

@app.dispatch_coroutine()
def my_cool_coroutine(input):
return f"You told me: '{input}' ({len(input)} characters)"

http_client = TestClient(app)

client = executor_service.client(http_client)

input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.StringValue(value="Hello World!"))
req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest(
coroutine_uri="my-cool-coroutine",
coroutine_uri=my_cool_coroutine.__qualname__,
coroutine_version="1",
input=input_any,
)

resp = client.Execute(req)

self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse)
self.assertEqual(resp.coroutine_uri, req.coroutine_uri)
self.assertEqual(resp.coroutine_version, req.coroutine_version)

resp.exit.result.output.Unpack(
output := google.protobuf.wrappers_pb2.StringValue()
)
self.assertEqual(output.value, "You told me: 'Hello World!' (12 characters)")

0 comments on commit c15d708

Please sign in to comment.