diff --git a/examples/auto_retry/test_app.py b/examples/auto_retry/test_app.py index abde35a7..719862a8 100644 --- a/examples/auto_retry/test_app.py +++ b/examples/auto_retry/test_app.py @@ -8,6 +8,8 @@ from fastapi.testclient import TestClient +import dispatch.sdk.v1.status_pb2 as status_pb + from ... import function_service from ...test_client import ServerTest @@ -36,4 +38,13 @@ def test_foo(self): server.execute(app_client) - self.assertEqual(len(servicer.responses), 1) + # Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6 + # calls, including 5 retries. + for i in range(6): + server.execute(app_client) + self.assertEqual(len(servicer.responses), 6) + + statuses = [r["response"].status for r in servicer.responses] + self.assertEqual( + statuses, [status_pb.STATUS_TEMPORARY_ERROR] * 5 + [status_pb.STATUS_OK] + ) diff --git a/tests/dispatch_service.py b/tests/dispatch_service.py index 40ad174f..9a07354c 100644 --- a/tests/dispatch_service.py +++ b/tests/dispatch_service.py @@ -7,6 +7,7 @@ import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc import dispatch.sdk.v1.function_pb2 as function_pb import dispatch.sdk.v1.function_pb2_grpc as function_grpc +import dispatch.sdk.v1.status_pb2 as status_pb from dispatch import Client, DispatchID _test_auth_token = "THIS_IS_A_TEST_AUTH_TOKEN" @@ -18,7 +19,7 @@ def __init__(self): self._next_dispatch_id = 1 self._pending_calls = [] - self.responses = {} # indexed by dispatch ID + self.responses = [] self.dispatched_calls = [] def _make_dispatch_id(self) -> DispatchID: @@ -53,7 +54,10 @@ def Dispatch(self, request: dispatch_pb.DispatchRequest, context): def execute(self, client: function_grpc.FunctionServiceStub): """Synchronously execute all pending function calls.""" - while len(self._pending_calls) > 0: + + _next_pending_calls = [] + + for entry in self._pending_calls: entry = self._pending_calls.pop(0) call: call_pb.Call = entry["call"] @@ -63,7 +67,38 @@ def execute(self, client: function_grpc.FunctionServiceStub): ) resp = client.Run(req) - self.responses[entry["dispatch_id"]] = resp + self.responses.append( + {"dispatch_id": entry["dispatch_id"], "response": resp} + ) + + if self._should_retry_status(resp.status): + _next_pending_calls.append( + {"dispatch_id": self._make_dispatch_id(), "call": call} + ) + + self._pending_calls = _next_pending_calls + + def response_for(self, dispatch_id: DispatchID) -> function_pb.RunResponse | None: + for entry in self.responses: + if entry["dispatch_id"] == dispatch_id: + return entry["response"] + return None + + def _should_retry_status(self, status: status_pb.Status) -> bool: + match status: + case ( + status_pb.STATUS_THROTTLED + | status_pb.STATUS_TIMEOUT + | status_pb.STATUS_TEMPORARY_ERROR + | status_pb.STATUS_INCOMPATIBLE_STATE + | status_pb.STATUS_DNS_ERROR + | status_pb.STATUS_TCP_ERROR + | status_pb.STATUS_TLS_ERROR + | status_pb.STATUS_HTTP_ERROR + ): + return True + case _: + return False class ServerTest: diff --git a/tests/test_full.py b/tests/test_full.py index f1248ed1..a95e0077 100644 --- a/tests/test_full.py +++ b/tests/test_full.py @@ -66,5 +66,5 @@ def my_function(name: str) -> str: self.execute() # Validate results. - resp = self.servicer.responses[dispatch_id] + resp = self.servicer.response_for(dispatch_id) self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52")