Skip to content

Commit

Permalink
add integration with aiohttp
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Jun 4, 2024
1 parent 7ae5255 commit 401b63a
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 18 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ readme = "README.md"
dynamic = ["version"]
requires-python = ">= 3.8"
dependencies = [
"aiohttp >= 3.9.4",
"grpcio >= 1.60.0",
"protobuf >= 4.24.0",
"types-protobuf >= 4.24.0.20240129",
Expand Down
131 changes: 131 additions & 0 deletions src/dispatch/aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Optional, Union

from aiohttp import web

from dispatch.function import Registry
from dispatch.http import (
FunctionServiceError,
function_service_run,
make_error_response_body,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key


class Dispatch(web.Application):
"""A Dispatch instance servicing as a http server."""

registry: Registry
verification_key: Optional[Ed25519PublicKey]

def __init__(
self,
registry: Registry,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
):
"""Initialize a Dispatch application.
Args:
registry: The registry of functions to be serviced.
verification_key: The verification key to use for requests.
"""
super().__init__()
self.registry = registry
self.verification_key = parse_verification_key(verification_key)
self.add_routes(
[
web.post(
"/dispatch.sdk.v1.FunctionService/Run", self.handle_run_request
),
]
)

async def handle_run_request(self, request: web.Request) -> web.Response:
return await function_service_run_handler(
request, self.registry, self.verification_key
)


class Server:
host: str
port: int
app: Dispatch

_runner: web.AppRunner
_site: web.TCPSite

def __init__(self, host: str, port: int, app: Dispatch):
self.host = host
self.port = port
self.app = app

async def __aenter__(self):
await self.start()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.stop()

async def start(self):
self._runner = web.AppRunner(self.app)
await self._runner.setup()

self._site = web.TCPSite(self._runner, self.host, self.port)
await self._site.start()

async def stop(self):
await self._site.stop()
await self._runner.cleanup()


def make_error_response(status: int, code: str, message: str) -> web.Response:
body = make_error_response_body(code, message)
return web.Response(status=status, content_type="application/json", body=body)


def make_error_response_invalid_argument(message: str) -> web.Response:
return make_error_response(400, "invalid_argument", message)


def make_error_response_not_found(message: str) -> web.Response:
return make_error_response(404, "not_found", message)


def make_error_response_unauthenticated(message: str) -> web.Response:
return make_error_response(401, "unauthenticated", message)


def make_error_response_permission_denied(message: str) -> web.Response:
return make_error_response(403, "permission_denied", message)


def make_error_response_internal(message: str) -> web.Response:
return make_error_response(500, "internal", message)


async def function_service_run_handler(
request: web.Request,
function_registry: Registry,
verification_key: Optional[Ed25519PublicKey],
) -> web.Response:
content_length = request.content_length
if content_length is None or content_length == 0:
return make_error_response_invalid_argument("content length is required")
if content_length < 0:
return make_error_response_invalid_argument("content length is negative")
if content_length > 16_000_000:
return make_error_response_invalid_argument("content length is too large")

data: bytes = await request.read()
try:
content = await function_service_run(
str(request.url),
request.method,
dict(request.headers),
data,
function_registry,
verification_key,
)
except FunctionServiceError as e:
return make_error_response(e.status, e.code, e.message)
return web.Response(status=200, content_type="application/proto", body=content)
15 changes: 11 additions & 4 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@


class Dispatch:
"""A Dispatch instance to be serviced by a http server. The Dispatch class
acts as a factory for DispatchHandler objects, by capturing the variables
that would be shared between all DispatchHandler instances it created."""
"""A Dispatch instance servicing as a http server."""

registry: Registry
verification_key: Optional[Ed25519PublicKey]

def __init__(
self,
Expand All @@ -38,6 +39,8 @@ def __init__(
Args:
registry: The registry of functions to be serviced.
verification_key: The verification key to use for requests.
"""
self.registry = registry
self.verification_key = parse_verification_key(verification_key)
Expand Down Expand Up @@ -92,7 +95,7 @@ def send_error_response_internal(self, message: str):
self.send_error_response(500, "internal", message)

def send_error_response(self, status: int, code: str, message: str):
body = f'{{"code":"{code}","message":"{message}"}}'.encode()
body = make_error_response_body(code, message)
self.send_response(status)
self.send_header("Content-Type", self.error_content_type)
self.send_header("Content-Length", str(len(body)))
Expand Down Expand Up @@ -234,3 +237,7 @@ async def function_service_run(

logger.debug("finished handling run request with status %s", status.name)
return response.SerializeToString()


def make_error_response_body(code: str, message: str) -> bytes:
return f'{{"code":"{code}","message":"{message}"}}'.encode()
127 changes: 127 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import asyncio
import base64
import os
import pickle
import struct
import threading
import unittest
from typing import Any, Tuple
from unittest import mock

import fastapi
import google.protobuf.any_pb2
import google.protobuf.wrappers_pb2
import httpx
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey

import dispatch.test.httpx
from dispatch.experimental.durable.registry import clear_functions
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.asyncio import Runner
from dispatch.aiohttp import Dispatch, Server
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import parse_verification_key, public_key_from_pem
from dispatch.status import Status
from dispatch.test import EndpointClient

public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
public_key = public_key_from_pem(public_key_pem)
public_key_bytes = public_key.public_bytes_raw()
public_key_b64 = base64.b64encode(public_key_bytes)

from datetime import datetime


def run(runner: Runner, server: Server, ready: threading.Event):
try:
with runner:
runner.run(serve(server, ready))
except RuntimeError as e:
pass # silence errors triggered by stopping the loop after tests are done

async def serve(server: Server, ready: threading.Event):
async with server:
ready.set() # allow the test to continue after the server started
await asyncio.Event().wait()


class TestAIOHTTP(unittest.TestCase):
def setUp(self):
ready = threading.Event()
self.runner = Runner()

host = "127.0.0.1"
port = 9997

self.endpoint = f"http://{host}:{port}"
self.dispatch = Dispatch(
Registry(
endpoint=self.endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
),
)

self.client = httpx.Client(timeout=1.0)
self.server = Server(host, port, self.dispatch)
self.thread = threading.Thread(target=lambda: run(self.runner, self.server, ready))
self.thread.start()
ready.wait()

def tearDown(self):
loop = self.runner.get_loop()
loop.call_soon_threadsafe(loop.stop)
self.thread.join(timeout=1.0)
self.client.close()

def test_content_length_missing(self):
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
body = resp.read()
self.assertEqual(resp.status_code, 400)
self.assertEqual(
body, b'{"code":"invalid_argument","message":"content length is required"}'
)

def test_content_length_too_large(self):
resp = self.client.post(
f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run",
data={"msg": "a" * 16_000_001},
)
body = resp.read()
self.assertEqual(resp.status_code, 400)
self.assertEqual(
body, b'{"code":"invalid_argument","message":"content length is too large"}'
)

def test_simple_request(self):
@self.dispatch.registry.primitive_function
async def my_function(input: Input) -> Output:
return Output.value(
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
client = EndpointClient(http_client)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
28 changes: 14 additions & 14 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import unittest
from http.server import HTTPServer
from typing import Any
from typing import Any, Tuple
from unittest import mock

import fastapi
Expand Down Expand Up @@ -34,21 +34,21 @@
from datetime import datetime


def create_dispatch_instance(endpoint: str):
return Dispatch(
Registry(
endpoint=endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
),
)


class TestHTTP(unittest.TestCase):
def setUp(self):
self.server_address = ("127.0.0.1", 9999)
self.endpoint = f"http://{self.server_address[0]}:{self.server_address[1]}"
self.dispatch = create_dispatch_instance(self.endpoint)
host = "127.0.0.1"
port = 9999

self.server_address = (host, port)
self.endpoint = f"http://{host}:{port}"
self.dispatch = Dispatch(
Registry(
endpoint=self.endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
),
)

self.client = httpx.Client(timeout=1.0)
self.server = HTTPServer(self.server_address, self.dispatch)
self.thread = threading.Thread(
Expand Down

0 comments on commit 401b63a

Please sign in to comment.