Skip to content

Commit

Permalink
interoperability with asyncio
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Jun 4, 2024
1 parent daedb94 commit 1c7bef8
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 112 deletions.
4 changes: 3 additions & 1 deletion src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def handler(event, context):
dispatch.handle(event, context, entrypoint="entrypoint")
"""

import asyncio
import base64
import json
import logging
Expand Down Expand Up @@ -92,7 +93,8 @@ def handle(

input = Input(req)
try:
output = func._primitive_call(input)
with asyncio.Runner() as runner:
output = runner.run(func._primitive_call(input))
except Exception:
logger.error("function '%s' fatal error", req.function, exc_info=True)
raise # FIXME
Expand Down
6 changes: 1 addition & 5 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,7 @@ async def execute(request: fastapi.Request):
# forcing execute() to be async.
data: bytes = await request.body()

loop = asyncio.get_running_loop()

content = await loop.run_in_executor(
None,
function_service_run,
content = await function_service_run(
str(request.url),
request.method,
request.headers,
Expand Down
20 changes: 12 additions & 8 deletions src/dispatch/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def read_root():
my_function.dispatch()
"""

import asyncio
import logging
from typing import Optional, Union

Expand Down Expand Up @@ -89,14 +90,17 @@ def _handle_error(self, exc: FunctionServiceError):
def _execute(self):
data: bytes = request.get_data(cache=False)

content = function_service_run(
request.url,
request.method,
dict(request.headers),
data,
self,
self._verification_key,
)
with asyncio.Runner() as runner:
content = runner.run(
function_service_run(
request.url,
request.method,
dict(request.headers),
data,
self,
self._verification_key,
),
)

res = make_response(content)
res.content_type = "application/proto"
Expand Down
24 changes: 13 additions & 11 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import os
from functools import wraps
from types import CoroutineType
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
Expand All @@ -33,7 +35,7 @@
logger = logging.getLogger(__name__)


PrimitiveFunctionType: TypeAlias = Callable[[Input], Output]
PrimitiveFunctionType: TypeAlias = Callable[[Input], Awaitable[Output]]
"""A primitive function is a function that accepts a dispatch.proto.Input
and unconditionally returns a dispatch.proto.Output. It must not raise
exceptions.
Expand Down Expand Up @@ -70,8 +72,8 @@ def endpoint(self, value: str):
def name(self) -> str:
return self._name

def _primitive_call(self, input: Input) -> Output:
return self._primitive_func(input)
async def _primitive_call(self, input: Input) -> Output:
return await self._primitive_func(input)

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
Expand Down Expand Up @@ -226,6 +228,7 @@ def function(self, func: Callable[P, T]) -> Function[P, T]: ...
def function(self, func):
"""Decorator that registers functions."""
name = func.__qualname__

if not inspect.iscoroutinefunction(func):
logger.info("registering function: %s", name)
return self._register_function(name, func)
Expand All @@ -237,23 +240,22 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]:
func = durable(func)

@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)

async_wrapper.__qualname__ = f"{name}_async"
async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, func, *args, **kwargs)

return self._register_coroutine(name, async_wrapper)
asyncio_wrapper.__qualname__ = f"{name}_asyncio"
return self._register_coroutine(name, asyncio_wrapper)

def _register_coroutine(
self, name: str, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
logger.info("registering coroutine: %s", name)

func = durable(func)

@wraps(func)
def primitive_func(input: Input) -> Output:
return OneShotScheduler(func).run(input)
async def primitive_func(input: Input) -> Output:
return await OneShotScheduler(func).run(input)

primitive_func.__qualname__ = f"{name}_primitive"
durable_primitive_func = durable(primitive_func)
Expand Down
24 changes: 14 additions & 10 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Integration of Dispatch functions with http."""

import asyncio
import logging
import os
from datetime import timedelta
Expand Down Expand Up @@ -120,14 +121,17 @@ def do_POST(self):
url = self.requestline # TODO: need full URL

try:
content = function_service_run(
url,
method,
dict(self.headers),
data,
self.registry,
self.verification_key,
)
with asyncio.Runner() as runner:
content = runner.run(
function_service_run(
url,
method,
dict(self.headers),
data,
self.registry,
self.verification_key,
)
)
except FunctionServiceError as e:
return self.send_error_response(e.status, e.code, e.message)

Expand All @@ -137,7 +141,7 @@ def do_POST(self):
self.wfile.write(content)


def function_service_run(
async def function_service_run(
url: str,
method: str,
headers: Mapping[str, str],
Expand Down Expand Up @@ -184,7 +188,7 @@ def function_service_run(
logger.info("running function '%s'", req.function)

try:
output = func._primitive_call(input)
output = await func._primitive_call(input)
except Exception:
# This indicates that an exception was raised in a primitive
# function. Primitive functions must catch exceptions, categorize
Expand Down
107 changes: 62 additions & 45 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import pickle
import sys
Expand Down Expand Up @@ -266,9 +267,7 @@ class State:
ready: List[Coroutine]
next_coroutine_id: int
next_call_id: int

prev_callers: List[Coroutine]

outstanding_calls: int


Expand Down Expand Up @@ -327,9 +326,9 @@ def __init__(
version,
)

def run(self, input: Input) -> Output:
async def run(self, input: Input) -> Output:
try:
return self._run(input)
return await self._run(input)
except Exception as e:
logger.exception(
"unexpected exception occurred during coroutine scheduling"
Expand Down Expand Up @@ -374,7 +373,7 @@ def _rebuild_state(self, input: Input):
logger.warning("state is incompatible", exc_info=True)
raise IncompatibleStateError from e

def _run(self, input: Input) -> Output:
async def _run(self, input: Input) -> Output:
if input.is_first_call:
state = self._init_state(input)
else:
Expand Down Expand Up @@ -430,19 +429,30 @@ def _run(self, input: Input) -> Output:
)

pending_calls: List[Call] = []
while state.ready:
coroutine = state.ready.pop(0)
logger.debug("running %s", coroutine)

assert coroutine.id not in state.suspended
coroutine_yield = run_coroutine(state, coroutine, pending_calls)

if coroutine_yield is not None:
if isinstance(coroutine_yield, Output):
return coroutine_yield
raise RuntimeError(
f"coroutine unexpectedly yielded '{coroutine_yield}'"
asyncio_tasks: List[asyncio.Task] = []

while state.ready or asyncio_tasks:
for coroutine in state.ready:
logger.debug("running %s", coroutine)
assert coroutine.id not in state.suspended
asyncio_tasks.append(
asyncio.create_task(run_coroutine(state, coroutine, pending_calls))
)
state.ready.clear()

done, pending = await asyncio.wait(
asyncio_tasks, return_when=asyncio.FIRST_COMPLETED
)
asyncio_tasks = list(pending)

for task in done:
coroutine_result = task.result()
if coroutine_result is None:
continue
for task in asyncio_tasks:
task.cancel()
await asyncio.gather(*asyncio_tasks, return_exceptions=True)
return coroutine_result

# Serialize coroutines and scheduler state.
logger.debug("serializing state")
Expand Down Expand Up @@ -472,40 +482,47 @@ def _run(self, input: Input) -> Output:
)


def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]):
coroutine_yield = None
coroutine_result: Optional[CoroutineResult] = None
try:
coroutine_yield = coroutine.run()
except TailCall as tc:
coroutine_result = CoroutineResult(
coroutine_id=coroutine.id, call=tc.call, status=tc.status
)
except StopIteration as e:
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, value=e.value)
except Exception as e:
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)
logger.debug(
f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e
)
async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]):
return await make_coroutine(state, coroutine, pending_calls)


@coroutine
def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]):
while True:
coroutine_yield = None
coroutine_result: Optional[CoroutineResult] = None
try:
coroutine_yield = coroutine.run()
except TailCall as tc:
coroutine_result = CoroutineResult(
coroutine_id=coroutine.id, call=tc.call, status=tc.status
)
except StopIteration as e:
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, value=e.value)
except Exception as e:
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)
logger.debug(
f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e
)

if coroutine_result is not None:
return set_coroutine_result(state, coroutine, coroutine_result)

if coroutine_result is not None:
return set_coroutine_result(state, coroutine, coroutine_result)
logger.debug("%s yielded %s", coroutine, coroutine_yield)
logger.debug("%s yielded %s", coroutine, coroutine_yield)

if isinstance(coroutine_yield, Call):
return set_coroutine_call(state, coroutine, coroutine_yield, pending_calls)
if isinstance(coroutine_yield, Call):
return set_coroutine_call(state, coroutine, coroutine_yield, pending_calls)

if isinstance(coroutine_yield, AllDirective):
return set_coroutine_all(state, coroutine, coroutine_yield.awaitables)
if isinstance(coroutine_yield, AllDirective):
return set_coroutine_all(state, coroutine, coroutine_yield.awaitables)

if isinstance(coroutine_yield, AnyDirective):
return set_coroutine_any(state, coroutine, coroutine_yield.awaitables)
if isinstance(coroutine_yield, AnyDirective):
return set_coroutine_any(state, coroutine, coroutine_yield.awaitables)

if isinstance(coroutine_yield, RaceDirective):
return set_coroutine_race(state, coroutine, coroutine_yield.awaitables)
if isinstance(coroutine_yield, RaceDirective):
return set_coroutine_race(state, coroutine, coroutine_yield.awaitables)

return coroutine_yield
yield coroutine_yield


def set_coroutine_result(
Expand Down
Loading

0 comments on commit 1c7bef8

Please sign in to comment.