Skip to content

Commit

Permalink
Merge pull request #145 from stealthrocket/raise-reset
Browse files Browse the repository at this point in the history
add special Reset class to implement tail calls
  • Loading branch information
achille-roussel authored Apr 2, 2024
2 parents 4463c21 + efbd835 commit 34b6f47
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import dispatch.integrations
from dispatch.coroutine import all, any, call, gather, race
from dispatch.function import DEFAULT_API_URL, Client, Registry
from dispatch.function import DEFAULT_API_URL, Client, Registry, Reset
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
from dispatch.status import Status
Expand All @@ -17,6 +17,7 @@
"Output",
"Call",
"Error",
"Reset",
"Status",
"call",
"gather",
Expand Down
10 changes: 9 additions & 1 deletion src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc
from dispatch.experimental.durable import durable
from dispatch.id import DispatchID
from dispatch.proto import Arguments, Call, Error, Input, Output
from dispatch.proto import Arguments, Call, Error, Input, Output, TailCall
from dispatch.scheduler import OneShotScheduler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -157,6 +157,14 @@ def build_call(
)


class Reset(TailCall):
"""The current coroutine is aborted and scheduling reset to be replaced with
the call embedded in this exception."""

def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs):
super().__init__(call=func.build_call(*args, correlation_id=None, **kwargs))


class Registry:
"""Registry of functions."""

Expand Down
18 changes: 16 additions & 2 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
from dispatch.sdk.v1 import status_pb2 as status_pb
from dispatch.status import Status, status_for_error, status_for_output


class TailCall(Exception):
"""The current coroutine is aborted and scheduling reset to be replaced with
the call embedded in this exception."""

call: Call
status: Status

def __init__(self, call: Call, status: Status = Status.TEMPORARY_ERROR):
super().__init__(f"reset coroutine to ${call.function}")
self.call = call
self.status = status


# Most types in this package are thin wrappers around the various protobuf
# messages of dispatch.sdk.v1. They provide some safeguards and ergonomics.

Expand Down Expand Up @@ -175,10 +189,10 @@ def error(cls, error: Error) -> Output:
return cls.exit(result=CallResult.from_error(error), status=error.status)

@classmethod
def tail_call(cls, tail_call: Call) -> Output:
def tail_call(cls, tail_call: Call, status: Status = Status.OK) -> Output:
"""Terminally exit the function, and instruct the orchestrator to
tail call the specified function."""
return cls.exit(tail_call=tail_call)
return cls.exit(tail_call=tail_call, status=status)

@classmethod
def exit(
Expand Down
19 changes: 17 additions & 2 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective
from dispatch.error import IncompatibleStateError
from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator
from dispatch.proto import Call, Error, Input, Output
from dispatch.proto import Call, Error, Input, Output, TailCall
from dispatch.status import Status

logger = logging.getLogger(__name__)
Expand All @@ -37,6 +37,8 @@ class CoroutineResult:
coroutine_id: CoroutineID
value: Optional[Any] = None
error: Optional[Exception] = None
call: Optional[Call] = None
status: Status = Status.OK


@dataclass
Expand Down Expand Up @@ -438,6 +440,10 @@ def _run(self, input: Input) -> Output:
coroutine_result: Optional[CoroutineResult] = None
try:
coroutine_yield = coroutine.run()
except TailCall as tc:
coroutine_result = CoroutineResult(
coroutine_id=coroutine.id, call=tc.call, status=tc.status
)
except StopIteration as e:
coroutine_result = CoroutineResult(
coroutine_id=coroutine.id, value=e.value
Expand All @@ -450,7 +456,11 @@ def _run(self, input: Input) -> Output:

# Handle coroutines that return or raise.
if coroutine_result is not None:
if coroutine_result.error is not None:
if coroutine_result.call is not None:
logger.debug(
"%s reset to %s", coroutine, coroutine_result.call.function
)
elif coroutine_result.error is not None:
logger.debug("%s raised %s", coroutine, coroutine_result.error)
else:
logger.debug("%s returned %s", coroutine, coroutine_result.value)
Expand All @@ -463,6 +473,11 @@ def _run(self, input: Input) -> Output:
return Output.error(
Error.from_exception(coroutine_result.error)
)
if coroutine_result.call is not None:
return Output.tail_call(
tail_call=coroutine_result.call,
status=coroutine_result.status,
)
return Output.value(coroutine_result.value)

# Otherwise, notify the parent of the result.
Expand Down
23 changes: 22 additions & 1 deletion tests/dispatch/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dispatch.coroutine import AnyException, any, call, gather, race
from dispatch.experimental.durable import durable
from dispatch.proto import Call, CallResult, Error, Input, Output
from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.scheduler import (
AllFuture,
Expand Down Expand Up @@ -372,6 +372,21 @@ async def main():
output = self.start(main)
self.assert_exit_result_error(output, ValueError, "oops")

def test_raise_reset(self):
@durable
async def main(x: int, y: int):
raise TailCall(
call=Call(
function="main", input=Arguments((), {"x": x + 1, "y": y + 2})
)
)

output = self.start(main, x=1, y=2)
self.assert_exit_tail_call(
output,
tail_call=Call(function="main", input=Arguments((), {"x": 2, "y": 4})),
)

def test_min_max_results_clamping(self):
@durable
async def main():
Expand Down Expand Up @@ -457,6 +472,12 @@ def assert_exit_result_error(
self.assertEqual(str(error), message)
return error

def assert_exit_tail_call(self, output: Output, tail_call: Call):
exit = self.assert_exit(output)
self.assertFalse(exit.HasField("result"))
self.assertTrue(exit.HasField("tail_call"))
self.assertEqual(tail_call._as_proto(), exit.tail_call)

def assert_poll(self, output: Output) -> poll_pb.Poll:
response = output._message
if response.HasField("exit"):
Expand Down

0 comments on commit 34b6f47

Please sign in to comment.