Skip to content

Commit

Permalink
add dispatch.serve function
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Apr 22, 2024
1 parent 6376c6e commit 6fb1400
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ import dispatch
def greet(msg: str):
print(f"Hello, ${msg}!")

greet.dispatch('World')
dispatch.run(lambda: greet.dispatch('World'))
```

Obviously, this is just an example, a real application would perform much more
Expand Down
41 changes: 33 additions & 8 deletions src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
from concurrent import futures
from contextlib import contextmanager
from http.server import ThreadingHTTPServer
from typing import Any, Callable, Coroutine, Optional, TypeVar, overload
from urllib.parse import urlsplit
Expand Down Expand Up @@ -37,6 +38,7 @@
"gather",
"race",
"run",
"serve",
]


Expand All @@ -46,7 +48,7 @@
_registry: Optional[Registry] = None


def _default_registry():
def default_registry():
global _registry
if not _registry:
_registry = Registry()
Expand All @@ -62,10 +64,10 @@ def function(func: Callable[P, T]) -> Function[P, T]: ...


def function(func):
return _default_registry().function(func)
return default_registry().function(func)


def run(port: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")):
def run(init: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""Run the default dispatch server on the given port. The default server
uses a function registry where functions tagged by the `@dispatch.function`
decorator are registered.
Expand All @@ -75,12 +77,35 @@ def run(port: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")):
to the Dispatch bridge API.
Args:
port: The address to bind the server to. Defaults to the value of the
entrypoint: The entrypoint function to run. Defaults to a no-op function.
args: Positional arguments to pass to the entrypoint.
kwargs: Keyword arguments to pass to the entrypoint.
Returns:
The return value of the entrypoint function.
"""
with serve():
return init(*args, **kwargs)


@contextmanager
def serve(address: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")):
"""Returns a context manager managing the operation of a Disaptch server
running on the given address. The server is initialized before the context
manager yields, then runs forever until the the program is interrupted.
Args:
address: The address to bind the server to. Defaults to the value of the
DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if it
wasn't set.
"""
print(f"Starting Dispatch server on {port}")
parsed_url = urlsplit("//" + port)
parsed_url = urlsplit("//" + address)
server_address = (parsed_url.hostname or "", parsed_url.port or 0)
server = ThreadingHTTPServer(server_address, Dispatch(_default_registry()))
server.serve_forever()
server = ThreadingHTTPServer(server_address, Dispatch(default_registry()))
try:
yield server
server.serve_forever()
finally:
server.server_close()
8 changes: 0 additions & 8 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Integration of Dispatch functions with http."""

from datetime import datetime

import logging
import os
from datetime import timedelta
Expand Down Expand Up @@ -66,9 +64,7 @@ def __init__(
self.registry = registry
self.verification_key = verification_key
self.error_content_type = "application/json"
print(datetime.now(), "INITIALIZING FUNCTION SERVICE")
super().__init__(request, client_address, server)
print(datetime.now(), "DONE HANDLING REQUEST")

def send_error_response_invalid_argument(self, message: str):
self.send_error_response(400, "invalid_argument", message)
Expand All @@ -91,9 +87,7 @@ def send_error_response(self, status: int, code: str, message: str):
self.send_header("Content-Type", self.error_content_type)
self.send_header("Content-Length", str(len(body)))
self.end_headers()
print(datetime.now(), "SENDING ERROR RESPONSE")
self.wfile.write(body)
print(datetime.now(), f"SERVER IS DONE {len(body)}")

def do_POST(self):
if self.path != "/dispatch.sdk.v1.FunctionService/Run":
Expand All @@ -112,7 +106,6 @@ def do_POST(self):
return

data: bytes = self.rfile.read(content_length)
print(datetime.now(), f"RECEIVED POST REQUEST: {self.path} {len(data)} {self.request_version} {self.headers}")
logger.debug("handling run request with %d byte body", len(data))

if self.verification_key is not None:
Expand Down Expand Up @@ -150,7 +143,6 @@ def do_POST(self):
)
return

print(datetime.now(), "running function '%s'", req.function)
try:
output = func._primitive_call(Input(req))
except Exception:
Expand Down
7 changes: 2 additions & 5 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def tearDown(self):
self.server.server_close()

def test_Dispatch_defaults(self):
print("POST REQUEST", f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
print(resp.status_code)
print("CLIENT RESPONSE!", resp.headers)
#body = resp.read()
#self.assertEqual(resp.status_code, 400)
body = resp.read()
self.assertEqual(resp.status_code, 400)

0 comments on commit 6fb1400

Please sign in to comment.