diff --git a/examples/auto_retry.py b/examples/auto_retry.py index 5eff8bd..fe932c9 100644 --- a/examples/auto_retry.py +++ b/examples/auto_retry.py @@ -12,6 +12,7 @@ def third_party_api_call(x): # Simulate a third-party API call that fails. print(f"Simulating third-party API call with {x}") if x < 3: + print("RAISE EXCEPTION") raise requests.RequestException("Simulated failure") else: return "SUCCESS" @@ -19,9 +20,10 @@ def third_party_api_call(x): # Use the `dispatch.function` decorator to declare a stateful function. @dispatch.function -def application(): +def auto_retry(): x = rng.randint(0, 5) return third_party_api_call(x) -dispatch.run(application()) +if __name__ == "__main__": + print(dispatch.run(auto_retry())) diff --git a/examples/fanout.py b/examples/fanout.py index 856d496..883b31a 100644 --- a/examples/fanout.py +++ b/examples/fanout.py @@ -42,4 +42,5 @@ async def fanout(): return await reduce_stargazers(repos) -print(dispatch.run(fanout())) +if __name__ == "__main__": + print(dispatch.run(fanout())) diff --git a/examples/getting_started.py b/examples/getting_started.py index 38ad9cf..17ff7f8 100644 --- a/examples/getting_started.py +++ b/examples/getting_started.py @@ -3,7 +3,6 @@ import dispatch -# Use the `dispatch.function` decorator declare a stateful function. @dispatch.function def publish(url, payload): r = requests.post(url, data=payload) @@ -11,7 +10,10 @@ def publish(url, payload): return r.text -# Use the `dispatch.run` function to run the function with automatic error -# handling and retries. -res = dispatch.run(publish("https://httpstat.us/200", {"hello": "world"})) -print(res) +@dispatch.function +async def getting_started(): + return await publish("https://httpstat.us/200", {"hello": "world"}) + + +if __name__ == "__main__": + print(dispatch.run(getting_started())) diff --git a/examples/github_stats.py b/examples/github_stats.py index c1dc6db..0636882 100644 --- a/examples/github_stats.py +++ b/examples/github_stats.py @@ -1,19 +1,3 @@ -"""Github repository stats example. - -This example demonstrates how to use async functions orchestrated by Dispatch. - -Make sure to follow the setup instructions at -https://docs.dispatch.run/dispatch/stateful-functions/getting-started/ - -Run with: - -uvicorn app:app - - -Logs will show a pipeline of functions being called and their results. - -""" - import httpx import dispatch @@ -31,21 +15,21 @@ def get_gh_api(url): @dispatch.function -async def get_repo_info(repo_owner, repo_name): +def get_repo_info(repo_owner, repo_name): url = f"https://api.github.com/repos/{repo_owner}/{repo_name}" repo_info = get_gh_api(url) return repo_info @dispatch.function -async def get_contributors(repo_info): +def get_contributors(repo_info): url = repo_info["contributors_url"] contributors = get_gh_api(url) return contributors @dispatch.function -async def main(): +async def github_stats(): repo_info = await get_repo_info("dispatchrun", "coroutine") print( f"""Repository: {repo_info['full_name']} @@ -57,5 +41,5 @@ async def main(): if __name__ == "__main__": - contributors = dispatch.run(main()) + contributors = dispatch.run(github_stats()) print(f"Contributors: {len(contributors)}") diff --git a/examples/test_examples.py b/examples/test_examples.py new file mode 100644 index 0000000..4d63a51 --- /dev/null +++ b/examples/test_examples.py @@ -0,0 +1,29 @@ +import dispatch.test + +from .auto_retry import auto_retry +from .fanout import fanout +from .getting_started import getting_started +from .github_stats import github_stats + + +@dispatch.test.function +async def test_auto_retry(): + assert await auto_retry() == "SUCCESS" + + +@dispatch.test.function +async def test_fanout(): + contributors = await fanout() + assert len(contributors) >= 15 + assert "achille-roussel" in contributors + + +@dispatch.test.function +async def test_getting_started(): + assert await getting_started() == "200 OK" + + +@dispatch.test.function +async def test_github_stats(): + contributors = await github_stats() + assert len(contributors) >= 6 diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index b0cbb09..c6dfdc5 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -32,7 +32,10 @@ CoroutineID: TypeAlias = int CorrelationID: TypeAlias = int -_in_function_call = contextvars.ContextVar("dispatch.scheduler.in_function_call", default=False) +_in_function_call = contextvars.ContextVar( + "dispatch.scheduler.in_function_call", default=False +) + def in_function_call() -> bool: return bool(_in_function_call.get()) @@ -523,7 +526,11 @@ def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call] if isinstance(coroutine_yield, RaceDirective): return set_coroutine_race(state, coroutine, coroutine_yield.awaitables) - yield coroutine_yield + try: + yield coroutine_yield + except Exception as e: + coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e) + return set_coroutine_result(state, coroutine, coroutine_result) def set_coroutine_result( diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 6650234..7250c96 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -28,7 +28,17 @@ from dispatch.sdk.v1.error_pb2 import Error from dispatch.sdk.v1.function_pb2 import RunRequest, RunResponse from dispatch.sdk.v1.poll_pb2 import PollResult -from dispatch.sdk.v1.status_pb2 import STATUS_OK +from dispatch.sdk.v1.status_pb2 import ( + STATUS_DNS_ERROR, + STATUS_HTTP_ERROR, + STATUS_INCOMPATIBLE_STATE, + STATUS_OK, + STATUS_TCP_ERROR, + STATUS_TEMPORARY_ERROR, + STATUS_THROTTLED, + STATUS_TIMEOUT, + STATUS_TLS_ERROR, +) from .client import EndpointClient from .server import DispatchServer @@ -183,7 +193,18 @@ def make_request(call: Call) -> RunRequest: res = await self.run(call.endpoint, req) if res.status != STATUS_OK: - # TODO: emulate retries etc... + if res.status in ( + STATUS_TIMEOUT, + STATUS_THROTTLED, + STATUS_TEMPORARY_ERROR, + STATUS_INCOMPATIBLE_STATE, + STATUS_DNS_ERROR, + STATUS_TCP_ERROR, + STATUS_TLS_ERROR, + STATUS_HTTP_ERROR, + ): + continue # emulate retries, without backoff for now + if ( res.HasField("exit") and res.exit.HasField("result") @@ -263,14 +284,19 @@ async def main(coro: Coroutine[Any, Any, None]) -> None: api = Service() app = Dispatch(reg) try: + print("Starting bakend") async with Server(api) as backend: + print("Starting server") async with Server(app) as server: # Here we break through the abstraction layers a bit, it's not # ideal but it works for now. reg.client.api_url.value = backend.url reg.endpoint = server.url + print("BACKEND:", backend.url) + print("SERVER:", server.url) await coro finally: + print("DONE!") await api.close() # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially.