Skip to content

Commit

Permalink
support both asyncio and blocking modes with different abstractions
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Jun 18, 2024
1 parent d448489 commit e340938
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 122 deletions.
2 changes: 1 addition & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import dispatch.integrations
from dispatch.coroutine import all, any, call, gather, race
from dispatch.function import AsyncFunction as Function
from dispatch.function import (
Batch,
Client,
ClientError,
Function,
Registry,
Reset,
default_registry,
Expand Down
File renamed without changes.
108 changes: 108 additions & 0 deletions src/dispatch/asyncio/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Integration of Dispatch functions with FastAPI for handlers using asyncio.
Example:
import fastapi
from dispatch.asyncio.fastapi import Dispatch
app = fastapi.FastAPI()
dispatch = Dispatch(app)
@dispatch.function
def my_function():
return "Hello World!"
@app.get("/")
async def read_root():
await my_function.dispatch()
"""

import logging
from typing import Optional, Union

import fastapi
import fastapi.responses

from dispatch.function import Registry
from dispatch.http import (
AsyncFunctionService,
FunctionServiceError,
validate_content_length,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)


class Dispatch(AsyncFunctionService):
"""A Dispatch instance, powered by FastAPI."""

def __init__(
self,
app: fastapi.FastAPI,
registry: Optional[Registry] = None,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
):
"""Initialize a Dispatch endpoint, and integrate it into a FastAPI app.
It mounts a sub-app that implements the Dispatch gRPC interface.
Args:
app: The FastAPI app to configure.
registry: A registry of functions to expose. If omitted, the default
registry is used.
verification_key: Key to use when verifying signed requests. Uses
the value of the DISPATCH_VERIFICATION_KEY environment variable
if omitted. The environment variable is expected to carry an
Ed25519 public key in base64 or PEM format.
If not set, request signature verification is disabled (a warning
will be logged by the constructor).
Raises:
ValueError: If any of the required arguments are missing.
"""
if not app:
raise ValueError(
"missing FastAPI app as first argument of the Dispatch constructor"
)
super().__init__(registry, verification_key)
function_service = fastapi.FastAPI()

@function_service.exception_handler(FunctionServiceError)
async def on_error(request: fastapi.Request, exc: FunctionServiceError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status,
content={"code": exc.code, "message": exc.message},
)

@function_service.post(
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
)
async def run(request: fastapi.Request):
valid, reason = validate_content_length(
int(request.headers.get("content-length", 0))
)
if not valid:
raise FunctionServiceError(400, "invalid_argument", reason)

# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

content = await self.run(
str(request.url),
request.method,
request.headers,
await request.body(),
)

return fastapi.Response(content=content, media_type="application/proto")

app.mount("/dispatch.sdk.v1.FunctionService", function_service)
4 changes: 2 additions & 2 deletions src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def handler(event, context):
from awslambdaric.lambda_context import LambdaContext

from dispatch.function import Registry
from dispatch.http import FunctionService
from dispatch.http import BlockingFunctionService
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.status import Status

logger = logging.getLogger(__name__)


class Dispatch(FunctionService):
class Dispatch(BlockingFunctionService):
def __init__(
self,
registry: Optional[Registry] = None,
Expand Down
95 changes: 17 additions & 78 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,90 +15,29 @@ def my_function():
@app.get("/")
def read_root():
my_function.dispatch()
"""
"""

import logging
from typing import Optional, Union
from typing import Any, Callable, Coroutine, TypeVar, overload

import fastapi
import fastapi.responses
from typing_extensions import ParamSpec

from dispatch.function import Registry
from dispatch.http import FunctionService, FunctionServiceError, validate_content_length
from dispatch.signature import Ed25519PublicKey, parse_verification_key
from dispatch.asyncio.fastapi import Dispatch as AsyncDispatch
from dispatch.function import BlockingFunction

logger = logging.getLogger(__name__)
__all__ = ["Dispatch", "AsyncDispatch"]

P = ParamSpec("P")
T = TypeVar("T")

class Dispatch(FunctionService):
"""A Dispatch instance, powered by FastAPI."""

def __init__(
self,
app: fastapi.FastAPI,
registry: Optional[Registry] = None,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
):
"""Initialize a Dispatch endpoint, and integrate it into a FastAPI app.
class Dispatch(AsyncDispatch):
@overload # type: ignore
def function(self, func: Callable[P, T]) -> BlockingFunction[P, T]: ...

It mounts a sub-app that implements the Dispatch gRPC interface.
@overload # type: ignore
def function(
self, func: Callable[P, Coroutine[Any, Any, T]]
) -> BlockingFunction[P, T]: ...

Args:
app: The FastAPI app to configure.
registry: A registry of functions to expose. If omitted, the default
registry is used.
verification_key: Key to use when verifying signed requests. Uses
the value of the DISPATCH_VERIFICATION_KEY environment variable
if omitted. The environment variable is expected to carry an
Ed25519 public key in base64 or PEM format.
If not set, request signature verification is disabled (a warning
will be logged by the constructor).
Raises:
ValueError: If any of the required arguments are missing.
"""
if not app:
raise ValueError(
"missing FastAPI app as first argument of the Dispatch constructor"
)
super().__init__(registry, verification_key)
function_service = fastapi.FastAPI()

@function_service.exception_handler(FunctionServiceError)
async def on_error(request: fastapi.Request, exc: FunctionServiceError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status,
content={"code": exc.code, "message": exc.message},
)

@function_service.post(
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
)
async def run(request: fastapi.Request):
valid, reason = validate_content_length(
int(request.headers.get("content-length", 0))
)
if not valid:
raise FunctionServiceError(400, "invalid_argument", reason)

# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

content = await self.run(
str(request.url),
request.method,
request.headers,
await request.body(),
)

return fastapi.Response(content=content, media_type="application/proto")

app.mount("/dispatch.sdk.v1.FunctionService", function_service)
def function(self, func):
return BlockingFunction(super().function(func))
8 changes: 6 additions & 2 deletions src/dispatch/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ def read_root():
from flask import Flask, make_response, request

from dispatch.function import Registry
from dispatch.http import FunctionService, FunctionServiceError, validate_content_length
from dispatch.http import (
BlockingFunctionService,
FunctionServiceError,
validate_content_length,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)


class Dispatch(FunctionService):
class Dispatch(BlockingFunctionService):
"""A Dispatch instance, powered by Flask."""

def __init__(
Expand Down
54 changes: 44 additions & 10 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from urllib.parse import urlparse
Expand Down Expand Up @@ -111,7 +112,7 @@ def _build_primitive_call(
)


class Function(PrimitiveFunction, Generic[P, T]):
class AsyncFunction(PrimitiveFunction, Generic[P, T]):
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""
Expand Down Expand Up @@ -157,7 +158,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
else:
return self._call_dispatch(*args, **kwargs)

def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
async def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""Dispatch an asynchronous call to the function without
waiting for a result.
Expand All @@ -171,7 +172,7 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
Returns:
DispatchID: ID of the dispatched call.
"""
return asyncio.run(self._primitive_dispatch(Arguments(args, kwargs)))
return await self._primitive_dispatch(Arguments(args, kwargs))

def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call:
"""Create a Call for this function with the provided input. Useful to
Expand All @@ -187,11 +188,38 @@ def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call:
return self._build_primitive_call(Arguments(args, kwargs))


class BlockingFunction(Generic[P, T]):
"""BlockingFunction is like Function but exposes a blocking API instead of
functions that use asyncio.
Applications typically don't create instances of BlockingFunction directly,
and instead use decorators from packages that provide integrations with
Python frameworks.
"""

def __init__(self, func: AsyncFunction[P, T]):
self._func = func

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return asyncio.run(self._func(*args, **kwargs))

def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
return asyncio.run(self._func.dispatch(*args, **kwargs))

def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call:
return self._func.build_call(*args, **kwargs)


class Reset(TailCall):
"""The current coroutine is aborted and scheduling reset to be replaced with
the call embedded in this exception."""

def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs):
def __init__(
self,
func: Union[AsyncFunction[P, T], BlockingFunction[P, T]],
*args: P.args,
**kwargs: P.kwargs,
):
super().__init__(call=func.build_call(*args, **kwargs))


Expand Down Expand Up @@ -267,10 +295,12 @@ def endpoint(self, value: str):
self._endpoint = value

@overload
def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...
def function(
self, func: Callable[P, Coroutine[Any, Any, T]]
) -> AsyncFunction[P, T]: ...

@overload
def function(self, func: Callable[P, T]) -> Function[P, T]: ...
def function(self, func: Callable[P, T]) -> AsyncFunction[P, T]: ...

def function(self, func):
"""Decorator that registers functions."""
Expand All @@ -283,7 +313,9 @@ def function(self, func):
logger.info("registering coroutine: %s", name)
return self._register_coroutine(name, func)

def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]:
def _register_function(
self, name: str, func: Callable[P, T]
) -> AsyncFunction[P, T]:
func = durable(func)

@wraps(func)
Expand All @@ -296,7 +328,7 @@ async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:

def _register_coroutine(
self, name: str, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
) -> AsyncFunction[P, T]:
logger.info("registering coroutine: %s", name)
func = durable(func)

Expand All @@ -307,7 +339,7 @@ async def primitive_func(input: Input) -> Output:
primitive_func.__qualname__ = f"{name}_primitive"
durable_primitive_func = durable(primitive_func)

wrapped_func = Function[P, T](
wrapped_func = AsyncFunction[P, T](
self,
name,
durable_primitive_func,
Expand Down Expand Up @@ -555,7 +587,9 @@ def __init__(self, client: Client):
self.client = client
self.calls = []

def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch:
def add(
self, func: AsyncFunction[P, T], *args: P.args, **kwargs: P.kwargs
) -> Batch:
"""Add a call to the specified function to the batch."""
return self.add_call(func.build_call(*args, **kwargs))

Expand Down
Loading

0 comments on commit e340938

Please sign in to comment.