From 6fb14006af181314c35c781fd6c7f008088e2b5a Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 22 Apr 2024 13:46:58 -0700 Subject: [PATCH] add dispatch.serve function Signed-off-by: Achille Roussel --- README.md | 2 +- src/dispatch/__init__.py | 41 ++++++++++++++++++++++++++++++++-------- src/dispatch/http.py | 8 -------- tests/test_http.py | 7 ++----- 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 5673c44b..a9745647 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ import dispatch def greet(msg: str): print(f"Hello, ${msg}!") -greet.dispatch('World') +dispatch.run(lambda: greet.dispatch('World')) ``` Obviously, this is just an example, a real application would perform much more diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index a5284396..a544a9d7 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -4,6 +4,7 @@ import os from concurrent import futures +from contextlib import contextmanager from http.server import ThreadingHTTPServer from typing import Any, Callable, Coroutine, Optional, TypeVar, overload from urllib.parse import urlsplit @@ -37,6 +38,7 @@ "gather", "race", "run", + "serve", ] @@ -46,7 +48,7 @@ _registry: Optional[Registry] = None -def _default_registry(): +def default_registry(): global _registry if not _registry: _registry = Registry() @@ -62,10 +64,10 @@ def function(func: Callable[P, T]) -> Function[P, T]: ... def function(func): - return _default_registry().function(func) + return default_registry().function(func) -def run(port: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")): +def run(init: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: """Run the default dispatch server on the given port. The default server uses a function registry where functions tagged by the `@dispatch.function` decorator are registered. @@ -75,12 +77,35 @@ def run(port: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")): to the Dispatch bridge API. Args: - port: The address to bind the server to. Defaults to the value of the + entrypoint: The entrypoint function to run. Defaults to a no-op function. + + args: Positional arguments to pass to the entrypoint. + + kwargs: Keyword arguments to pass to the entrypoint. + + Returns: + The return value of the entrypoint function. + """ + with serve(): + return init(*args, **kwargs) + + +@contextmanager +def serve(address: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")): + """Returns a context manager managing the operation of a Disaptch server + running on the given address. The server is initialized before the context + manager yields, then runs forever until the the program is interrupted. + + Args: + address: The address to bind the server to. Defaults to the value of the DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if it wasn't set. """ - print(f"Starting Dispatch server on {port}") - parsed_url = urlsplit("//" + port) + parsed_url = urlsplit("//" + address) server_address = (parsed_url.hostname or "", parsed_url.port or 0) - server = ThreadingHTTPServer(server_address, Dispatch(_default_registry())) - server.serve_forever() + server = ThreadingHTTPServer(server_address, Dispatch(default_registry())) + try: + yield server + server.serve_forever() + finally: + server.server_close() diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 052e3091..0635fc3f 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -1,7 +1,5 @@ """Integration of Dispatch functions with http.""" -from datetime import datetime - import logging import os from datetime import timedelta @@ -66,9 +64,7 @@ def __init__( self.registry = registry self.verification_key = verification_key self.error_content_type = "application/json" - print(datetime.now(), "INITIALIZING FUNCTION SERVICE") super().__init__(request, client_address, server) - print(datetime.now(), "DONE HANDLING REQUEST") def send_error_response_invalid_argument(self, message: str): self.send_error_response(400, "invalid_argument", message) @@ -91,9 +87,7 @@ def send_error_response(self, status: int, code: str, message: str): self.send_header("Content-Type", self.error_content_type) self.send_header("Content-Length", str(len(body))) self.end_headers() - print(datetime.now(), "SENDING ERROR RESPONSE") self.wfile.write(body) - print(datetime.now(), f"SERVER IS DONE {len(body)}") def do_POST(self): if self.path != "/dispatch.sdk.v1.FunctionService/Run": @@ -112,7 +106,6 @@ def do_POST(self): return data: bytes = self.rfile.read(content_length) - print(datetime.now(), f"RECEIVED POST REQUEST: {self.path} {len(data)} {self.request_version} {self.headers}") logger.debug("handling run request with %d byte body", len(data)) if self.verification_key is not None: @@ -150,7 +143,6 @@ def do_POST(self): ) return - print(datetime.now(), "running function '%s'", req.function) try: output = func._primitive_call(Input(req)) except Exception: diff --git a/tests/test_http.py b/tests/test_http.py index 9d6e0b7c..ad623e02 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -57,9 +57,6 @@ def tearDown(self): self.server.server_close() def test_Dispatch_defaults(self): - print("POST REQUEST", f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run") resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run") - print(resp.status_code) - print("CLIENT RESPONSE!", resp.headers) - #body = resp.read() - #self.assertEqual(resp.status_code, 400) + body = resp.read() + self.assertEqual(resp.status_code, 400)