Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support both asyncio and blocking modes with different abstractions #178

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import dispatch.integrations
from dispatch.coroutine import all, any, call, gather, race
from dispatch.function import AsyncFunction as Function
from dispatch.function import (
Batch,
Client,
ClientError,
Function,
Registry,
Reset,
default_registry,
Expand Down
File renamed without changes.
108 changes: 108 additions & 0 deletions src/dispatch/asyncio/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Integration of Dispatch functions with FastAPI for handlers using asyncio.

Example:

import fastapi
from dispatch.asyncio.fastapi import Dispatch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FastAPI wraps Starlette which is an ASGI framework. Should the asynchronous interface be the default for FastAPI (and should the user opt-in to the synchronous/blocking API)?

Copy link
Contributor Author

@achille-roussel achille-roussel Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked myself the same question... I know FastAPI is natively async, but their home page shows examples with non-async code (https://fastapi.tiangolo.com/), which makes me think it's probably a popular default.

The package name with asyncio is also self-explanatory, but using something like dispatch.blocking.fastapi would somewhat look like inventing a new name that may not be as clear to developers unfamiliar with the concept.

Tough call 🤷


app = fastapi.FastAPI()
dispatch = Dispatch(app)

@dispatch.function
def my_function():
return "Hello World!"

@app.get("/")
async def read_root():
await my_function.dispatch()
"""

import logging
from typing import Optional, Union

import fastapi
import fastapi.responses

from dispatch.function import Registry
from dispatch.http import (
AsyncFunctionService,
FunctionServiceError,
validate_content_length,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)


class Dispatch(AsyncFunctionService):
"""A Dispatch instance, powered by FastAPI."""

def __init__(
self,
app: fastapi.FastAPI,
registry: Optional[Registry] = None,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
):
"""Initialize a Dispatch endpoint, and integrate it into a FastAPI app.

It mounts a sub-app that implements the Dispatch gRPC interface.

Args:
app: The FastAPI app to configure.

registry: A registry of functions to expose. If omitted, the default
registry is used.

verification_key: Key to use when verifying signed requests. Uses
the value of the DISPATCH_VERIFICATION_KEY environment variable
if omitted. The environment variable is expected to carry an
Ed25519 public key in base64 or PEM format.
If not set, request signature verification is disabled (a warning
will be logged by the constructor).

Raises:
ValueError: If any of the required arguments are missing.
"""
if not app:
raise ValueError(
"missing FastAPI app as first argument of the Dispatch constructor"
)
super().__init__(registry, verification_key)
function_service = fastapi.FastAPI()

@function_service.exception_handler(FunctionServiceError)
async def on_error(request: fastapi.Request, exc: FunctionServiceError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status,
content={"code": exc.code, "message": exc.message},
)

@function_service.post(
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
)
async def run(request: fastapi.Request):
valid, reason = validate_content_length(
int(request.headers.get("content-length", 0))
)
if not valid:
raise FunctionServiceError(400, "invalid_argument", reason)

# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

content = await self.run(
str(request.url),
request.method,
request.headers,
await request.body(),
)

return fastapi.Response(content=content, media_type="application/proto")

app.mount("/dispatch.sdk.v1.FunctionService", function_service)
4 changes: 2 additions & 2 deletions src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def handler(event, context):
from awslambdaric.lambda_context import LambdaContext

from dispatch.function import Registry
from dispatch.http import FunctionService
from dispatch.http import BlockingFunctionService
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.status import Status

logger = logging.getLogger(__name__)


class Dispatch(FunctionService):
class Dispatch(BlockingFunctionService):
def __init__(
self,
registry: Optional[Registry] = None,
Expand Down
95 changes: 17 additions & 78 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,90 +15,29 @@ def my_function():
@app.get("/")
def read_root():
my_function.dispatch()
"""
"""

import logging
from typing import Optional, Union
from typing import Any, Callable, Coroutine, TypeVar, overload

import fastapi
import fastapi.responses
from typing_extensions import ParamSpec

from dispatch.function import Registry
from dispatch.http import FunctionService, FunctionServiceError, validate_content_length
from dispatch.signature import Ed25519PublicKey, parse_verification_key
from dispatch.asyncio.fastapi import Dispatch as AsyncDispatch
from dispatch.function import BlockingFunction

logger = logging.getLogger(__name__)
__all__ = ["Dispatch", "AsyncDispatch"]

P = ParamSpec("P")
T = TypeVar("T")

class Dispatch(FunctionService):
"""A Dispatch instance, powered by FastAPI."""

def __init__(
self,
app: fastapi.FastAPI,
registry: Optional[Registry] = None,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
):
"""Initialize a Dispatch endpoint, and integrate it into a FastAPI app.
class Dispatch(AsyncDispatch):
@overload # type: ignore
def function(self, func: Callable[P, T]) -> BlockingFunction[P, T]: ...

It mounts a sub-app that implements the Dispatch gRPC interface.
@overload # type: ignore
def function(
self, func: Callable[P, Coroutine[Any, Any, T]]
) -> BlockingFunction[P, T]: ...

Args:
app: The FastAPI app to configure.

registry: A registry of functions to expose. If omitted, the default
registry is used.

verification_key: Key to use when verifying signed requests. Uses
the value of the DISPATCH_VERIFICATION_KEY environment variable
if omitted. The environment variable is expected to carry an
Ed25519 public key in base64 or PEM format.
If not set, request signature verification is disabled (a warning
will be logged by the constructor).

Raises:
ValueError: If any of the required arguments are missing.
"""
if not app:
raise ValueError(
"missing FastAPI app as first argument of the Dispatch constructor"
)
super().__init__(registry, verification_key)
function_service = fastapi.FastAPI()

@function_service.exception_handler(FunctionServiceError)
async def on_error(request: fastapi.Request, exc: FunctionServiceError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status,
content={"code": exc.code, "message": exc.message},
)

@function_service.post(
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
)
async def run(request: fastapi.Request):
valid, reason = validate_content_length(
int(request.headers.get("content-length", 0))
)
if not valid:
raise FunctionServiceError(400, "invalid_argument", reason)

# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

content = await self.run(
str(request.url),
request.method,
request.headers,
await request.body(),
)

return fastapi.Response(content=content, media_type="application/proto")

app.mount("/dispatch.sdk.v1.FunctionService", function_service)
def function(self, func):
return BlockingFunction(super().function(func))
8 changes: 6 additions & 2 deletions src/dispatch/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ def read_root():
from flask import Flask, make_response, request

from dispatch.function import Registry
from dispatch.http import FunctionService, FunctionServiceError, validate_content_length
from dispatch.http import (
BlockingFunctionService,
FunctionServiceError,
validate_content_length,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)


class Dispatch(FunctionService):
class Dispatch(BlockingFunctionService):
"""A Dispatch instance, powered by Flask."""

def __init__(
Expand Down
54 changes: 44 additions & 10 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from urllib.parse import urlparse
Expand Down Expand Up @@ -111,7 +112,7 @@ def _build_primitive_call(
)


class Function(PrimitiveFunction, Generic[P, T]):
class AsyncFunction(PrimitiveFunction, Generic[P, T]):
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""
Expand Down Expand Up @@ -157,7 +158,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
else:
return self._call_dispatch(*args, **kwargs)

def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
async def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""Dispatch an asynchronous call to the function without
waiting for a result.

Expand All @@ -171,7 +172,7 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
Returns:
DispatchID: ID of the dispatched call.
"""
return asyncio.run(self._primitive_dispatch(Arguments(args, kwargs)))
return await self._primitive_dispatch(Arguments(args, kwargs))

def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call:
"""Create a Call for this function with the provided input. Useful to
Expand All @@ -187,11 +188,38 @@ def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call:
return self._build_primitive_call(Arguments(args, kwargs))


class BlockingFunction(Generic[P, T]):
"""BlockingFunction is like Function but exposes a blocking API instead of
functions that use asyncio.

Applications typically don't create instances of BlockingFunction directly,
and instead use decorators from packages that provide integrations with
Python frameworks.
"""

def __init__(self, func: AsyncFunction[P, T]):
self._func = func

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return asyncio.run(self._func(*args, **kwargs))

def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
return asyncio.run(self._func.dispatch(*args, **kwargs))

def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call:
return self._func.build_call(*args, **kwargs)


class Reset(TailCall):
"""The current coroutine is aborted and scheduling reset to be replaced with
the call embedded in this exception."""

def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs):
def __init__(
self,
func: Union[AsyncFunction[P, T], BlockingFunction[P, T]],
*args: P.args,
**kwargs: P.kwargs,
):
super().__init__(call=func.build_call(*args, **kwargs))


Expand Down Expand Up @@ -267,10 +295,12 @@ def endpoint(self, value: str):
self._endpoint = value

@overload
def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...
def function(
self, func: Callable[P, Coroutine[Any, Any, T]]
) -> AsyncFunction[P, T]: ...

@overload
def function(self, func: Callable[P, T]) -> Function[P, T]: ...
def function(self, func: Callable[P, T]) -> AsyncFunction[P, T]: ...

def function(self, func):
"""Decorator that registers functions."""
Expand All @@ -283,7 +313,9 @@ def function(self, func):
logger.info("registering coroutine: %s", name)
return self._register_coroutine(name, func)

def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]:
def _register_function(
self, name: str, func: Callable[P, T]
) -> AsyncFunction[P, T]:
func = durable(func)

@wraps(func)
Expand All @@ -296,7 +328,7 @@ async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:

def _register_coroutine(
self, name: str, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
) -> AsyncFunction[P, T]:
logger.info("registering coroutine: %s", name)
func = durable(func)

Expand All @@ -307,7 +339,7 @@ async def primitive_func(input: Input) -> Output:
primitive_func.__qualname__ = f"{name}_primitive"
durable_primitive_func = durable(primitive_func)

wrapped_func = Function[P, T](
wrapped_func = AsyncFunction[P, T](
self,
name,
durable_primitive_func,
Expand Down Expand Up @@ -555,7 +587,9 @@ def __init__(self, client: Client):
self.client = client
self.calls = []

def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch:
def add(
self, func: AsyncFunction[P, T], *args: P.args, **kwargs: P.kwargs
) -> Batch:
"""Add a call to the specified function to the batch."""
return self.add_call(func.build_call(*args, **kwargs))

Expand Down
Loading