diff --git a/src/dispatch/asyncio.py b/src/dispatch/asyncio.py new file mode 100644 index 0000000..9e9584f --- /dev/null +++ b/src/dispatch/asyncio.py @@ -0,0 +1,113 @@ +import asyncio +import functools +import inspect +import signal +import threading + + +class Runner: + """Runner is a class similar to asyncio.Runner but that we use for backward + compatibility with Python 3.10 and earlier. + """ + + def __init__(self): + self._loop = asyncio.new_event_loop() + self._interrupt_count = 0 + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + def close(self): + try: + loop = self._loop + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, "shutdown_default_executor"): # Python 3.9+ + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + loop.close() + + def get_loop(self): + return self._loop + + def run(self, coro): + if not inspect.iscoroutine(coro): + raise ValueError("a coroutine was expected, got {!r}".format(coro)) + + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError( + "Runner.run() cannot be called from a running event loop" + ) + + task = self._loop.create_task(coro) + sigint_handler = None + + if ( + threading.current_thread() is threading.main_thread() + and signal.getsignal(signal.SIGINT) is signal.default_int_handler + ): + sigint_handler = functools.partial(self._on_sigint, main_task=task) + try: + signal.signal(signal.SIGINT, sigint_handler) + except ValueError: + # `signal.signal` may throw if `threading.main_thread` does + # not support signals (e.g. embedded interpreter with signals + # not registered - see gh-91880) + sigint_handler = None + + self._interrupt_count = 0 + try: + asyncio.set_event_loop(self._loop) + return self._loop.run_until_complete(task) + except asyncio.CancelledError: + if self._interrupt_count > 0: + uncancel = getattr(task, "uncancel", None) + if uncancel is not None and uncancel() == 0: + raise KeyboardInterrupt() + raise # CancelledError + finally: + asyncio.set_event_loop(None) + if ( + sigint_handler is not None + and signal.getsignal(signal.SIGINT) is sigint_handler + ): + signal.signal(signal.SIGINT, signal.default_int_handler) + + def _on_sigint(self, signum, frame, main_task): + self._interrupt_count += 1 + if self._interrupt_count == 1 and not main_task.done(): + main_task.cancel() + # wakeup loop if it is blocked by select() with long timeout + self._loop.call_soon_threadsafe(lambda: None) + return + raise KeyboardInterrupt() + + +def _cancel_all_tasks(loop): + to_cancel = asyncio.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 7d15978..5e54a48 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -25,6 +25,7 @@ def handler(event, context): from awslambdaric.lambda_context import LambdaContext +from dispatch.asyncio import Runner from dispatch.function import Registry from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb @@ -92,7 +93,8 @@ def handle( input = Input(req) try: - output = func._primitive_call(input) + with Runner() as runner: + output = runner.run(func._primitive_call(input)) except Exception: logger.error("function '%s' fatal error", req.function, exc_info=True) raise # FIXME diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 206e7bc..0901849 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -102,11 +102,7 @@ async def execute(request: fastapi.Request): # forcing execute() to be async. data: bytes = await request.body() - loop = asyncio.get_running_loop() - - content = await loop.run_in_executor( - None, - function_service_run, + content = await function_service_run( str(request.url), request.method, request.headers, diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index cbaded6..5f48dd1 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -22,6 +22,7 @@ def read_root(): from flask import Flask, make_response, request +from dispatch.asyncio import Runner from dispatch.function import Registry from dispatch.http import FunctionServiceError, function_service_run from dispatch.signature import Ed25519PublicKey, parse_verification_key @@ -89,14 +90,17 @@ def _handle_error(self, exc: FunctionServiceError): def _execute(self): data: bytes = request.get_data(cache=False) - content = function_service_run( - request.url, - request.method, - dict(request.headers), - data, - self, - self._verification_key, - ) + with Runner() as runner: + content = runner.run( + function_service_run( + request.url, + request.method, + dict(request.headers), + data, + self, + self._verification_key, + ), + ) res = make_response(content) res.content_type = "application/proto" diff --git a/src/dispatch/function.py b/src/dispatch/function.py index d7b4702..e9f7a32 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import inspect import logging import os @@ -7,6 +8,7 @@ from types import CoroutineType from typing import ( Any, + Awaitable, Callable, Coroutine, Dict, @@ -33,7 +35,7 @@ logger = logging.getLogger(__name__) -PrimitiveFunctionType: TypeAlias = Callable[[Input], Output] +PrimitiveFunctionType: TypeAlias = Callable[[Input], Awaitable[Output]] """A primitive function is a function that accepts a dispatch.proto.Input and unconditionally returns a dispatch.proto.Output. It must not raise exceptions. @@ -70,8 +72,8 @@ def endpoint(self, value: str): def name(self) -> str: return self._name - def _primitive_call(self, input: Input) -> Output: - return self._primitive_func(input) + async def _primitive_call(self, input: Input) -> Output: + return await self._primitive_func(input) def _primitive_dispatch(self, input: Any = None) -> DispatchID: [dispatch_id] = self._client.dispatch([self._build_primitive_call(input)]) @@ -226,6 +228,7 @@ def function(self, func: Callable[P, T]) -> Function[P, T]: ... def function(self, func): """Decorator that registers functions.""" name = func.__qualname__ + if not inspect.iscoroutinefunction(func): logger.info("registering function: %s", name) return self._register_function(name, func) @@ -237,23 +240,22 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: func = durable(func) @wraps(func) - async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - return func(*args, **kwargs) - - async_wrapper.__qualname__ = f"{name}_async" + async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, func, *args, **kwargs) - return self._register_coroutine(name, async_wrapper) + asyncio_wrapper.__qualname__ = f"{name}_asyncio" + return self._register_coroutine(name, asyncio_wrapper) def _register_coroutine( self, name: str, func: Callable[P, Coroutine[Any, Any, T]] ) -> Function[P, T]: logger.info("registering coroutine: %s", name) - func = durable(func) @wraps(func) - def primitive_func(input: Input) -> Output: - return OneShotScheduler(func).run(input) + async def primitive_func(input: Input) -> Output: + return await OneShotScheduler(func).run(input) primitive_func.__qualname__ = f"{name}_primitive" durable_primitive_func = durable(primitive_func) diff --git a/src/dispatch/http.py b/src/dispatch/http.py index d772cb1..f155260 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -8,6 +8,7 @@ from http_message_signatures import InvalidSignature +from dispatch.asyncio import Runner from dispatch.function import Registry from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb @@ -120,14 +121,17 @@ def do_POST(self): url = self.requestline # TODO: need full URL try: - content = function_service_run( - url, - method, - dict(self.headers), - data, - self.registry, - self.verification_key, - ) + with Runner() as runner: + content = runner.run( + function_service_run( + url, + method, + dict(self.headers), + data, + self.registry, + self.verification_key, + ) + ) except FunctionServiceError as e: return self.send_error_response(e.status, e.code, e.message) @@ -137,7 +141,7 @@ def do_POST(self): self.wfile.write(content) -def function_service_run( +async def function_service_run( url: str, method: str, headers: Mapping[str, str], @@ -184,7 +188,7 @@ def function_service_run( logger.info("running function '%s'", req.function) try: - output = func._primitive_call(input) + output = await func._primitive_call(input) except Exception: # This indicates that an exception was raised in a primitive # function. Primitive functions must catch exceptions, categorize diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index ab81220..921f513 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -1,7 +1,9 @@ +import asyncio import logging import pickle import sys from dataclasses import dataclass, field +from types import coroutine from typing import ( Any, Awaitable, @@ -265,9 +267,7 @@ class State: ready: List[Coroutine] next_coroutine_id: int next_call_id: int - prev_callers: List[Coroutine] - outstanding_calls: int @@ -326,9 +326,9 @@ def __init__( version, ) - def run(self, input: Input) -> Output: + async def run(self, input: Input) -> Output: try: - return self._run(input) + return await self._run(input) except Exception as e: logger.exception( "unexpected exception occurred during coroutine scheduling" @@ -373,7 +373,7 @@ def _rebuild_state(self, input: Input): logger.warning("state is incompatible", exc_info=True) raise IncompatibleStateError from e - def _run(self, input: Input) -> Output: + async def _run(self, input: Input) -> Output: if input.is_first_call: state = self._init_state(input) else: @@ -429,114 +429,30 @@ def _run(self, input: Input) -> Output: ) pending_calls: List[Call] = [] - while state.ready: - coroutine = state.ready.pop(0) - logger.debug("running %s", coroutine) - - assert coroutine.id not in state.suspended - - coroutine_yield = None - coroutine_result: Optional[CoroutineResult] = None - try: - coroutine_yield = coroutine.run() - except TailCall as tc: - coroutine_result = CoroutineResult( - coroutine_id=coroutine.id, call=tc.call, status=tc.status - ) - except StopIteration as e: - coroutine_result = CoroutineResult( - coroutine_id=coroutine.id, value=e.value - ) - except Exception as e: - logger.debug( - f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e - ) - coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e) - - # Handle coroutines that return or raise. - if coroutine_result is not None: - if coroutine_result.call is not None: - logger.debug( - "%s reset to %s", coroutine, coroutine_result.call.function - ) - elif coroutine_result.error is not None: - logger.debug("%s raised %s", coroutine, coroutine_result.error) - else: - logger.debug("%s returned %s", coroutine, coroutine_result.value) - - # If this is the main coroutine, we're done. - if coroutine.parent_id is None: - for suspended in state.suspended.values(): - suspended.coroutine.close() - if coroutine_result.error is not None: - return Output.error( - Error.from_exception(coroutine_result.error) - ) - if coroutine_result.call is not None: - return Output.tail_call( - tail_call=coroutine_result.call, - status=coroutine_result.status, - ) - return Output.value(coroutine_result.value) - - # Otherwise, notify the parent of the result. - try: - parent = state.suspended[coroutine.parent_id] - future = parent.result - assert future is not None - except (KeyError, AssertionError): - logger.warning("discarding %s", coroutine_result) - else: - future.add_result(coroutine_result) - if future.ready() and parent.id in state.suspended: - state.ready.insert(0, parent) - del state.suspended[parent.id] - logger.debug("parent %s is now ready", parent) - continue - - # Handle coroutines that yield. - logger.debug("%s yielded %s", coroutine, coroutine_yield) - if isinstance(coroutine_yield, Call): - call = coroutine_yield - call_id = state.next_call_id - state.next_call_id += 1 - call.correlation_id = correlation_id(coroutine.id, call_id) - logger.debug( - "enqueuing call %d (%s) for %s", - call_id, - call.function, - coroutine, + asyncio_tasks: List[asyncio.Task] = [] + + while state.ready or asyncio_tasks: + for coroutine in state.ready: + logger.debug("running %s", coroutine) + assert coroutine.id not in state.suspended + asyncio_tasks.append( + asyncio.create_task(run_coroutine(state, coroutine, pending_calls)) ) - pending_calls.append(call) - coroutine.result = CallFuture() - state.suspended[coroutine.id] = coroutine - state.prev_callers.append(coroutine) - state.outstanding_calls += 1 - - elif isinstance(coroutine_yield, AllDirective): - children = spawn_children(state, coroutine, coroutine_yield.awaitables) - - child_ids = [child.id for child in children] - coroutine.result = AllFuture(order=child_ids, waiting=set(child_ids)) - state.suspended[coroutine.id] = coroutine - - elif isinstance(coroutine_yield, AnyDirective): - children = spawn_children(state, coroutine, coroutine_yield.awaitables) - - child_ids = [child.id for child in children] - coroutine.result = AnyFuture(order=child_ids, waiting=set(child_ids)) - state.suspended[coroutine.id] = coroutine + state.ready.clear() - elif isinstance(coroutine_yield, RaceDirective): - children = spawn_children(state, coroutine, coroutine_yield.awaitables) - - coroutine.result = RaceFuture(waiting={child.id for child in children}) - state.suspended[coroutine.id] = coroutine + done, pending = await asyncio.wait( + asyncio_tasks, return_when=asyncio.FIRST_COMPLETED + ) + asyncio_tasks = list(pending) - else: - raise RuntimeError( - f"coroutine unexpectedly yielded '{coroutine_yield}'" - ) + for task in done: + coroutine_result = task.result() + if coroutine_result is None: + continue + for task in asyncio_tasks: + task.cancel() + await asyncio.gather(*asyncio_tasks, return_exceptions=True) + return coroutine_result # Serialize coroutines and scheduler state. logger.debug("serializing state") @@ -566,14 +482,142 @@ def _run(self, input: Input) -> Output: ) +async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): + return await make_coroutine(state, coroutine, pending_calls) + + +@coroutine +def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): + while True: + coroutine_yield = None + coroutine_result: Optional[CoroutineResult] = None + try: + coroutine_yield = coroutine.run() + except TailCall as tc: + coroutine_result = CoroutineResult( + coroutine_id=coroutine.id, call=tc.call, status=tc.status + ) + except StopIteration as e: + coroutine_result = CoroutineResult(coroutine_id=coroutine.id, value=e.value) + except Exception as e: + coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e) + logger.debug( + f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e + ) + + if coroutine_result is not None: + return set_coroutine_result(state, coroutine, coroutine_result) + + logger.debug("%s yielded %s", coroutine, coroutine_yield) + + if isinstance(coroutine_yield, Call): + return set_coroutine_call(state, coroutine, coroutine_yield, pending_calls) + + if isinstance(coroutine_yield, AllDirective): + return set_coroutine_all(state, coroutine, coroutine_yield.awaitables) + + if isinstance(coroutine_yield, AnyDirective): + return set_coroutine_any(state, coroutine, coroutine_yield.awaitables) + + if isinstance(coroutine_yield, RaceDirective): + return set_coroutine_race(state, coroutine, coroutine_yield.awaitables) + + yield coroutine_yield + + +def set_coroutine_result( + state: State, coroutine: Coroutine, coroutine_result: CoroutineResult +): + if coroutine_result.call is not None: + logger.debug("%s reset to %s", coroutine, coroutine_result.call.function) + elif coroutine_result.error is not None: + logger.debug("%s raised %s", coroutine, coroutine_result.error) + else: + logger.debug("%s returned %s", coroutine, coroutine_result.value) + + # If this is the main coroutine, we're done. + if coroutine.parent_id is None: + for suspended in state.suspended.values(): + suspended.coroutine.close() + if coroutine_result.error is not None: + return Output.error(Error.from_exception(coroutine_result.error)) + if coroutine_result.call is not None: + return Output.tail_call( + tail_call=coroutine_result.call, status=coroutine_result.status + ) + return Output.value(coroutine_result.value) + + # Otherwise, notify the parent of the result. + try: + parent = state.suspended[coroutine.parent_id] + future = parent.result + assert future is not None + except (KeyError, AssertionError): + logger.warning("discarding %s", coroutine_result) + else: + future.add_result(coroutine_result) + if future.ready() and parent.id in state.suspended: + state.ready.insert(0, parent) + del state.suspended[parent.id] + logger.debug("parent %s is now ready", parent) + return + + +def set_coroutine_call( + state: State, coroutine: Coroutine, call: Call, pending_calls: List[Call] +): + call_id = state.next_call_id + state.next_call_id += 1 + call.correlation_id = correlation_id(coroutine.id, call_id) + logger.debug("enqueuing call %d (%s) for %s", call_id, call.function, coroutine) + pending_calls.append(call) + coroutine.result = CallFuture() + state.suspended[coroutine.id] = coroutine + state.prev_callers.append(coroutine) + state.outstanding_calls += 1 + return + + +def set_coroutine_all( + state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...] +): + children = spawn_children(state, coroutine, awaitables) + child_ids = [child.id for child in children] + coroutine.result = AllFuture(order=child_ids, waiting=set(child_ids)) + state.suspended[coroutine.id] = coroutine + return + + +def set_coroutine_any( + state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...] +): + children = spawn_children(state, coroutine, awaitables) + child_ids = [child.id for child in children] + coroutine.result = AnyFuture(order=child_ids, waiting=set(child_ids)) + state.suspended[coroutine.id] = coroutine + return + + +def set_coroutine_race( + state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...] +): + children = spawn_children(state, coroutine, awaitables) + coroutine.result = RaceFuture(waiting={child.id for child in children}) + state.suspended[coroutine.id] = coroutine + return + + def spawn_children( state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...] ) -> List[Coroutine]: children = [] + for awaitable in awaitables: g = awaitable.__await__() + if not isinstance(g, DurableGenerator): raise TypeError("awaitable is not a @dispatch.function") + child_id = state.next_coroutine_id state.next_coroutine_id += 1 child = Coroutine(id=child_id, parent_id=coroutine.id, coroutine=g) @@ -582,7 +626,6 @@ def spawn_children( # Prepend children to get a depth-first traversal of coroutines. state.ready = children + state.ready - return children diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 4df24c3..a88c4b7 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -1,6 +1,7 @@ import unittest from typing import Any, Callable, List, Optional, Set, Type +from dispatch.asyncio import Runner from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall @@ -53,6 +54,12 @@ async def raises_error(): class TestOneShotScheduler(unittest.TestCase): + def setUp(self): + self.runner = Runner() + + def tearDown(self): + self.runner.close() + def test_main_return(self): @durable async def main(): @@ -105,7 +112,7 @@ async def main(*functions): self.assert_poll_call_functions(output, ["foo", "bar", "baz"]) - def test_depth_first_run(self): + def test_depth_run(self): @durable async def main(): return await gather( @@ -116,10 +123,15 @@ async def main(): ) output = self.start(main) - + # In this test, the output is deterministic, but it does not follow the + # order in which the coroutines are declared due to interleaving of the + # asyncio event loop. + # + # Note that the order could change between Python versions, so we might + # choose to remove this test, or adapt it in the future. self.assert_poll_call_functions( output, - ["a", "b", "c", "d", "e", "f", "g", "h"], + ["d", "h", "e", "f", "g", "a", "b", "c"], min_results=1, max_results=8, ) @@ -331,7 +343,8 @@ async def c_then_d(): @durable async def main(c_then_d): return await gather( - call_concurrently("a", "b"), + call_one("a"), + call_one("b"), c_then_d(), ) @@ -362,7 +375,7 @@ async def main(c_then_d): ], ) - self.assert_exit_result_value(output, [[a_result, b_result], c_result + 100]) + self.assert_exit_result_value(output, [a_result, b_result, c_result + 100]) def test_raise_indirect(self): @durable @@ -417,12 +430,14 @@ def start( **kwargs: Any, ) -> Output: input = Input.from_input_arguments(main.__qualname__, *args, **kwargs) - return OneShotScheduler( - main, - poll_min_results=poll_min_results, - poll_max_results=poll_max_results, - poll_max_wait_seconds=poll_max_wait_seconds, - ).run(input) + return self.runner.run( + OneShotScheduler( + main, + poll_min_results=poll_min_results, + poll_max_results=poll_max_results, + poll_max_wait_seconds=poll_max_wait_seconds, + ).run(input) + ) def resume( self, @@ -438,7 +453,7 @@ def resume( call_results, Error.from_exception(poll_error) if poll_error else None, ) - return OneShotScheduler(main).run(input) + return self.runner.run(OneShotScheduler(main).run(input)) def assert_exit(self, output: Output) -> exit_pb.Exit: response = output._message diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index dee353f..149b817 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -85,7 +85,7 @@ def test_fastapi_simple_request(self): dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/") @dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value( f"You told me: '{input.input}' ({len(input.input)} characters)" ) @@ -263,7 +263,7 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: def test_no_input(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value("Hello World!") resp = self.execute(my_function) @@ -282,7 +282,7 @@ def test_missing_coroutine(self): def test_string_input(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value(f"You sent '{input.input}'") resp = self.execute(my_function, input="cool stuff") @@ -291,7 +291,7 @@ def my_function(input: Input) -> Output: def test_error_on_access_state_in_first_call(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: try: print(input.coroutine_state) except ValueError: @@ -310,7 +310,7 @@ def my_function(input: Input) -> Output: def test_error_on_access_input_in_second_call(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: if input.is_first_call: return Output.poll(coroutine_state=b"42") try: @@ -334,22 +334,22 @@ def my_function(input: Input) -> Output: def test_duplicate_coro(self): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value("Do one thing") with self.assertRaises(ValueError): @self.dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value("Do something else") def test_two_simple_coroutines(self): @self.dispatch.primitive_function - def echoroutine(input: Input) -> Output: + async def echoroutine(input: Input) -> Output: return Output.value(f"Echo: '{input.input}'") @self.dispatch.primitive_function - def len_coroutine(input: Input) -> Output: + async def len_coroutine(input: Input) -> Output: return Output.value(f"Length: {len(input.input)}") data = "cool stuff" @@ -363,7 +363,7 @@ def len_coroutine(input: Input) -> Output: def test_coroutine_with_state(self): @self.dispatch.primitive_function - def coroutine3(input: Input) -> Output: + async def coroutine3(input: Input) -> Output: if input.is_first_call: counter = input.input else: @@ -398,11 +398,11 @@ def coroutine3(input: Input) -> Output: def test_coroutine_poll(self): @self.dispatch.primitive_function - def coro_compute_len(input: Input) -> Output: + async def coro_compute_len(input: Input) -> Output: return Output.value(len(input.input)) @self.dispatch.primitive_function - def coroutine_main(input: Input) -> Output: + async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( @@ -439,11 +439,11 @@ def coroutine_main(input: Input) -> Output: def test_coroutine_poll_error(self): @self.dispatch.primitive_function - def coro_compute_len(input: Input) -> Output: + async def coro_compute_len(input: Input) -> Output: return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) @self.dispatch.primitive_function - def coroutine_main(input: Input) -> Output: + async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( @@ -479,7 +479,7 @@ def coroutine_main(input: Input) -> Output: def test_coroutine_error(self): @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead")) resp = self.execute(mycoro) @@ -488,7 +488,7 @@ def mycoro(input: Input) -> Output: def test_coroutine_expected_exception(self): @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: try: 1 / 0 except ZeroDivisionError as e: @@ -513,7 +513,7 @@ def mycoro(): def test_specific_status(self): @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: return Output.error(Error(Status.THROTTLED, "foo", "bar")) resp = self.execute(mycoro) @@ -527,7 +527,7 @@ def other_coroutine(value: Any) -> str: return f"Hello {value}" @self.dispatch.primitive_function - def mycoro(input: Input) -> Output: + async def mycoro(input: Input) -> Output: return Output.tail_call(other_coroutine._build_primitive_call(42)) resp = self.call(mycoro) diff --git a/tests/test_flask.py b/tests/test_flask.py index dbf5932..f681a17 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -50,7 +50,7 @@ def test_flask(self): dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/") @dispatch.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value( f"You told me: '{input.input}' ({len(input.input)} characters)" ) diff --git a/tests/test_http.py b/tests/test_http.py index c5f0c0f..597722d 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -83,7 +83,7 @@ def test_content_length_too_large(self): def test_simple_request(self): @self.dispatch.registry.primitive_function - def my_function(input: Input) -> Output: + async def my_function(input: Input) -> Output: return Output.value( f"You told me: '{input.input}' ({len(input.input)} characters)" )