Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Apr 22, 2024
1 parent 6fb1400 commit 6e14930
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 81 deletions.
14 changes: 10 additions & 4 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,27 @@ def run(init: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:


@contextmanager
def serve(address: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")):
def serve(
address: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000"),
poll_interval: float = 0.5,
):
"""Returns a context manager managing the operation of a Disaptch server
running on the given address. The server is initialized before the context
manager yields, then runs forever until the the program is interrupted.
Args:
address: The address to bind the server to. Defaults to the value of the
DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if it
wasn't set.
DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if
it wasn't set.
poll_interval: Poll for shutdown every poll_interval seconds.
Defaults to 0.5 seconds.
"""
parsed_url = urlsplit("//" + address)
server_address = (parsed_url.hostname or "", parsed_url.port or 0)
server = ThreadingHTTPServer(server_address, Dispatch(default_registry()))
try:
yield server
server.serve_forever()
server.serve_forever(poll_interval=poll_interval)
finally:
server.server_close()
84 changes: 84 additions & 0 deletions tests/dispatch/signature/test_signature.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import base64
import os
import unittest
from datetime import datetime, timedelta
from unittest import mock

from http_message_signatures import HTTPMessageSigner
from http_message_signatures._algorithms import ED25519

from dispatch.signature import (
CaseInsensitiveDict,
Ed25519PublicKey,
InvalidSignature,
Request,
parse_verification_key,
sign_request,
verify_request,
)
Expand All @@ -33,6 +38,18 @@
"""
)

public_key2_pem = """-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
-----END PUBLIC KEY-----
"""
public_key2_pem2 = """-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
-----END PUBLIC KEY-----
"""
public_key2 = public_key_from_pem(public_key2_pem)
public_key2_bytes = public_key2.public_bytes_raw()
public_key2_b64 = base64.b64encode(public_key2_bytes)


class TestSignature(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -125,3 +142,70 @@ def test_known_signature(self):
ValueError, "public key 'test-key-ed25519' not available"
):
verify_request(request, public_key, max_age=timedelta(weeks=9000))

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem})
def test_parse_verification_key_env_pem_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem2})
def test_parse_verification_key_env_pem_escaped_newline_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

@mock.patch.dict(
os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_b64.decode()}
)
def test_parse_verification_key_env_b64_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

def test_parse_verification_key_none(self):
# The verification key is optional. Both Dispatch(verification_key=...) and
# DISPATCH_VERIFICATION_KEY may be omitted/None.
verification_key = parse_verification_key(None)
self.assertIsNone(verification_key)

def test_parse_verification_key_ed25519publickey(self):
verification_key = parse_verification_key(public_key2)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

def test_parse_verification_key_pem_str(self):
verification_key = parse_verification_key(public_key2_pem)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

def test_parse_verification_key_pem_escaped_newline_str(self):
verification_key = parse_verification_key(public_key2_pem2)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

def test_parse_verification_key_pem_bytes(self):
verification_key = parse_verification_key(public_key2_pem.encode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

def test_parse_verification_key_b64_str(self):
verification_key = parse_verification_key(public_key2_b64.decode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

def test_parse_verification_key_b64_bytes(self):
verification_key = parse_verification_key(public_key2_b64)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)

def test_parse_verification_key_invalid(self):
with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"):
parse_verification_key("foo")

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"})
def test_parse_verification_key_invalid_env(self):
with self.assertRaisesRegex(
ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'"
):
parse_verification_key(None)
71 changes: 0 additions & 71 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
from dispatch.status import Status
from dispatch.test import EndpointClient

public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
public_key = public_key_from_pem(public_key_pem)
public_key_bytes = public_key.public_bytes_raw()
public_key_b64 = base64.b64encode(public_key_bytes)


def create_dispatch_instance(app, endpoint):
return Dispatch(
Expand Down Expand Up @@ -107,71 +101,6 @@ def my_function(input: Input) -> Output:

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem})
def test_parse_verification_key_env_pem_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem2})
def test_parse_verification_key_env_pem_escaped_newline_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_b64.decode()})
def test_parse_verification_key_env_b64_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_none(self):
# The verification key is optional. Both Dispatch(verification_key=...) and
# DISPATCH_VERIFICATION_KEY may be omitted/None.
verification_key = parse_verification_key(None)
self.assertIsNone(verification_key)

def test_parse_verification_key_ed25519publickey(self):
verification_key = parse_verification_key(public_key)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_pem_str(self):
verification_key = parse_verification_key(public_key_pem)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_pem_escaped_newline_str(self):
verification_key = parse_verification_key(public_key_pem2)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_pem_bytes(self):
verification_key = parse_verification_key(public_key_pem.encode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_b64_str(self):
verification_key = parse_verification_key(public_key_b64.decode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_b64_bytes(self):
verification_key = parse_verification_key(public_key_b64)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_invalid(self):
with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"):
parse_verification_key("foo")

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"})
def test_parse_verification_key_invalid_env(self):
with self.assertRaisesRegex(
ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'"
):
parse_verification_key(None)


def response_output(resp: function_pb.RunResponse) -> Any:
return any_unpickle(resp.exit.result.output)
Expand Down
60 changes: 54 additions & 6 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
import struct
import threading
import unittest
from http.server import HTTPServer
from typing import Any
from unittest import mock

import fastapi
import google.protobuf.any_pb2
import google.protobuf.wrappers_pb2
import httpx
from http.server import HTTPServer
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey

from dispatch.experimental.durable.registry import clear_functions
from dispatch.http import Dispatch
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.http import Dispatch
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
Expand All @@ -30,6 +30,8 @@
public_key_bytes = public_key.public_bytes_raw()
public_key_b64 = base64.b64encode(public_key_bytes)

from datetime import datetime


def create_dispatch_instance(endpoint: str):
return Dispatch(
Expand All @@ -43,11 +45,14 @@ def create_dispatch_instance(endpoint: str):

class TestHTTP(unittest.TestCase):
def setUp(self):
self.server_address = ('127.0.0.1', 9999)
self.server_address = ("127.0.0.1", 9999)
self.endpoint = f"http://{self.server_address[0]}:{self.server_address[1]}"
self.dispatch = create_dispatch_instance(self.endpoint)
self.client = httpx.Client(timeout=1.0)
self.server = HTTPServer(self.server_address, create_dispatch_instance(self.endpoint))
self.thread = threading.Thread(target=self.server.serve_forever)
self.server = HTTPServer(self.server_address, self.dispatch)
self.thread = threading.Thread(
target=lambda: self.server.serve_forever(poll_interval=0.05)
)
self.thread.start()

def tearDown(self):
Expand All @@ -56,7 +61,50 @@ def tearDown(self):
self.client.close()
self.server.server_close()

def test_Dispatch_defaults(self):
def test_content_length_missing(self):
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
body = resp.read()
self.assertEqual(resp.status_code, 400)
self.assertEqual(
body, b'{"code":"invalid_argument","message":"content length is required"}'
)

def test_content_length_too_large(self):
resp = self.client.post(
f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run",
data=b"a" * 16_000_001,
)
body = resp.read()
self.assertEqual(resp.status_code, 400)
self.assertEqual(
body, b'{"code":"invalid_argument","message":"content length is too large"}'
)

def test_simple_request(self):
@self.dispatch.registry.primitive_function
def my_function(input: Input) -> Output:
return Output.value(
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

client = EndpointClient.from_url(self.endpoint)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

0 comments on commit 6e14930

Please sign in to comment.