Skip to content

Commit

Permalink
Fix a few more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 13, 2024
1 parent b7acacf commit ef2bedc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 36 deletions.
12 changes: 3 additions & 9 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dispatch.experimental.durable.registry import clear_functions
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.proto import _pb_any_pickle as any_pickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import parse_verification_key, public_key_from_pem
Expand Down Expand Up @@ -109,22 +110,15 @@ async def my_function(input: Input) -> Output:
http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
client = EndpointClient(http_client)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
input=any_pickle("Hello World!"),
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)
output = any_unpickle(resp.exit.result.output)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
47 changes: 20 additions & 27 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def execute(
any = any_pickle(input)
req.input.CopyFrom(any)
if state is not None:
req.poll_result.coroutine_state = state
any = any_pickle(state)
req.poll_result.typed_coroutine_state.CopyFrom(any)
if calls is not None:
for c in calls:
req.poll_result.results.append(c)
Expand All @@ -247,10 +248,6 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult:
resp = self.client.run(req)
self.assertIsInstance(resp, function_pb.RunResponse)

# Assert the response is terminal. Good enough until the test client can
# orchestrate coroutines.
self.assertTrue(len(resp.poll.coroutine_state) == 0)

resp.exit.result.correlation_id = call.correlation_id
return resp.exit.result

Expand Down Expand Up @@ -317,9 +314,10 @@ async def my_function(input: Input) -> Output:
return Output.value("not reached")

resp = self.execute(my_function, input="cool stuff")
self.assertEqual(b"42", resp.poll.coroutine_state)
state = any_unpickle(resp.poll.typed_coroutine_state)
self.assertEqual(b"42", state)

resp = self.execute(my_function, state=resp.poll.coroutine_state)
resp = self.execute(my_function, state=state)
self.assertEqual("ValueError", resp.exit.result.error.type)
self.assertEqual(
"This input is for a resumed coroutine", resp.exit.result.error.message
Expand Down Expand Up @@ -360,32 +358,29 @@ async def coroutine3(input: Input) -> Output:
if input.is_first_call:
counter = input.input
else:
(counter,) = struct.unpack("@i", input.coroutine_state)
counter = input.coroutine_state
counter -= 1
if counter <= 0:
return Output.value("done")
coroutine_state = struct.pack("@i", counter)
return Output.poll(coroutine_state=coroutine_state)
return Output.poll(coroutine_state=counter)

# first call
resp = self.execute(coroutine3, input=4)
state = resp.poll.coroutine_state
self.assertTrue(len(state) > 0)
state = any_unpickle(resp.poll.typed_coroutine_state)
self.assertEqual(state, 3)

# resume, state = 3
resp = self.execute(coroutine3, state=state)
state = resp.poll.coroutine_state
self.assertTrue(len(state) > 0)
state = any_unpickle(resp.poll.typed_coroutine_state)
self.assertEqual(state, 2)

# resume, state = 2
resp = self.execute(coroutine3, state=state)
state = resp.poll.coroutine_state
self.assertTrue(len(state) > 0)
state = any_unpickle(resp.poll.typed_coroutine_state)
self.assertEqual(state, 1)

# resume, state = 1
resp = self.execute(coroutine3, state=state)
state = resp.poll.coroutine_state
self.assertTrue(len(state) == 0)
out = response_output(resp)
self.assertEqual(out, "done")

Expand All @@ -399,18 +394,18 @@ async def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
return Output.poll(
coroutine_state=text.encode(),
coroutine_state=text,
calls=[coro_compute_len._build_primitive_call(text)],
)
text = input.coroutine_state.decode()
text = input.coroutine_state
length = input.call_results[0].output
return Output.value(f"length={length} text='{text}'")

resp = self.execute(coroutine_main, input="cool stuff")

# main saved some state
state = resp.poll.coroutine_state
self.assertTrue(len(state) > 0)
state = any_unpickle(resp.poll.typed_coroutine_state)
self.assertEqual(state, "cool stuff")
# main asks for 1 call to compute_len
self.assertEqual(len(resp.poll.calls), 1)
call = resp.poll.calls[0]
Expand All @@ -426,7 +421,6 @@ async def coroutine_main(input: Input) -> Output:
# resume main with the result
resp = self.execute(coroutine_main, state=state, calls=[resp2])
# validate the final result
self.assertTrue(len(resp.poll.coroutine_state) == 0)
out = response_output(resp)
self.assertEqual("length=10 text='cool stuff'", out)

Expand All @@ -440,7 +434,7 @@ async def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
return Output.poll(
coroutine_state=text.encode(),
coroutine_state=text,
calls=[coro_compute_len._build_primitive_call(text)],
)
error = input.call_results[0].error
Expand All @@ -452,8 +446,8 @@ async def coroutine_main(input: Input) -> Output:
resp = self.execute(coroutine_main, input="cool stuff")

# main saved some state
state = resp.poll.coroutine_state
self.assertTrue(len(state) > 0)
state = any_unpickle(resp.poll.typed_coroutine_state)
self.assertEqual(state, "cool stuff")
# main asks for 1 call to compute_len
self.assertEqual(len(resp.poll.calls), 1)
call = resp.poll.calls[0]
Expand All @@ -466,7 +460,6 @@ async def coroutine_main(input: Input) -> Output:
# resume main with the result
resp = self.execute(coroutine_main, state=state, calls=[resp2])
# validate the final result
self.assertTrue(len(resp.poll.coroutine_state) == 0)
out = response_output(resp)
self.assertEqual(out, "msg=Dead type='type'")

Expand Down

0 comments on commit ef2bedc

Please sign in to comment.