diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 08ad5fe..ef0ebe9 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -5,7 +5,7 @@ import asyncio import os from http.server import ThreadingHTTPServer -from typing import Any, Callable, Coroutine, Optional, TypeVar, overload +from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar, overload from urllib.parse import urlsplit from typing_extensions import ParamSpec, TypeAlias @@ -96,7 +96,7 @@ async def main(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: def run(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: - """Run the default dispatch server. The default server uses a function + """Run the default Dispatch server. The default server uses a function registry where functions tagged by the `@dispatch.function` decorator are registered. @@ -119,9 +119,27 @@ def run(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: return asyncio.run(main(coro, addr)) -def run_forever(): - """Run the default dispatch server forever.""" - return run(asyncio.Event().wait()) +def run_forever( + coro: Optional[Coroutine[Any, Any, T]] = None, addr: Optional[str] = None +): + """Run the default Dispatch server forever. + + Args: + coro: A coroutine to optionally run as the entrypoint. + + addr: The address to bind the server to. If not provided, the server + will bind to the address specified by the `DISPATCH_ENDPOINT_ADDR` + environment variable. If the environment variable is not set, the + server will bind to `localhost:8000`. + """ + wait = asyncio.Event().wait() + coro = chain(coro, wait) if coro is not None else wait + return run(coro=coro, addr=addr) + + +async def chain(*awaitables: Awaitable[Any]): + for a in awaitables: + await a def batch() -> Batch: