Skip to content

Commit

Permalink
Allow for a public URL + propagate to URIs
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Feb 1, 2024
1 parent 04e6f81 commit d09ce78
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 32 deletions.
16 changes: 7 additions & 9 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

from __future__ import annotations
import enum
from typing import Any
import pickle
from typing import Any, Callable
from dataclasses import dataclass

import google.protobuf.message
import pickle

from ring.coroutine.v1 import coroutine_pb2
from ring.status.v1 import status_pb2

Expand Down Expand Up @@ -77,15 +79,16 @@ class Coroutine:
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK."""

def __init__(self, func):
def __init__(self, uri: str, func: Callable[[Input], Output]):
self._uri = uri
self._func = func

def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)

@property
def uri(self) -> str:
return self._func.__qualname__
return self._uri

def call_with(self, input: Any, correlation_id: int | None = None) -> Call:
"""Create a Call of this coroutine with the provided input. Useful to
Expand Down Expand Up @@ -362,8 +365,3 @@ def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any


def _coroutine_uri_to_qualname(coroutine_uri: str) -> str:
# TODO: fix this when we decide on the format of coroutine URIs.
return coroutine_uri.split("/")[-1]
51 changes: 37 additions & 14 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@ def read_root():
my_cool_coroutine.call()
"""

import ring.coroutine.v1.coroutine_pb2
from collections.abc import Callable
from typing import Any
import os
from typing import Any, Dict
from collections.abc import Callable

import fastapi
import fastapi.responses
from httpx import _urlparse

import ring.coroutine.v1.coroutine_pb2
import dispatch.coroutine


def configure(
app: fastapi.FastAPI,
public_url: str,
api_key: None | str = None,
):
"""Configure the FastAPI app to use Dispatch programmable endpoints.
Expand All @@ -39,6 +43,8 @@ def configure(
app: The FastAPI app to configure.
api_key: Dispatch API key to use for authentication. Uses the value of
the DISPATCH_API_KEY environment variable by default.
public_url: Full URL of the application the dispatch programmable
endpoint will be running on.
Raises:
ValueError: If any of the required arguments are missing.
Expand All @@ -47,19 +53,26 @@ def configure(

if not app:
raise ValueError("app is required")
if not public_url:
raise ValueError("public_url is required")
if not api_key:
raise ValueError("api_key is required")

dispatch_app = _new_app()
parsed_url = _urlparse.urlparse(public_url)
if not parsed_url.netloc or not parsed_url.scheme:
raise ValueError("public_url must be a full URL with protocol and domain")

dispatch_app = _new_app(public_url)

app.__setattr__("dispatch_coroutine", dispatch_app.dispatch_coroutine)
app.mount("/ring.coroutine.v1.ExecutorService", dispatch_app)


class _DispatchAPI(fastapi.FastAPI):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._coroutines = {}
def __init__(self, public_url: str):
super().__init__()
self._coroutines: Dict[str, dispatch.coroutine.Coroutine] = {}
self._public_url = _urlparse.urlparse(public_url)

def dispatch_coroutine(self):
"""Register a coroutine with the Dispatch programmable endpoints.
Expand All @@ -73,7 +86,9 @@ def dispatch_coroutine(self):
"""

def wrap(func: Callable[[dispatch.coroutine.Input], dispatch.coroutine.Output]):
coro = dispatch.coroutine.Coroutine(func)
name = func.__qualname__
uri = str(self._public_url.copy_with(fragment="function=" + name))
coro = dispatch.coroutine.Coroutine(uri, func)
if coro.uri in self._coroutines:
raise ValueError(f"Coroutine {coro.uri} already registered")
self._coroutines[coro.uri] = coro
Expand All @@ -86,9 +101,8 @@ class _GRPCResponse(fastapi.Response):
media_type = "application/grpc+proto"


def _new_app():
app = _DispatchAPI()
app._coroutines = {}
def _new_app(public_url: str):
app = _DispatchAPI(public_url)

@app.post(
# The endpoint for execution is hardcoded at the moment. If the service
Expand All @@ -112,9 +126,18 @@ async def execute(request: fastapi.Request):

# TODO: be more graceful. This will crash if the coroutine is not found,
# and the coroutine version is not taken into account.
coroutine = app._coroutines[
dispatch.coroutine._coroutine_uri_to_qualname(req.coroutine_uri)
]

uri = req.coroutine_uri

coroutine = app._coroutines.get(uri, None)
if coroutine is None:
# TODO: integrate with logging
print("Coroutine not found:")
print(" uri:", uri)
print("Available coroutines:")
for k in app._coroutines:
print(" ", k)
raise KeyError(f"coroutine '{uri}' not available on this system")

coro_input = dispatch.coroutine.Input(req)

Expand Down
4 changes: 1 addition & 3 deletions tests/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def execute(self, client: coroutine_grpc.ExecutorServiceStub):
pending task left.
"""
while True:
if len(self.pending_tasks) == 0:
return
while len(self.pending_tasks) > 0:
entry = self.pending_tasks.pop(0)
task = entry["task"]

Expand Down
25 changes: 20 additions & 5 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class TestFastAPI(unittest.TestCase):
def test_configure(self):
app = fastapi.FastAPI()

dispatch.fastapi.configure(app, api_key="test-key")
dispatch.fastapi.configure(
app, api_key="test-key", public_url="https://127.0.0.1:9999"
)

@app.get("/")
def read_root():
Expand All @@ -34,16 +36,27 @@ def read_root():

def test_configure_no_app(self):
with self.assertRaises(ValueError):
dispatch.fastapi.configure(None, api_key="test-key")
dispatch.fastapi.configure(
None, api_key="test-key", public_url="http://127.0.0.1:9999"
)

def test_configure_no_api_key(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key=None)
dispatch.fastapi.configure(
app, api_key=None, public_url="http://127.0.0.1:9999"
)

def test_configure_no_public_url(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key="test", public_url="")

def test_fastapi_simple_request(self):
app = fastapi.FastAPI()
dispatch.fastapi.configure(app, api_key="test-key")
dispatch.fastapi.configure(
app, api_key="test-key", public_url="http://127.0.0.1:9999/"
)

@app.dispatch_coroutine()
def my_cool_coroutine(input: Input) -> Output:
Expand Down Expand Up @@ -86,7 +99,9 @@ def response_output(resp: coroutine_pb2.ExecuteResponse) -> Any:
class TestCoroutine(unittest.TestCase):
def setUp(self):
self.app = fastapi.FastAPI()
dispatch.fastapi.configure(self.app, api_key="test-key")
dispatch.fastapi.configure(
self.app, api_key="test-key", public_url="https://127.0.0.1:9999"
)
http_client = TestClient(self.app)
self.client = executor_service.client(http_client)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
class TestFullFastapi(unittest.TestCase):
def setUp(self):
self.app = fastapi.FastAPI()
dispatch.fastapi.configure(self.app, api_key="test-key")
dispatch.fastapi.configure(
self.app, api_key="test-key", public_url="http://test"
)
http_client = TestClient(self.app)
self.app_client = executor_service.client(http_client)
self.server = ServerTest()
Expand Down

0 comments on commit d09ce78

Please sign in to comment.