-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Achille Roussel <[email protected]>
- Loading branch information
1 parent
7ae5255
commit 401b63a
Showing
5 changed files
with
284 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from typing import Optional, Union | ||
|
||
from aiohttp import web | ||
|
||
from dispatch.function import Registry | ||
from dispatch.http import ( | ||
FunctionServiceError, | ||
function_service_run, | ||
make_error_response_body, | ||
) | ||
from dispatch.signature import Ed25519PublicKey, parse_verification_key | ||
|
||
|
||
class Dispatch(web.Application): | ||
"""A Dispatch instance servicing as a http server.""" | ||
|
||
registry: Registry | ||
verification_key: Optional[Ed25519PublicKey] | ||
|
||
def __init__( | ||
self, | ||
registry: Registry, | ||
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, | ||
): | ||
"""Initialize a Dispatch application. | ||
Args: | ||
registry: The registry of functions to be serviced. | ||
verification_key: The verification key to use for requests. | ||
""" | ||
super().__init__() | ||
self.registry = registry | ||
self.verification_key = parse_verification_key(verification_key) | ||
self.add_routes( | ||
[ | ||
web.post( | ||
"/dispatch.sdk.v1.FunctionService/Run", self.handle_run_request | ||
), | ||
] | ||
) | ||
|
||
async def handle_run_request(self, request: web.Request) -> web.Response: | ||
return await function_service_run_handler( | ||
request, self.registry, self.verification_key | ||
) | ||
|
||
|
||
class Server: | ||
host: str | ||
port: int | ||
app: Dispatch | ||
|
||
_runner: web.AppRunner | ||
_site: web.TCPSite | ||
|
||
def __init__(self, host: str, port: int, app: Dispatch): | ||
self.host = host | ||
self.port = port | ||
self.app = app | ||
|
||
async def __aenter__(self): | ||
await self.start() | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_value, traceback): | ||
await self.stop() | ||
|
||
async def start(self): | ||
self._runner = web.AppRunner(self.app) | ||
await self._runner.setup() | ||
|
||
self._site = web.TCPSite(self._runner, self.host, self.port) | ||
await self._site.start() | ||
|
||
async def stop(self): | ||
await self._site.stop() | ||
await self._runner.cleanup() | ||
|
||
|
||
def make_error_response(status: int, code: str, message: str) -> web.Response: | ||
body = make_error_response_body(code, message) | ||
return web.Response(status=status, content_type="application/json", body=body) | ||
|
||
|
||
def make_error_response_invalid_argument(message: str) -> web.Response: | ||
return make_error_response(400, "invalid_argument", message) | ||
|
||
|
||
def make_error_response_not_found(message: str) -> web.Response: | ||
return make_error_response(404, "not_found", message) | ||
|
||
|
||
def make_error_response_unauthenticated(message: str) -> web.Response: | ||
return make_error_response(401, "unauthenticated", message) | ||
|
||
|
||
def make_error_response_permission_denied(message: str) -> web.Response: | ||
return make_error_response(403, "permission_denied", message) | ||
|
||
|
||
def make_error_response_internal(message: str) -> web.Response: | ||
return make_error_response(500, "internal", message) | ||
|
||
|
||
async def function_service_run_handler( | ||
request: web.Request, | ||
function_registry: Registry, | ||
verification_key: Optional[Ed25519PublicKey], | ||
) -> web.Response: | ||
content_length = request.content_length | ||
if content_length is None or content_length == 0: | ||
return make_error_response_invalid_argument("content length is required") | ||
if content_length < 0: | ||
return make_error_response_invalid_argument("content length is negative") | ||
if content_length > 16_000_000: | ||
return make_error_response_invalid_argument("content length is too large") | ||
|
||
data: bytes = await request.read() | ||
try: | ||
content = await function_service_run( | ||
str(request.url), | ||
request.method, | ||
dict(request.headers), | ||
data, | ||
function_registry, | ||
verification_key, | ||
) | ||
except FunctionServiceError as e: | ||
return make_error_response(e.status, e.code, e.message) | ||
return web.Response(status=200, content_type="application/proto", body=content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import asyncio | ||
import base64 | ||
import os | ||
import pickle | ||
import struct | ||
import threading | ||
import unittest | ||
from typing import Any, Tuple | ||
from unittest import mock | ||
|
||
import fastapi | ||
import google.protobuf.any_pb2 | ||
import google.protobuf.wrappers_pb2 | ||
import httpx | ||
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey | ||
|
||
import dispatch.test.httpx | ||
from dispatch.experimental.durable.registry import clear_functions | ||
from dispatch.function import Arguments, Error, Function, Input, Output, Registry | ||
from dispatch.asyncio import Runner | ||
from dispatch.aiohttp import Dispatch, Server | ||
from dispatch.proto import _any_unpickle as any_unpickle | ||
from dispatch.sdk.v1 import call_pb2 as call_pb | ||
from dispatch.sdk.v1 import function_pb2 as function_pb | ||
from dispatch.signature import parse_verification_key, public_key_from_pem | ||
from dispatch.status import Status | ||
from dispatch.test import EndpointClient | ||
|
||
public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----" | ||
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----" | ||
public_key = public_key_from_pem(public_key_pem) | ||
public_key_bytes = public_key.public_bytes_raw() | ||
public_key_b64 = base64.b64encode(public_key_bytes) | ||
|
||
from datetime import datetime | ||
|
||
|
||
def run(runner: Runner, server: Server, ready: threading.Event): | ||
try: | ||
with runner: | ||
runner.run(serve(server, ready)) | ||
except RuntimeError as e: | ||
pass # silence errors triggered by stopping the loop after tests are done | ||
|
||
async def serve(server: Server, ready: threading.Event): | ||
async with server: | ||
ready.set() # allow the test to continue after the server started | ||
await asyncio.Event().wait() | ||
|
||
|
||
class TestAIOHTTP(unittest.TestCase): | ||
def setUp(self): | ||
ready = threading.Event() | ||
self.runner = Runner() | ||
|
||
host = "127.0.0.1" | ||
port = 9997 | ||
|
||
self.endpoint = f"http://{host}:{port}" | ||
self.dispatch = Dispatch( | ||
Registry( | ||
endpoint=self.endpoint, | ||
api_key="0000000000000000", | ||
api_url="http://127.0.0.1:10000", | ||
), | ||
) | ||
|
||
self.client = httpx.Client(timeout=1.0) | ||
self.server = Server(host, port, self.dispatch) | ||
self.thread = threading.Thread(target=lambda: run(self.runner, self.server, ready)) | ||
self.thread.start() | ||
ready.wait() | ||
|
||
def tearDown(self): | ||
loop = self.runner.get_loop() | ||
loop.call_soon_threadsafe(loop.stop) | ||
self.thread.join(timeout=1.0) | ||
self.client.close() | ||
|
||
def test_content_length_missing(self): | ||
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run") | ||
body = resp.read() | ||
self.assertEqual(resp.status_code, 400) | ||
self.assertEqual( | ||
body, b'{"code":"invalid_argument","message":"content length is required"}' | ||
) | ||
|
||
def test_content_length_too_large(self): | ||
resp = self.client.post( | ||
f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run", | ||
data={"msg": "a" * 16_000_001}, | ||
) | ||
body = resp.read() | ||
self.assertEqual(resp.status_code, 400) | ||
self.assertEqual( | ||
body, b'{"code":"invalid_argument","message":"content length is too large"}' | ||
) | ||
|
||
def test_simple_request(self): | ||
@self.dispatch.registry.primitive_function | ||
async def my_function(input: Input) -> Output: | ||
return Output.value( | ||
f"You told me: '{input.input}' ({len(input.input)} characters)" | ||
) | ||
|
||
http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) | ||
client = EndpointClient(http_client) | ||
|
||
pickled = pickle.dumps("Hello World!") | ||
input_any = google.protobuf.any_pb2.Any() | ||
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) | ||
|
||
req = function_pb.RunRequest( | ||
function=my_function.name, | ||
input=input_any, | ||
) | ||
|
||
resp = client.run(req) | ||
|
||
self.assertIsInstance(resp, function_pb.RunResponse) | ||
|
||
resp.exit.result.output.Unpack( | ||
output_bytes := google.protobuf.wrappers_pb2.BytesValue() | ||
) | ||
output = pickle.loads(output_bytes.value) | ||
|
||
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters