Skip to content

Commit

Permalink
Merge pull request #76 from stealthrocket/better-test-retry
Browse files Browse the repository at this point in the history
Make retry example test more accurate
  • Loading branch information
pelletier authored Feb 20, 2024
2 parents 693542d + 1ff20c3 commit 1486980
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 5 deletions.
13 changes: 12 additions & 1 deletion examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from fastapi.testclient import TestClient

import dispatch.sdk.v1.status_pb2 as status_pb

from ... import function_service
from ...test_client import ServerTest

Expand Down Expand Up @@ -36,4 +38,13 @@ def test_foo(self):

server.execute(app_client)

self.assertEqual(len(servicer.responses), 1)
# Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6
# calls, including 5 retries.
for i in range(6):
server.execute(app_client)
self.assertEqual(len(servicer.responses), 6)

statuses = [r["response"].status for r in servicer.responses]
self.assertEqual(
statuses, [status_pb.STATUS_TEMPORARY_ERROR] * 5 + [status_pb.STATUS_OK]
)
41 changes: 38 additions & 3 deletions tests/dispatch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc
import dispatch.sdk.v1.function_pb2 as function_pb
import dispatch.sdk.v1.function_pb2_grpc as function_grpc
import dispatch.sdk.v1.status_pb2 as status_pb
from dispatch import Client, DispatchID

_test_auth_token = "THIS_IS_A_TEST_AUTH_TOKEN"
Expand All @@ -18,7 +19,7 @@ def __init__(self):
self._next_dispatch_id = 1
self._pending_calls = []

self.responses = {} # indexed by dispatch ID
self.responses = []
self.dispatched_calls = []

def _make_dispatch_id(self) -> DispatchID:
Expand Down Expand Up @@ -53,7 +54,10 @@ def Dispatch(self, request: dispatch_pb.DispatchRequest, context):

def execute(self, client: function_grpc.FunctionServiceStub):
"""Synchronously execute all pending function calls."""
while len(self._pending_calls) > 0:

_next_pending_calls = []

for entry in self._pending_calls:
entry = self._pending_calls.pop(0)
call: call_pb.Call = entry["call"]

Expand All @@ -63,7 +67,38 @@ def execute(self, client: function_grpc.FunctionServiceStub):
)

resp = client.Run(req)
self.responses[entry["dispatch_id"]] = resp
self.responses.append(
{"dispatch_id": entry["dispatch_id"], "response": resp}
)

if self._should_retry_status(resp.status):
_next_pending_calls.append(
{"dispatch_id": self._make_dispatch_id(), "call": call}
)

self._pending_calls = _next_pending_calls

def response_for(self, dispatch_id: DispatchID) -> function_pb.RunResponse | None:
for entry in self.responses:
if entry["dispatch_id"] == dispatch_id:
return entry["response"]
return None

def _should_retry_status(self, status: status_pb.Status) -> bool:
match status:
case (
status_pb.STATUS_THROTTLED
| status_pb.STATUS_TIMEOUT
| status_pb.STATUS_TEMPORARY_ERROR
| status_pb.STATUS_INCOMPATIBLE_STATE
| status_pb.STATUS_DNS_ERROR
| status_pb.STATUS_TCP_ERROR
| status_pb.STATUS_TLS_ERROR
| status_pb.STATUS_HTTP_ERROR
):
return True
case _:
return False


class ServerTest:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def my_function(name: str) -> str:
self.execute()

# Validate results.
resp = self.servicer.responses[dispatch_id]
resp = self.servicer.response_for(dispatch_id)
self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52")

0 comments on commit 1486980

Please sign in to comment.