Skip to content

Commit

Permalink
port http tests to generic dispatch test suite
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Jun 12, 2024
1 parent ea8cd41 commit ed50efc
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 207 deletions.
12 changes: 11 additions & 1 deletion src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def read_root():
import fastapi.responses

from dispatch.function import Registry
from dispatch.http import FunctionServiceError, function_service_run
from dispatch.http import (
FunctionServiceError,
function_service_run,
validate_content_length,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,6 +101,12 @@ async def on_error(request: fastapi.Request, exc: FunctionServiceError):
"/Run",
)
async def execute(request: fastapi.Request):
valid, reason = validate_content_length(
int(request.headers.get("content-length", 0))
)
if not valid:
raise FunctionServiceError(400, "invalid_argument", reason)

# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
Expand Down
10 changes: 9 additions & 1 deletion src/dispatch/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ def read_root():
from flask import Flask, make_response, request

from dispatch.function import Registry
from dispatch.http import FunctionServiceError, function_service_run
from dispatch.http import (
FunctionServiceError,
function_service_run,
validate_content_length,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,6 +91,10 @@ def _handle_error(self, exc: FunctionServiceError):
return {"code": exc.code, "message": exc.message}, exc.status

def _execute(self):
valid, reason = validate_content_length(request.content_length or 0)
if not valid:
return {"code": "invalid_argument", "message": reason}, 400

data: bytes = request.get_data(cache=False)

content = asyncio.run(
Expand Down
33 changes: 17 additions & 16 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from datetime import timedelta
from http.server import BaseHTTPRequestHandler
from typing import Iterable, List, Mapping, Optional, Union
from typing import Iterable, List, Mapping, Optional, Tuple, Union

from aiohttp import web
from http_message_signatures import InvalidSignature
Expand Down Expand Up @@ -34,6 +34,16 @@ def __init__(self, status, code, message):
self.message = message


def validate_content_length(content_length: int) -> Tuple[bool, str]:
if content_length == 0:
return False, "content length is required"
if content_length < 0:
return False, "content length is negative"
if content_length > 16_000_000:
return False, "content length is too large"
return True, ""


class FunctionService(BaseHTTPRequestHandler):

def __init__(
Expand Down Expand Up @@ -78,14 +88,9 @@ def do_POST(self):
return

content_length = int(self.headers.get("Content-Length", 0))
if content_length == 0:
self.send_error_response_invalid_argument("content length is required")
return
if content_length < 0:
self.send_error_response_invalid_argument("content length is negative")
return
if content_length > 16_000_000:
self.send_error_response_invalid_argument("content length is too large")
valid, reason = validate_content_length(content_length)
if not valid:
self.send_error_response_invalid_argument(reason)
return

data: bytes = self.rfile.read(content_length)
Expand Down Expand Up @@ -229,13 +234,9 @@ async def function_service_run_handler(
function_registry: Registry,
verification_key: Optional[Ed25519PublicKey],
) -> web.Response:
content_length = request.content_length
if content_length is None or content_length == 0:
return make_error_response_invalid_argument("content length is required")
if content_length < 0:
return make_error_response_invalid_argument("content length is negative")
if content_length > 16_000_000:
return make_error_response_invalid_argument("content length is too large")
valid, reason = validate_content_length(request.content_length or 0)
if not valid:
return make_error_response_invalid_argument(reason)

data: bytes = await request.read()
try:
Expand Down
60 changes: 60 additions & 0 deletions src/dispatch/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import threading
import unittest
from datetime import datetime, timedelta
Expand Down Expand Up @@ -289,6 +290,16 @@ def wrapper(self: T):
return wrapper


def aiotest(
fn: Callable[["TestCase"], Coroutine[Any, Any, None]]
) -> Callable[["TestCase"], None]:
@wraps(fn)
def wrapper(self):
self.loop.run_until_complete(fn(self))

return wrapper


class TestCase(unittest.TestCase):

def dispatch_test_init(self, api_key: str, api_url: str) -> BaseRegistry:
Expand Down Expand Up @@ -325,6 +336,42 @@ def tearDown(self):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()

# TODO: let's figure out how to get rid of this global registry
# state at some point, which forces tests to be run sequentially.
dispatch.experimental.durable.registry.clear_functions()

@aiotest
async def test_content_length_missing(self):
async with aiohttp.ClientSession(
request_class=ClientRequestContentLengthMissing
) as session:
async with await session.post(
f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run",
) as resp:
data = await resp.read()
print(data)
assert resp.status == 400
assert json.loads(data) == {
"code": "invalid_argument",
"message": "content length is required",
}

@aiotest
async def test_content_length_too_large(self):
async with aiohttp.ClientSession(
request_class=ClientRequestContentLengthTooLarge
) as session:
async with await session.post(
f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run",
) as resp:
data = await resp.read()
print(data)
assert resp.status == 400
assert json.loads(data) == {
"code": "invalid_argument",
"message": "content length is too large",
}

def test_simple_end_to_end(self):
@self.dispatch.function
def my_function(name: str) -> str:
Expand All @@ -335,3 +382,16 @@ async def test():
assert msg == "Hello world: 52"

self.loop.run_until_complete(test())


class ClientRequestContentLengthMissing(aiohttp.ClientRequest):
def update_headers(self, skip_auto_headers):
super().update_headers(skip_auto_headers)
if "Content-Length" in self.headers:
del self.headers["Content-Length"]


class ClientRequestContentLengthTooLarge(aiohttp.ClientRequest):
def update_headers(self, skip_auto_headers):
super().update_headers(skip_auto_headers)
self.headers["Content-Length"] = "16000001"
14 changes: 6 additions & 8 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,24 @@
public_key_from_pem,
)
from dispatch.status import Status
from dispatch.test import Client, DispatchServer, DispatchService, EndpointClient
from dispatch.test import EndpointClient
from dispatch.test.fastapi import http_client


class TestFastAPI(dispatch.test.TestCase):

def dispatch_test_init(self, api_key: str, api_url: str) -> Dispatch:
host = "localhost"
port = 56789
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("localhost", 0))
sock.listen(128)

(host, port) = sock.getsockname()

app = FastAPI()
reg = Dispatch(
app, endpoint=f"http://{host}:{port}", api_key=api_key, api_url=api_url
)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(128)

config = uvicorn.Config(app, host=host, port=port)
self.sockets = [sock]
self.uvicorn = uvicorn.Server(config)
Expand Down
Loading

0 comments on commit ed50efc

Please sign in to comment.