Skip to content

Commit

Permalink
WIP: fix tests blocked reading http request body
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Apr 20, 2024
1 parent f3e982b commit 6376c6e
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 8 deletions.
6 changes: 6 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Dockerfile
__pycache__
*.md
*.yaml
*.yml
dist/*
10 changes: 10 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.12
WORKDIR /usr/src/dispatch-py

COPY pyproject.toml .
RUN python -m pip install -e .[dev]

COPY . .
RUN python -m pip install -e .[dev]

ENTRYPOINT ["python"]
5 changes: 3 additions & 2 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from concurrent import futures
from http.server import HTTPServer
from http.server import ThreadingHTTPServer
from typing import Any, Callable, Coroutine, Optional, TypeVar, overload
from urllib.parse import urlsplit

Expand Down Expand Up @@ -79,7 +79,8 @@ def run(port: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")):
DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if it
wasn't set.
"""
print(f"Starting Dispatch server on {port}")
parsed_url = urlsplit("//" + port)
server_address = (parsed_url.hostname or "", parsed_url.port or 0)
server = HTTPServer(server_address, Dispatch(_default_registry()))
server = ThreadingHTTPServer(server_address, Dispatch(_default_registry()))
server.serve_forever()
28 changes: 24 additions & 4 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Integration of Dispatch functions with http."""

from datetime import datetime

import logging
import os
from datetime import timedelta
Expand Down Expand Up @@ -61,10 +63,12 @@ def __init__(
registry: Registry,
verification_key: Optional[Ed25519PublicKey] = None,
):
super().__init__(request, client_address, server)
self.registry = registry
self.verification_key = verification_key
self.error_content_type = "application/json"
print(datetime.now(), "INITIALIZING FUNCTION SERVICE")
super().__init__(request, client_address, server)
print(datetime.now(), "DONE HANDLING REQUEST")

def send_error_response_invalid_argument(self, message: str):
self.send_error_response(400, "invalid_argument", message)
Expand All @@ -82,17 +86,33 @@ def send_error_response_internal(self, message: str):
self.send_error_response(500, "internal", message)

def send_error_response(self, status: int, code: str, message: str):
body = f'{{"code":"{code}","message":"{message}"}}'.encode()
self.send_response(status)
self.send_header("Content-Type", self.error_content_type)
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(f'{{"code":"{code}","message":"{message}"}}'.encode())
print(datetime.now(), "SENDING ERROR RESPONSE")
self.wfile.write(body)
print(datetime.now(), f"SERVER IS DONE {len(body)}")

def do_POST(self):
if self.path != "/dispatch.sdk.v1.FunctionService/Run":
self.send_error_response_not_found("path not found")
return

data: bytes = self.rfile.read()
content_length = int(self.headers.get("Content-Length", 0))
if content_length == 0:
self.send_error_response_invalid_argument("content length is required")
return
if content_length < 0:
self.send_error_response_invalid_argument("content length is negative")
return
if content_length > 16_000_000:
self.send_error_response_invalid_argument("content length is too large")
return

data: bytes = self.rfile.read(content_length)
print(datetime.now(), f"RECEIVED POST REQUEST: {self.path} {len(data)} {self.request_version} {self.headers}")
logger.debug("handling run request with %d byte body", len(data))

if self.verification_key is not None:
Expand Down Expand Up @@ -130,7 +150,7 @@ def do_POST(self):
)
return

logger.info("running function '%s'", req.function)
print(datetime.now(), "running function '%s'", req.function)
try:
output = func._primitive_call(Input(req))
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from fastapi.testclient import TestClient

from dispatch.experimental.durable.registry import clear_functions
from dispatch.fastapi import Dispatch, parse_verification_key
from dispatch.fastapi import Dispatch
from dispatch.function import Arguments, Error, Function, Input, Output
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import public_key_from_pem
from dispatch.signature import parse_verification_key, public_key_from_pem
from dispatch.status import Status
from dispatch.test import EndpointClient

Expand Down
65 changes: 65 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import base64
import os
import pickle
import struct
import threading
import unittest
from typing import Any
from unittest import mock

import fastapi
import google.protobuf.any_pb2
import google.protobuf.wrappers_pb2
import httpx
from http.server import HTTPServer
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey

from dispatch.experimental.durable.registry import clear_functions
from dispatch.http import Dispatch
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import parse_verification_key, public_key_from_pem
from dispatch.status import Status
from dispatch.test import EndpointClient

public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
public_key = public_key_from_pem(public_key_pem)
public_key_bytes = public_key.public_bytes_raw()
public_key_b64 = base64.b64encode(public_key_bytes)


def create_dispatch_instance(endpoint: str):
return Dispatch(
Registry(
endpoint=endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
),
)


class TestHTTP(unittest.TestCase):
def setUp(self):
self.server_address = ('127.0.0.1', 9999)
self.endpoint = f"http://{self.server_address[0]}:{self.server_address[1]}"
self.client = httpx.Client(timeout=1.0)
self.server = HTTPServer(self.server_address, create_dispatch_instance(self.endpoint))
self.thread = threading.Thread(target=self.server.serve_forever)
self.thread.start()

def tearDown(self):
self.server.shutdown()
self.thread.join(timeout=1.0)
self.client.close()
self.server.server_close()

def test_Dispatch_defaults(self):
print("POST REQUEST", f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
print(resp.status_code)
print("CLIENT RESPONSE!", resp.headers)
#body = resp.read()
#self.assertEqual(resp.status_code, 400)

0 comments on commit 6376c6e

Please sign in to comment.