From 33f8716e0b4d888cbe6ce6d55d6efa640a01f6c1 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 2 Apr 2024 14:38:12 +0900 Subject: [PATCH 1/3] add special Reset class to implement tail calls Signed-off-by: Achille Roussel --- src/dispatch/function.py | 10 +++++++++- src/dispatch/proto.py | 18 ++++++++++++++++-- src/dispatch/scheduler.py | 27 +++++++++++++++++++++++++-- tests/dispatch/test_scheduler.py | 23 ++++++++++++++++++++++- 4 files changed, 72 insertions(+), 6 deletions(-) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index f401f950..c2fbfcb8 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -27,7 +27,7 @@ import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc from dispatch.experimental.durable import durable from dispatch.id import DispatchID -from dispatch.proto import Arguments, Call, Error, Input, Output +from dispatch.proto import Arguments, Call, Error, Input, Output, TailCall from dispatch.scheduler import OneShotScheduler logger = logging.getLogger(__name__) @@ -157,6 +157,14 @@ def build_call( ) +class Reset(TailCall): + """The current coroutine is aborted and scheduling reset to be replaced with + the call embedded in this exception.""" + + def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs): + super().__init__(call=func.build_call(*args, correlation_id=None, **kwargs)) + + class Registry: """Registry of functions.""" diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 1a9d5c5e..d297e64a 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -20,6 +20,20 @@ from dispatch.sdk.v1 import status_pb2 as status_pb from dispatch.status import Status, status_for_error, status_for_output + +class TailCall(Exception): + """The current coroutine is aborted and scheduling reset to be replaced with + the call embedded in this exception.""" + + call: Call + status: Status + + def __init__(self, call: Call, status: Status = Status.TEMPORARY_ERROR): + super().__init__(f"reset coroutine to ${call.function}") + self.call = call + self.status = status + + # Most types in this package are thin wrappers around the various protobuf # messages of dispatch.sdk.v1. They provide some safeguards and ergonomics. @@ -175,10 +189,10 @@ def error(cls, error: Error) -> Output: return cls.exit(result=CallResult.from_error(error), status=error.status) @classmethod - def tail_call(cls, tail_call: Call) -> Output: + def tail_call(cls, tail_call: Call, status: Status = Status.OK) -> Output: """Terminally exit the function, and instruct the orchestrator to tail call the specified function.""" - return cls.exit(tail_call=tail_call) + return cls.exit(tail_call=tail_call, status=status) @classmethod def exit( diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index bd69575a..4299ec27 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -20,7 +20,7 @@ from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective from dispatch.error import IncompatibleStateError from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator -from dispatch.proto import Call, Error, Input, Output +from dispatch.proto import Call, Error, Input, Output, TailCall from dispatch.status import Status logger = logging.getLogger(__name__) @@ -30,6 +30,14 @@ CorrelationID: TypeAlias = int +@dataclass +class Reset(Exception): + """A reset exception which discards the current coroutine and schedules + the provided call as argument instead""" + + def __init__(self): ... + + @dataclass class CoroutineResult: """The result from running a coroutine to completion.""" @@ -37,6 +45,8 @@ class CoroutineResult: coroutine_id: CoroutineID value: Optional[Any] = None error: Optional[Exception] = None + call: Optional[Call] = None + status: Status = Status.OK @dataclass @@ -438,6 +448,10 @@ def _run(self, input: Input) -> Output: coroutine_result: Optional[CoroutineResult] = None try: coroutine_yield = coroutine.run() + except TailCall as tc: + coroutine_result = CoroutineResult( + coroutine_id=coroutine.id, call=tc.call, status=tc.status + ) except StopIteration as e: coroutine_result = CoroutineResult( coroutine_id=coroutine.id, value=e.value @@ -450,7 +464,11 @@ def _run(self, input: Input) -> Output: # Handle coroutines that return or raise. if coroutine_result is not None: - if coroutine_result.error is not None: + if coroutine_result.call is not None: + logger.debug( + "%s reset to %s", coroutine, coroutine_result.call.function + ) + elif coroutine_result.error is not None: logger.debug("%s raised %s", coroutine, coroutine_result.error) else: logger.debug("%s returned %s", coroutine, coroutine_result.value) @@ -463,6 +481,11 @@ def _run(self, input: Input) -> Output: return Output.error( Error.from_exception(coroutine_result.error) ) + if coroutine_result.call is not None: + return Output.tail_call( + tail_call=coroutine_result.call, + status=coroutine_result.status, + ) return Output.value(coroutine_result.value) # Otherwise, notify the parent of the result. diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 2576da97..40da5f4b 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -3,7 +3,7 @@ from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable -from dispatch.proto import Call, CallResult, Error, Input, Output +from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall from dispatch.proto import _any_unpickle as any_unpickle from dispatch.scheduler import ( AllFuture, @@ -372,6 +372,21 @@ async def main(): output = self.start(main) self.assert_exit_result_error(output, ValueError, "oops") + def test_raise_reset(self): + @durable + async def main(x: int, y: int): + raise TailCall( + call=Call( + function="main", input=Arguments((), {"x": x + 1, "y": y + 2}) + ) + ) + + output = self.start(main, x=1, y=2) + self.assert_exit_tail_call( + output, + tail_call=Call(function="main", input=Arguments((), {"x": 2, "y": 4})), + ) + def test_min_max_results_clamping(self): @durable async def main(): @@ -457,6 +472,12 @@ def assert_exit_result_error( self.assertEqual(str(error), message) return error + def assert_exit_tail_call(self, output: Output, tail_call: Call): + exit = self.assert_exit(output) + self.assertFalse(exit.HasField("result")) + self.assertTrue(exit.HasField("tail_call")) + self.assertEqual(tail_call._as_proto(), exit.tail_call) + def assert_poll(self, output: Output) -> poll_pb.Poll: response = output._message if response.HasField("exit"): From b7fc88cd4b0da7ec90731bc3ce295f87796646ed Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 2 Apr 2024 14:43:39 +0900 Subject: [PATCH 2/3] export Reset type Signed-off-by: Achille Roussel --- src/dispatch/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 45a45e75..78c4be57 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -4,7 +4,7 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race -from dispatch.function import DEFAULT_API_URL, Client, Registry +from dispatch.function import DEFAULT_API_URL, Client, Registry, Reset from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output from dispatch.status import Status @@ -17,6 +17,7 @@ "Output", "Call", "Error", + "Reset", "Status", "call", "gather", From efbd8356996d326286d150319180a51d25408780 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 2 Apr 2024 14:44:55 +0900 Subject: [PATCH 3/3] remove development leftovers Signed-off-by: Achille Roussel --- src/dispatch/scheduler.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 4299ec27..ada9b2a3 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -30,14 +30,6 @@ CorrelationID: TypeAlias = int -@dataclass -class Reset(Exception): - """A reset exception which discards the current coroutine and schedules - the provided call as argument instead""" - - def __init__(self): ... - - @dataclass class CoroutineResult: """The result from running a coroutine to completion."""