Skip to content

Commit

Permalink
Merge pull request #174 from dispatchrun/asyncio
Browse files Browse the repository at this point in the history
interoperability with asyncio (part 1)
  • Loading branch information
achille-roussel authored Jun 5, 2024
2 parents d0f9818 + 7ae5255 commit ffc358b
Show file tree
Hide file tree
Showing 11 changed files with 357 additions and 178 deletions.
113 changes: 113 additions & 0 deletions src/dispatch/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import asyncio
import functools
import inspect
import signal
import threading


class Runner:
"""Runner is a class similar to asyncio.Runner but that we use for backward
compatibility with Python 3.10 and earlier.
"""

def __init__(self):
self._loop = asyncio.new_event_loop()
self._interrupt_count = 0

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
self.close()

def close(self):
try:
loop = self._loop
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, "shutdown_default_executor"): # Python 3.9+
loop.run_until_complete(loop.shutdown_default_executor())
finally:
loop.close()

def get_loop(self):
return self._loop

def run(self, coro):
if not inspect.iscoroutine(coro):
raise ValueError("a coroutine was expected, got {!r}".format(coro))

try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"Runner.run() cannot be called from a running event loop"
)

task = self._loop.create_task(coro)
sigint_handler = None

if (
threading.current_thread() is threading.main_thread()
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
):
sigint_handler = functools.partial(self._on_sigint, main_task=task)
try:
signal.signal(signal.SIGINT, sigint_handler)
except ValueError:
# `signal.signal` may throw if `threading.main_thread` does
# not support signals (e.g. embedded interpreter with signals
# not registered - see gh-91880)
sigint_handler = None

self._interrupt_count = 0
try:
asyncio.set_event_loop(self._loop)
return self._loop.run_until_complete(task)
except asyncio.CancelledError:
if self._interrupt_count > 0:
uncancel = getattr(task, "uncancel", None)
if uncancel is not None and uncancel() == 0:
raise KeyboardInterrupt()
raise # CancelledError
finally:
asyncio.set_event_loop(None)
if (
sigint_handler is not None
and signal.getsignal(signal.SIGINT) is sigint_handler
):
signal.signal(signal.SIGINT, signal.default_int_handler)

def _on_sigint(self, signum, frame, main_task):
self._interrupt_count += 1
if self._interrupt_count == 1 and not main_task.done():
main_task.cancel()
# wakeup loop if it is blocked by select() with long timeout
self._loop.call_soon_threadsafe(lambda: None)
return
raise KeyboardInterrupt()


def _cancel_all_tasks(loop):
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
4 changes: 3 additions & 1 deletion src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def handler(event, context):

from awslambdaric.lambda_context import LambdaContext

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
Expand Down Expand Up @@ -92,7 +93,8 @@ def handle(

input = Input(req)
try:
output = func._primitive_call(input)
with Runner() as runner:
output = runner.run(func._primitive_call(input))
except Exception:
logger.error("function '%s' fatal error", req.function, exc_info=True)
raise # FIXME
Expand Down
6 changes: 1 addition & 5 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,7 @@ async def execute(request: fastapi.Request):
# forcing execute() to be async.
data: bytes = await request.body()

loop = asyncio.get_running_loop()

content = await loop.run_in_executor(
None,
function_service_run,
content = await function_service_run(
str(request.url),
request.method,
request.headers,
Expand Down
20 changes: 12 additions & 8 deletions src/dispatch/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def read_root():

from flask import Flask, make_response, request

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.http import FunctionServiceError, function_service_run
from dispatch.signature import Ed25519PublicKey, parse_verification_key
Expand Down Expand Up @@ -89,14 +90,17 @@ def _handle_error(self, exc: FunctionServiceError):
def _execute(self):
data: bytes = request.get_data(cache=False)

content = function_service_run(
request.url,
request.method,
dict(request.headers),
data,
self,
self._verification_key,
)
with Runner() as runner:
content = runner.run(
function_service_run(
request.url,
request.method,
dict(request.headers),
data,
self,
self._verification_key,
),
)

res = make_response(content)
res.content_type = "application/proto"
Expand Down
24 changes: 13 additions & 11 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import os
from functools import wraps
from types import CoroutineType
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
Expand All @@ -33,7 +35,7 @@
logger = logging.getLogger(__name__)


PrimitiveFunctionType: TypeAlias = Callable[[Input], Output]
PrimitiveFunctionType: TypeAlias = Callable[[Input], Awaitable[Output]]
"""A primitive function is a function that accepts a dispatch.proto.Input
and unconditionally returns a dispatch.proto.Output. It must not raise
exceptions.
Expand Down Expand Up @@ -70,8 +72,8 @@ def endpoint(self, value: str):
def name(self) -> str:
return self._name

def _primitive_call(self, input: Input) -> Output:
return self._primitive_func(input)
async def _primitive_call(self, input: Input) -> Output:
return await self._primitive_func(input)

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
Expand Down Expand Up @@ -226,6 +228,7 @@ def function(self, func: Callable[P, T]) -> Function[P, T]: ...
def function(self, func):
"""Decorator that registers functions."""
name = func.__qualname__

if not inspect.iscoroutinefunction(func):
logger.info("registering function: %s", name)
return self._register_function(name, func)
Expand All @@ -237,23 +240,22 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]:
func = durable(func)

@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)

async_wrapper.__qualname__ = f"{name}_async"
async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, func, *args, **kwargs)

return self._register_coroutine(name, async_wrapper)
asyncio_wrapper.__qualname__ = f"{name}_asyncio"
return self._register_coroutine(name, asyncio_wrapper)

def _register_coroutine(
self, name: str, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
logger.info("registering coroutine: %s", name)

func = durable(func)

@wraps(func)
def primitive_func(input: Input) -> Output:
return OneShotScheduler(func).run(input)
async def primitive_func(input: Input) -> Output:
return await OneShotScheduler(func).run(input)

primitive_func.__qualname__ = f"{name}_primitive"
durable_primitive_func = durable(primitive_func)
Expand Down
24 changes: 14 additions & 10 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from http_message_signatures import InvalidSignature

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
Expand Down Expand Up @@ -120,14 +121,17 @@ def do_POST(self):
url = self.requestline # TODO: need full URL

try:
content = function_service_run(
url,
method,
dict(self.headers),
data,
self.registry,
self.verification_key,
)
with Runner() as runner:
content = runner.run(
function_service_run(
url,
method,
dict(self.headers),
data,
self.registry,
self.verification_key,
)
)
except FunctionServiceError as e:
return self.send_error_response(e.status, e.code, e.message)

Expand All @@ -137,7 +141,7 @@ def do_POST(self):
self.wfile.write(content)


def function_service_run(
async def function_service_run(
url: str,
method: str,
headers: Mapping[str, str],
Expand Down Expand Up @@ -184,7 +188,7 @@ def function_service_run(
logger.info("running function '%s'", req.function)

try:
output = func._primitive_call(input)
output = await func._primitive_call(input)
except Exception:
# This indicates that an exception was raised in a primitive
# function. Primitive functions must catch exceptions, categorize
Expand Down
Loading

0 comments on commit ffc358b

Please sign in to comment.