diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 7d15978..5e27147 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -18,6 +18,7 @@ def handler(event, context): dispatch.handle(event, context, entrypoint="entrypoint") """ +import asyncio import base64 import json import logging @@ -92,7 +93,8 @@ def handle( input = Input(req) try: - output = func._primitive_call(input) + with asyncio.Runner() as runner: + output = runner.run(func._primitive_call(input)) except Exception: logger.error("function '%s' fatal error", req.function, exc_info=True) raise # FIXME diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 206e7bc..0901849 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -102,11 +102,7 @@ async def execute(request: fastapi.Request): # forcing execute() to be async. data: bytes = await request.body() - loop = asyncio.get_running_loop() - - content = await loop.run_in_executor( - None, - function_service_run, + content = await function_service_run( str(request.url), request.method, request.headers, diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index cbaded6..0e719f9 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -17,6 +17,7 @@ def read_root(): my_function.dispatch() """ +import asyncio import logging from typing import Optional, Union @@ -89,14 +90,17 @@ def _handle_error(self, exc: FunctionServiceError): def _execute(self): data: bytes = request.get_data(cache=False) - content = function_service_run( - request.url, - request.method, - dict(request.headers), - data, - self, - self._verification_key, - ) + with asyncio.Runner() as runner: + content = runner.run( + function_service_run( + request.url, + request.method, + dict(request.headers), + data, + self, + self._verification_key, + ), + ) res = make_response(content) res.content_type = "application/proto" diff --git a/src/dispatch/function.py b/src/dispatch/function.py index d7b4702..e9f7a32 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import inspect import logging import os @@ -7,6 +8,7 @@ from types import CoroutineType from typing import ( Any, + Awaitable, Callable, Coroutine, Dict, @@ -33,7 +35,7 @@ logger = logging.getLogger(__name__) -PrimitiveFunctionType: TypeAlias = Callable[[Input], Output] +PrimitiveFunctionType: TypeAlias = Callable[[Input], Awaitable[Output]] """A primitive function is a function that accepts a dispatch.proto.Input and unconditionally returns a dispatch.proto.Output. It must not raise exceptions. @@ -70,8 +72,8 @@ def endpoint(self, value: str): def name(self) -> str: return self._name - def _primitive_call(self, input: Input) -> Output: - return self._primitive_func(input) + async def _primitive_call(self, input: Input) -> Output: + return await self._primitive_func(input) def _primitive_dispatch(self, input: Any = None) -> DispatchID: [dispatch_id] = self._client.dispatch([self._build_primitive_call(input)]) @@ -226,6 +228,7 @@ def function(self, func: Callable[P, T]) -> Function[P, T]: ... def function(self, func): """Decorator that registers functions.""" name = func.__qualname__ + if not inspect.iscoroutinefunction(func): logger.info("registering function: %s", name) return self._register_function(name, func) @@ -237,23 +240,22 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: func = durable(func) @wraps(func) - async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - return func(*args, **kwargs) - - async_wrapper.__qualname__ = f"{name}_async" + async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, func, *args, **kwargs) - return self._register_coroutine(name, async_wrapper) + asyncio_wrapper.__qualname__ = f"{name}_asyncio" + return self._register_coroutine(name, asyncio_wrapper) def _register_coroutine( self, name: str, func: Callable[P, Coroutine[Any, Any, T]] ) -> Function[P, T]: logger.info("registering coroutine: %s", name) - func = durable(func) @wraps(func) - def primitive_func(input: Input) -> Output: - return OneShotScheduler(func).run(input) + async def primitive_func(input: Input) -> Output: + return await OneShotScheduler(func).run(input) primitive_func.__qualname__ = f"{name}_primitive" durable_primitive_func = durable(primitive_func) diff --git a/src/dispatch/http.py b/src/dispatch/http.py index d772cb1..4c46e1d 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -1,5 +1,6 @@ """Integration of Dispatch functions with http.""" +import asyncio import logging import os from datetime import timedelta @@ -120,14 +121,17 @@ def do_POST(self): url = self.requestline # TODO: need full URL try: - content = function_service_run( - url, - method, - dict(self.headers), - data, - self.registry, - self.verification_key, - ) + with asyncio.Runner() as runner: + content = runner.run( + function_service_run( + url, + method, + dict(self.headers), + data, + self.registry, + self.verification_key, + ) + ) except FunctionServiceError as e: return self.send_error_response(e.status, e.code, e.message) @@ -137,7 +141,7 @@ def do_POST(self): self.wfile.write(content) -def function_service_run( +async def function_service_run( url: str, method: str, headers: Mapping[str, str], @@ -184,7 +188,7 @@ def function_service_run( logger.info("running function '%s'", req.function) try: - output = func._primitive_call(input) + output = await func._primitive_call(input) except Exception: # This indicates that an exception was raised in a primitive # function. Primitive functions must catch exceptions, categorize diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index fbae021..921f513 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -1,3 +1,4 @@ +import asyncio import logging import pickle import sys @@ -266,9 +267,7 @@ class State: ready: List[Coroutine] next_coroutine_id: int next_call_id: int - prev_callers: List[Coroutine] - outstanding_calls: int @@ -327,9 +326,9 @@ def __init__( version, ) - def run(self, input: Input) -> Output: + async def run(self, input: Input) -> Output: try: - return self._run(input) + return await self._run(input) except Exception as e: logger.exception( "unexpected exception occurred during coroutine scheduling" @@ -374,7 +373,7 @@ def _rebuild_state(self, input: Input): logger.warning("state is incompatible", exc_info=True) raise IncompatibleStateError from e - def _run(self, input: Input) -> Output: + async def _run(self, input: Input) -> Output: if input.is_first_call: state = self._init_state(input) else: @@ -430,19 +429,30 @@ def _run(self, input: Input) -> Output: ) pending_calls: List[Call] = [] - while state.ready: - coroutine = state.ready.pop(0) - logger.debug("running %s", coroutine) - - assert coroutine.id not in state.suspended - coroutine_yield = run_coroutine(state, coroutine, pending_calls) - - if coroutine_yield is not None: - if isinstance(coroutine_yield, Output): - return coroutine_yield - raise RuntimeError( - f"coroutine unexpectedly yielded '{coroutine_yield}'" + asyncio_tasks: List[asyncio.Task] = [] + + while state.ready or asyncio_tasks: + for coroutine in state.ready: + logger.debug("running %s", coroutine) + assert coroutine.id not in state.suspended + asyncio_tasks.append( + asyncio.create_task(run_coroutine(state, coroutine, pending_calls)) ) + state.ready.clear() + + done, pending = await asyncio.wait( + asyncio_tasks, return_when=asyncio.FIRST_COMPLETED + ) + asyncio_tasks = list(pending) + + for task in done: + coroutine_result = task.result() + if coroutine_result is None: + continue + for task in asyncio_tasks: + task.cancel() + await asyncio.gather(*asyncio_tasks, return_exceptions=True) + return coroutine_result # Serialize coroutines and scheduler state. logger.debug("serializing state") @@ -472,40 +482,47 @@ def _run(self, input: Input) -> Output: ) -def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): - coroutine_yield = None - coroutine_result: Optional[CoroutineResult] = None - try: - coroutine_yield = coroutine.run() - except TailCall as tc: - coroutine_result = CoroutineResult( - coroutine_id=coroutine.id, call=tc.call, status=tc.status - ) - except StopIteration as e: - coroutine_result = CoroutineResult(coroutine_id=coroutine.id, value=e.value) - except Exception as e: - coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e) - logger.debug( - f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e - ) +async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): + return await make_coroutine(state, coroutine, pending_calls) + + +@coroutine +def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): + while True: + coroutine_yield = None + coroutine_result: Optional[CoroutineResult] = None + try: + coroutine_yield = coroutine.run() + except TailCall as tc: + coroutine_result = CoroutineResult( + coroutine_id=coroutine.id, call=tc.call, status=tc.status + ) + except StopIteration as e: + coroutine_result = CoroutineResult(coroutine_id=coroutine.id, value=e.value) + except Exception as e: + coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e) + logger.debug( + f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e + ) + + if coroutine_result is not None: + return set_coroutine_result(state, coroutine, coroutine_result) - if coroutine_result is not None: - return set_coroutine_result(state, coroutine, coroutine_result) - logger.debug("%s yielded %s", coroutine, coroutine_yield) + logger.debug("%s yielded %s", coroutine, coroutine_yield) - if isinstance(coroutine_yield, Call): - return set_coroutine_call(state, coroutine, coroutine_yield, pending_calls) + if isinstance(coroutine_yield, Call): + return set_coroutine_call(state, coroutine, coroutine_yield, pending_calls) - if isinstance(coroutine_yield, AllDirective): - return set_coroutine_all(state, coroutine, coroutine_yield.awaitables) + if isinstance(coroutine_yield, AllDirective): + return set_coroutine_all(state, coroutine, coroutine_yield.awaitables) - if isinstance(coroutine_yield, AnyDirective): - return set_coroutine_any(state, coroutine, coroutine_yield.awaitables) + if isinstance(coroutine_yield, AnyDirective): + return set_coroutine_any(state, coroutine, coroutine_yield.awaitables) - if isinstance(coroutine_yield, RaceDirective): - return set_coroutine_race(state, coroutine, coroutine_yield.awaitables) + if isinstance(coroutine_yield, RaceDirective): + return set_coroutine_race(state, coroutine, coroutine_yield.awaitables) - return coroutine_yield + yield coroutine_yield def set_coroutine_result( diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 4df24c3..1276f4c 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -1,3 +1,4 @@ +import asyncio import unittest from typing import Any, Callable, List, Optional, Set, Type @@ -53,6 +54,12 @@ async def raises_error(): class TestOneShotScheduler(unittest.TestCase): + def setUp(self): + self.runner = asyncio.Runner() + + def tearDown(self): + self.runner.close() + def test_main_return(self): @durable async def main(): @@ -105,7 +112,7 @@ async def main(*functions): self.assert_poll_call_functions(output, ["foo", "bar", "baz"]) - def test_depth_first_run(self): + def test_depth_run(self): @durable async def main(): return await gather( @@ -116,10 +123,15 @@ async def main(): ) output = self.start(main) - + # In this test, the output is deterministic, but it does not follow the + # order in which the coroutines are declared due to interleaving of the + # asyncio event loop. + # + # Note that the order could change between Python versions, so we might + # choose to remove this test, or adapt it in the future. self.assert_poll_call_functions( output, - ["a", "b", "c", "d", "e", "f", "g", "h"], + ["d", "h", "e", "f", "g", "a", "b", "c"], min_results=1, max_results=8, ) @@ -331,7 +343,8 @@ async def c_then_d(): @durable async def main(c_then_d): return await gather( - call_concurrently("a", "b"), + call_one("a"), + call_one("b"), c_then_d(), ) @@ -362,7 +375,7 @@ async def main(c_then_d): ], ) - self.assert_exit_result_value(output, [[a_result, b_result], c_result + 100]) + self.assert_exit_result_value(output, [a_result, b_result, c_result + 100]) def test_raise_indirect(self): @durable @@ -417,12 +430,14 @@ def start( **kwargs: Any, ) -> Output: input = Input.from_input_arguments(main.__qualname__, *args, **kwargs) - return OneShotScheduler( - main, - poll_min_results=poll_min_results, - poll_max_results=poll_max_results, - poll_max_wait_seconds=poll_max_wait_seconds, - ).run(input) + return self.runner.run( + OneShotScheduler( + main, + poll_min_results=poll_min_results, + poll_max_results=poll_max_results, + poll_max_wait_seconds=poll_max_wait_seconds, + ).run(input) + ) def resume( self, @@ -438,7 +453,7 @@ def resume( call_results, Error.from_exception(poll_error) if poll_error else None, ) - return OneShotScheduler(main).run(input) + return self.runner.run(OneShotScheduler(main).run(input)) def assert_exit(self, output: Output) -> exit_pb.Exit: response = output._message diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index dee353f..149b817 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -85,7 +85,7 @@ def test_fastapi_simple_request(self): dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/") @dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value( f"You told me: '{input.input}' ({len(input.input)} characters)" ) @@ -263,7 +263,7 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: def test_no_input(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value("Hello World!") resp = self.execute(my_function) @@ -282,7 +282,7 @@ def test_missing_coroutine(self): def test_string_input(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value(f"You sent '{input.input}'") resp = self.execute(my_function, input="cool stuff") @@ -291,7 +291,7 @@ def my_function(input: Input) -> Output: def test_error_on_access_state_in_first_call(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: try: print(input.coroutine_state) except ValueError: @@ -310,7 +310,7 @@ def my_function(input: Input) -> Output: def test_error_on_access_input_in_second_call(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: if input.is_first_call: return Output.poll(coroutine_state=b"42") try: @@ -334,22 +334,22 @@ def my_function(input: Input) -> Output: def test_duplicate_coro(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value("Do one thing") with self.assertRaises(ValueError): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value("Do something else") def test_two_simple_coroutines(self): @self.dispatch.primitive_function - def echoroutine(input: Input) -> Output: + async def echoroutine(input: Input) -> Output: return Output.value(f"Echo: '{input.input}'") @self.dispatch.primitive_function - def len_coroutine(input: Input) -> Output: + async def len_coroutine(input: Input) -> Output: return Output.value(f"Length: {len(input.input)}") data = "cool stuff" @@ -363,7 +363,7 @@ def len_coroutine(input: Input) -> Output: def test_coroutine_with_state(self): @self.dispatch.primitive_function - def coroutine3(input: Input) -> Output: + async def coroutine3(input: Input) -> Output: if input.is_first_call: counter = input.input else: @@ -398,11 +398,11 @@ def coroutine3(input: Input) -> Output: def test_coroutine_poll(self): @self.dispatch.primitive_function - def coro_compute_len(input: Input) -> Output: + async def coro_compute_len(input: Input) -> Output: return Output.value(len(input.input)) @self.dispatch.primitive_function - def coroutine_main(input: Input) -> Output: + async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( @@ -439,11 +439,11 @@ def coroutine_main(input: Input) -> Output: def test_coroutine_poll_error(self): @self.dispatch.primitive_function - def coro_compute_len(input: Input) -> Output: + async def coro_compute_len(input: Input) -> Output: return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) @self.dispatch.primitive_function - def coroutine_main(input: Input) -> Output: + async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( @@ -479,7 +479,7 @@ def coroutine_main(input: Input) -> Output: def test_coroutine_error(self): @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead")) resp = self.execute(mycoro) @@ -488,7 +488,7 @@ def mycoro(input: Input) -> Output: def test_coroutine_expected_exception(self): @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: try: 1 / 0 except ZeroDivisionError as e: @@ -513,7 +513,7 @@ def mycoro(): def test_specific_status(self): @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: return Output.error(Error(Status.THROTTLED, "foo", "bar")) resp = self.execute(mycoro) @@ -527,7 +527,7 @@ def other_coroutine(value: Any) -> str: return f"Hello {value}" @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: return Output.tail_call(other_coroutine._build_primitive_call(42)) resp = self.call(mycoro) diff --git a/tests/test_flask.py b/tests/test_flask.py index dbf5932..f681a17 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -50,7 +50,7 @@ def test_flask(self): dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/") @dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value( f"You told me: '{input.input}' ({len(input.input)} characters)" ) diff --git a/tests/test_http.py b/tests/test_http.py index c5f0c0f..597722d 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -83,7 +83,7 @@ def test_content_length_too_large(self): def test_simple_request(self): @self.dispatch.registry.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value( f"You told me: '{input.input}' ({len(input.input)} characters)" )