diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 45a45e75..78c4be57 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -4,7 +4,7 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race -from dispatch.function import DEFAULT_API_URL, Client, Registry +from dispatch.function import DEFAULT_API_URL, Client, Registry, Reset from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output from dispatch.status import Status @@ -17,6 +17,7 @@ "Output", "Call", "Error", + "Reset", "Status", "call", "gather", diff --git a/src/dispatch/function.py b/src/dispatch/function.py index f401f950..c2fbfcb8 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -27,7 +27,7 @@ import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc from dispatch.experimental.durable import durable from dispatch.id import DispatchID -from dispatch.proto import Arguments, Call, Error, Input, Output +from dispatch.proto import Arguments, Call, Error, Input, Output, TailCall from dispatch.scheduler import OneShotScheduler logger = logging.getLogger(__name__) @@ -157,6 +157,14 @@ def build_call( ) +class Reset(TailCall): + """The current coroutine is aborted and scheduling reset to be replaced with + the call embedded in this exception.""" + + def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs): + super().__init__(call=func.build_call(*args, correlation_id=None, **kwargs)) + + class Registry: """Registry of functions.""" diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 1a9d5c5e..d297e64a 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -20,6 +20,20 @@ from dispatch.sdk.v1 import status_pb2 as status_pb from dispatch.status import Status, status_for_error, status_for_output + +class TailCall(Exception): + """The current coroutine is aborted and scheduling reset to be replaced with + the call embedded in this exception.""" + + call: Call + status: Status + + def __init__(self, call: Call, status: Status = Status.TEMPORARY_ERROR): + super().__init__(f"reset coroutine to ${call.function}") + self.call = call + self.status = status + + # Most types in this package are thin wrappers around the various protobuf # messages of dispatch.sdk.v1. They provide some safeguards and ergonomics. @@ -175,10 +189,10 @@ def error(cls, error: Error) -> Output: return cls.exit(result=CallResult.from_error(error), status=error.status) @classmethod - def tail_call(cls, tail_call: Call) -> Output: + def tail_call(cls, tail_call: Call, status: Status = Status.OK) -> Output: """Terminally exit the function, and instruct the orchestrator to tail call the specified function.""" - return cls.exit(tail_call=tail_call) + return cls.exit(tail_call=tail_call, status=status) @classmethod def exit( diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index bd69575a..ada9b2a3 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -20,7 +20,7 @@ from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective from dispatch.error import IncompatibleStateError from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator -from dispatch.proto import Call, Error, Input, Output +from dispatch.proto import Call, Error, Input, Output, TailCall from dispatch.status import Status logger = logging.getLogger(__name__) @@ -37,6 +37,8 @@ class CoroutineResult: coroutine_id: CoroutineID value: Optional[Any] = None error: Optional[Exception] = None + call: Optional[Call] = None + status: Status = Status.OK @dataclass @@ -438,6 +440,10 @@ def _run(self, input: Input) -> Output: coroutine_result: Optional[CoroutineResult] = None try: coroutine_yield = coroutine.run() + except TailCall as tc: + coroutine_result = CoroutineResult( + coroutine_id=coroutine.id, call=tc.call, status=tc.status + ) except StopIteration as e: coroutine_result = CoroutineResult( coroutine_id=coroutine.id, value=e.value @@ -450,7 +456,11 @@ def _run(self, input: Input) -> Output: # Handle coroutines that return or raise. if coroutine_result is not None: - if coroutine_result.error is not None: + if coroutine_result.call is not None: + logger.debug( + "%s reset to %s", coroutine, coroutine_result.call.function + ) + elif coroutine_result.error is not None: logger.debug("%s raised %s", coroutine, coroutine_result.error) else: logger.debug("%s returned %s", coroutine, coroutine_result.value) @@ -463,6 +473,11 @@ def _run(self, input: Input) -> Output: return Output.error( Error.from_exception(coroutine_result.error) ) + if coroutine_result.call is not None: + return Output.tail_call( + tail_call=coroutine_result.call, + status=coroutine_result.status, + ) return Output.value(coroutine_result.value) # Otherwise, notify the parent of the result. diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 2576da97..40da5f4b 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -3,7 +3,7 @@ from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable -from dispatch.proto import Call, CallResult, Error, Input, Output +from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall from dispatch.proto import _any_unpickle as any_unpickle from dispatch.scheduler import ( AllFuture, @@ -372,6 +372,21 @@ async def main(): output = self.start(main) self.assert_exit_result_error(output, ValueError, "oops") + def test_raise_reset(self): + @durable + async def main(x: int, y: int): + raise TailCall( + call=Call( + function="main", input=Arguments((), {"x": x + 1, "y": y + 2}) + ) + ) + + output = self.start(main, x=1, y=2) + self.assert_exit_tail_call( + output, + tail_call=Call(function="main", input=Arguments((), {"x": 2, "y": 4})), + ) + def test_min_max_results_clamping(self): @durable async def main(): @@ -457,6 +472,12 @@ def assert_exit_result_error( self.assertEqual(str(error), message) return error + def assert_exit_tail_call(self, output: Output, tail_call: Call): + exit = self.assert_exit(output) + self.assertFalse(exit.HasField("result")) + self.assertTrue(exit.HasField("tail_call")) + self.assertEqual(tail_call._as_proto(), exit.tail_call) + def assert_poll(self, output: Output) -> poll_pb.Poll: response = output._message if response.HasField("exit"):