Skip to content

Commit

Permalink
First stateful coroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Jan 30, 2024
1 parent c41ea4b commit 046a1d3
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 28 deletions.
62 changes: 52 additions & 10 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from __future__ import annotations
from typing import Any
from dataclasses import dataclass
import google.protobuf.message
import pickle
from ring.coroutine.v1 import coroutine_pb2


class Input:
Expand All @@ -32,12 +34,22 @@ class Input:
# TODO: first implementation with a single Input type, but we should
# consider using some dynamic filling positional and keyword arguments.

def __init__(self, input: None | bytes, poll_response: None | Any):
# _has_input is used to tracked whether some bytes were provided, to
# differentiate with a pickled None.
self._has_input = input is not None
if input is not None:
self._input = pickle.loads(input) if len(input) > 0 else None
def __init__(self, req: coroutine_pb2.ExecuteRequest):
self._has_input = req.HasField("input")
if self._has_input:
input_pb = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_pb)
input_bytes = input_pb.value
if len(input_bytes) > 0:
self._input = pickle.loads(input_bytes)
else:
self._input = None
else:
state_bytes = req.poll_response.state
if len(state_bytes) > 0:
self._state = pickle.loads(state_bytes)
else:
self._state = None

@property
def is_first_call(self) -> bool:
Expand All @@ -49,10 +61,16 @@ def is_resume(self) -> bool:

@property
def input(self) -> Any:
if not self._has_input:
if self.is_resume:
raise ValueError("This input is for a resumed coroutine")
return self._input

@property
def state(self) -> Any:
if self.is_first_call:
raise ValueError("This input is for a first coroutine call")
return self._state


class Output:
"""The output of a coroutine.
Expand All @@ -61,9 +79,33 @@ class Output:
to indicate the follow up action they need to take.
"""

def __init__(self, value: None | Any = None):
self._value = pickle.dumps(value)
def __init__(self, proto: coroutine_pb2.ExecuteResponse):
self._message = proto

@classmethod
def value(cls, value: Any) -> Output:
return Output(value=value)
"""Terminally exit the coroutine with the provided return value."""
output_any = _pb_any_pickle(value)
return Output(
coroutine_pb2.ExecuteResponse(
exit=coroutine_pb2.Exit(result=coroutine_pb2.Result(output=output_any))
)
)

@classmethod
def callback(cls, state: Any) -> Output:
"""Exit the coroutine instructing the orchestrator to call back this
coroutine with the provided state. The state will be made available in
Input.state."""
state_bytes = pickle.dumps(state)
return Output(
coroutine_pb2.ExecuteResponse(poll=coroutine_pb2.Poll(state=state_bytes))
)


def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
value_bytes = pickle.dumps(x)
pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes)
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any
21 changes: 4 additions & 17 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,17 @@ async def execute(request: fastapi.Request):
# and the coroutine version is not taken into account.
coroutine = app._coroutines[_coroutine_uri_to_qualname(req.coroutine_uri)]

input_bytes = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_bytes)

coro_input = dispatch.coroutine.Input(
input=input_bytes.value, poll_response=None
)
coro_input = dispatch.coroutine.Input(req)
output = coroutine(coro_input)

if not isinstance(output, dispatch.coroutine.Output):
raise ValueError(
f"coroutine output should be an instance of {dispatch.coroutine.Output}, not {type(output)}"
)

output_pb = google.protobuf.wrappers_pb2.BytesValue(value=output._value)
output_any = google.protobuf.any_pb2.Any()
output_any.Pack(output_pb)

resp = ring.coroutine.v1.coroutine_pb2.ExecuteResponse(
coroutine_uri=req.coroutine_uri,
coroutine_version=req.coroutine_version,
exit=ring.coroutine.v1.coroutine_pb2.Exit(
result=ring.coroutine.v1.coroutine_pb2.Result(output=output_any)
),
)
resp = output._message
resp.coroutine_uri = req.coroutine_uri
resp.coroutine_version = req.coroutine_version

return fastapi.Response(content=resp.SerializeToString())

Expand Down
39 changes: 38 additions & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def setUp(self):
self.client = executor_service.client(http_client)

def execute(
self, coroutine, input=None
self, coroutine, input=None, state=None
) -> ring.coroutine.v1.coroutine_pb2.ExecuteResponse:
"""Test helper to invoke coroutines on the local server."""
req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest(
Expand All @@ -110,6 +110,9 @@ def execute(
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes))
req.input.CopyFrom(input_any)
if state is not None:
print("SENDING BACK STATE: ", state)
req.poll_response.state = state

resp = self.client.Execute(req)
self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse)
Expand Down Expand Up @@ -150,3 +153,37 @@ def len_coroutine(input: Input) -> Output:
resp = self.execute(len_coroutine, input=data)
out = response_output(resp)
self.assertEqual(out, "Length: 10")

def test_coroutine_with_state(self):
@self.app.dispatch_coroutine()
def coroutine3(input: Input) -> Output:
if input.is_first_call:
counter = input.input
else:
counter = input.state
counter -= 1
if counter <= 0:
return Output.value("done")
return Output.callback(state=counter)

# first call
resp = self.execute(coroutine3, input=4)
state = resp.poll.state
self.assertTrue(len(state) > 0)

# resume, state = 3
resp = self.execute(coroutine3, state=state)
state = resp.poll.state
self.assertTrue(len(state) > 0)

# resume, state = 2
resp = self.execute(coroutine3, state=state)
state = resp.poll.state
self.assertTrue(len(state) > 0)

# resume, state = 1
resp = self.execute(coroutine3, state=state)
state = resp.poll.state
self.assertTrue(len(state) == 0)
out = response_output(resp)
self.assertEqual(out, "done")

0 comments on commit 046a1d3

Please sign in to comment.