Skip to content

Commit

Permalink
Sketch input API
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Jan 30, 2024
1 parent c15d708 commit d117e28
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 12 deletions.
5 changes: 2 additions & 3 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
import ring.task.v1.service_pb2_grpc as task_grpc

service = task_grpc.Service()
"""The Dispatch SDK for Python.
"""
44 changes: 44 additions & 0 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Dispatch coroutine interface.
"""

from typing import Any
from dataclasses import dataclass
import pickle


class Input:
"""The input to a coroutine.
Coroutines always take a single argument of type Input. If the coroutine is
started, it contains the input to the coroutine. If the coroutine is
resumed, it contains the saved state and response to any poll requests. Use
the is_first_call and is_resume properties to differentiate between the two
cases.
This class is intended to be used as read-only.
"""

# TODO: first implementation with a single Input type, but we should
# consider using some dynamic filling positional and keyword arguments.

def __init__(self, input: None | bytes, poll_response: None | Any):
# _has_input is used to tracked whether some bytes were provided, to
# differentiate with a pickled None.
self._has_input = input is not None
if input is not None:
self._input = pickle.loads(input) if len(input) > 0 else None

@property
def is_first_call(self) -> bool:
return self._has_input

@property
def is_resume(self) -> bool:
return not self.is_first_call

@property
def input(self) -> Any:
if not self._has_input:
raise ValueError("This input is for a resumed coroutine")
return self._input
22 changes: 17 additions & 5 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def read_root():
import fastapi
import fastapi.responses
import google.protobuf.wrappers_pb2
import dispatch.coroutine


def configure(
Expand Down Expand Up @@ -105,20 +106,31 @@ def read_root():
return "ok"

@app.post(
"/ring.coroutine.v1.ExecutorService/Execute", response_class=_GRPCResponse
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/ring.coroutine.v1.ExecutorService/Execute",
response_class=_GRPCResponse,
)
async def execute(request: fastapi.Request):
# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest.FromString(data)

# TODO: be more graceful. This will crash if the coroutine is not found,
# and the coroutine version is not taken into account.
coroutine = app._coroutines[_coroutine_uri_to_qualname(req.coroutine_uri)]

# TODO: unpack any
input = google.protobuf.wrappers_pb2.StringValue()
req.input.Unpack(input)
input_bytes = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_bytes)

output = coroutine(input.value)
coro_input = dispatch.coroutine.Input(
input=input_bytes.value, poll_response=None
)
output = coroutine(coro_input)

# TODO pack any
output_pb = google.protobuf.wrappers_pb2.StringValue(value=output)
Expand Down
40 changes: 36 additions & 4 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
import unittest
import dispatch
import dispatch.coroutine
import dispatch.fastapi
import fastapi
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -47,15 +48,17 @@ def test_fastapi_simple_request(self):
app = dispatch.fastapi._new_app()

@app.dispatch_coroutine()
def my_cool_coroutine(input):
return f"You told me: '{input}' ({len(input)} characters)"
def my_cool_coroutine(input: dispatch.coroutine.Input):
return f"You told me: '{input.input}' ({len(input.input)} characters)"

http_client = TestClient(app)

client = executor_service.client(http_client)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.StringValue(value="Hello World!"))
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest(
coroutine_uri=my_cool_coroutine.__qualname__,
coroutine_version="1",
Expand All @@ -72,3 +75,32 @@ def my_cool_coroutine(input):
output := google.protobuf.wrappers_pb2.StringValue()
)
self.assertEqual(output.value, "You told me: 'Hello World!' (12 characters)")


class TestCoroutine(unittest.TestCase):
def setUp(self):
self.app = dispatch.fastapi._new_app()
http_client = TestClient(self.app)
self.client = executor_service.client(http_client)

def execute(self, coroutine):
req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest(
coroutine_uri=coroutine.__qualname__,
coroutine_version="1",
)
resp = self.client.Execute(req)
return resp

def test_no_input(self):
@self.app.dispatch_coroutine()
def my_cool_coroutine(input: dispatch.coroutine.Input):
return "Hello World!"

resp = self.execute(my_cool_coroutine)

self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse)

resp.exit.result.output.Unpack(
output := google.protobuf.wrappers_pb2.StringValue()
)
self.assertEqual(output.value, "Hello World!")

0 comments on commit d117e28

Please sign in to comment.