Skip to content

Commit

Permalink
GRPC client and server for ExecutorService
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Jan 29, 2024
1 parent e1d3904 commit 5d24447
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 7 deletions.
39 changes: 34 additions & 5 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
"""

import ring.coroutine.v1.coroutine_pb2

import os
import fastapi
from fastapi.responses import PlainTextResponse
import fastapi.responses
import google.protobuf.wrappers_pb2


def configure(
Expand All @@ -23,7 +26,6 @@ def configure(
mount_path: The path to mount Dispatch programmable endpoints at.
"""
api_key = api_key or os.environ.get("DISPATCH_API_KEY")
api_url = api_url or "https://api.stealthrocket.cloud"

if not app:
raise ValueError("app is required")
Expand All @@ -34,10 +36,37 @@ def configure(
if not mount_path:
raise ValueError("mount_path is required")

dispatch_app = fastapi.FastAPI()
dispatch_app = _new_app()

app.mount(mount_path, dispatch_app)


class GRPCResponse(fastapi.Response):
media_type = "application/grpc+proto"


def _new_app():
app = fastapi.FastAPI()

@dispatch_app.get("/", response_class=PlainTextResponse)
@app.get("/", response_class=fastapi.responses.PlainTextResponse)
def read_root():
return "ok"

app.mount(mount_path, dispatch_app)
@app.post("/ring.coroutine.v1.ExecutorService/Execute", response_class=GRPCResponse)
async def execute(request: fastapi.Request):
data: bytes = await request.body()

req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest.FromString(data)

# TODO: unpack any
input = google.protobuf.wrappers_pb2.StringValue
req.input.Unpack(input)

resp = ring.coroutine.v1.coroutine_pb2.ExecuteResponse(
coroutine_uri=req.coroutine_uri,
coroutine_version=req.coroutine_version,
)

return resp.SerializeToString()

return app
157 changes: 157 additions & 0 deletions tests/executor_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import httpx
import grpc
import ring.coroutine.v1.coroutine_pb2
import ring.coroutine.v1.coroutine_pb2_grpc


class UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def __init__(self, client, method, request_serializer, response_deserializer):
self.client = client
self.method = method
self.request_serializer = request_serializer
self.response_deserializer = response_deserializer

def __call__(
self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
"""Synchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip.
Returns:
The response value for the RPC.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""

response = self.client.post(
self.method,
content=self.request_serializer(request),
headers={"Content-Type": "application/grpc+proto"},
)
response.raise_for_status()
return self.response_deserializer(response.content)

def with_call(
self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
raise NotImplementedError()

def future(
self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
raise NotImplementedError()


class HttpxGrpcChannel(grpc.Channel):
def __init__(self, http_client: httpx.Client):
self.http_client = http_client

def subscribe(self, callback, try_to_connect=False):
raise NotImplementedError()

def unsubscribe(self, callback):
raise NotImplementedError()

def unary_unary(self, method, request_serializer=None, response_deserializer=None):
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
Args:
method: The name of the RPC method.
request_serializer: Optional :term:`serializer` for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
return UnaryUnaryMultiCallable(
self.http_client, method, request_serializer, response_deserializer
)

def unary_stream(self, method, request_serializer=None, response_deserializer=None):
raise NotImplementedError()

def stream_unary(self, method, request_serializer=None, response_deserializer=None):
raise NotImplementedError()

def stream_stream(
self, method, request_serializer=None, response_deserializer=None
):
raise NotImplementedError()

def close(self):
raise NotImplementedError()

def __enter__(self):
raise NotImplementedError()

def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError()


def client(
http_client: httpx.Client,
) -> ring.coroutine.v1.coroutine_pb2_grpc.ExecutorServiceStub:
channel = HttpxGrpcChannel(http_client)
return ring.coroutine.v1.coroutine_pb2_grpc.ExecutorServiceStub(channel)


# class GrpcHttpxClient:
# """Client for the ring.coroutine.v1.ExecutorService gRPC service over an
# httpx client.
# """

# def __init__(self, http_client: httpx.Client):
# self.http_client = http_client

# def execute(self, request: ring.coroutine.v1.coroutine_pb2.ExecuteRequest) -> ring.coroutine.v1.coroutine_pb2.ExecuteResponse:
# """Execute a coroutine.

# Args:
# request: The request to execute.

# Returns:
# The response from the coroutine.
# """

# response = self.http_client.post(
# "/ring.coroutine.v1.ExecutorService/Execute",
# content=request.SerializeToString(),
# headers={"Content-Type": "application/grpc+proto"},
# )
# response.raise_for_status()
# return ring.coroutine.v1.coroutine_pb2.ExecuteResponse.FromString(response.content)
41 changes: 40 additions & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import fastapi
from fastapi.testclient import TestClient

import ring.coroutine.v1.coroutine_pb2
from . import executor_service


class TestFastAPI(unittest.TestCase):
def test_fastapi(self):
def test_configure(self):
app = fastapi.FastAPI()

dispatch.fastapi.configure(app, api_key="test-key")
Expand All @@ -25,3 +28,39 @@ def read_root():
resp = client.get("/dispatch/")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, "ok")

def test_configure_no_app(self):
with self.assertRaises(ValueError):
dispatch.fastapi.configure(None, api_key="test-key")

def test_configure_no_api_key(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key=None)

def test_configure_no_api_url(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key="test-key", api_url=None)

def test_configure_no_mount_path(self):
app = fastapi.FastAPI()
with self.assertRaises(ValueError):
dispatch.fastapi.configure(app, api_key="test-key", mount_path=None)

def test_fastapi_empty_request(self):
app = dispatch.fastapi._new_app()
http_client = TestClient(app)

client = executor_service.client(http_client)

req = ring.coroutine.v1.coroutine_pb2.ExecuteRequest(
coroutine_uri="my-cool-coroutine",
coroutine_version="1",
)

resp = client.Execute(req)

self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse)
self.assertEqual(resp.coroutine_uri, req.coroutine_uri)
self.assertEqual(resp.coroutine_version, req.coroutine_version)
1 change: 0 additions & 1 deletion tests/test_ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def tearDown(self):
self.thread_pool.shutdown(wait=True, cancel_futures=True)

def test_ring(self):

request = dispatch.http.v1.http_pb2.Request(
url="https://www.google.com", method="GET"
)
Expand Down

0 comments on commit 5d24447

Please sign in to comment.