diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index c0cce39..1f53543 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -20,6 +20,7 @@ from dispatch.experimental.durable.registry import clear_functions from dispatch.function import Arguments, Error, Function, Input, Output, Registry from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import parse_verification_key, public_key_from_pem @@ -109,22 +110,15 @@ async def my_function(input: Input) -> Output: http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) client = EndpointClient(http_client) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) - req = function_pb.RunRequest( function=my_function.name, - input=input_any, + input=any_pickle("Hello World!"), ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 25e6829..e7f4e52 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -227,7 +227,8 @@ def execute( any = any_pickle(input) req.input.CopyFrom(any) if state is not None: - req.poll_result.coroutine_state = state + any = any_pickle(state) + req.poll_result.typed_coroutine_state.CopyFrom(any) if calls is not None: for c in calls: req.poll_result.results.append(c) @@ -247,10 +248,6 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: resp = self.client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - # Assert the response is terminal. Good enough until the test client can - # orchestrate coroutines. - self.assertTrue(len(resp.poll.coroutine_state) == 0) - resp.exit.result.correlation_id = call.correlation_id return resp.exit.result @@ -317,9 +314,10 @@ async def my_function(input: Input) -> Output: return Output.value("not reached") resp = self.execute(my_function, input="cool stuff") - self.assertEqual(b"42", resp.poll.coroutine_state) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(b"42", state) - resp = self.execute(my_function, state=resp.poll.coroutine_state) + resp = self.execute(my_function, state=state) self.assertEqual("ValueError", resp.exit.result.error.type) self.assertEqual( "This input is for a resumed coroutine", resp.exit.result.error.message @@ -360,32 +358,29 @@ async def coroutine3(input: Input) -> Output: if input.is_first_call: counter = input.input else: - (counter,) = struct.unpack("@i", input.coroutine_state) + counter = input.coroutine_state counter -= 1 if counter <= 0: return Output.value("done") - coroutine_state = struct.pack("@i", counter) - return Output.poll(coroutine_state=coroutine_state) + return Output.poll(coroutine_state=counter) # first call resp = self.execute(coroutine3, input=4) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, 3) # resume, state = 3 resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, 2) # resume, state = 2 resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, 1) # resume, state = 1 resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) == 0) out = response_output(resp) self.assertEqual(out, "done") @@ -399,18 +394,18 @@ async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( - coroutine_state=text.encode(), + coroutine_state=text, calls=[coro_compute_len._build_primitive_call(text)], ) - text = input.coroutine_state.decode() + text = input.coroutine_state length = input.call_results[0].output return Output.value(f"length={length} text='{text}'") resp = self.execute(coroutine_main, input="cool stuff") # main saved some state - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, "cool stuff") # main asks for 1 call to compute_len self.assertEqual(len(resp.poll.calls), 1) call = resp.poll.calls[0] @@ -426,7 +421,6 @@ async def coroutine_main(input: Input) -> Output: # resume main with the result resp = self.execute(coroutine_main, state=state, calls=[resp2]) # validate the final result - self.assertTrue(len(resp.poll.coroutine_state) == 0) out = response_output(resp) self.assertEqual("length=10 text='cool stuff'", out) @@ -440,7 +434,7 @@ async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( - coroutine_state=text.encode(), + coroutine_state=text, calls=[coro_compute_len._build_primitive_call(text)], ) error = input.call_results[0].error @@ -452,8 +446,8 @@ async def coroutine_main(input: Input) -> Output: resp = self.execute(coroutine_main, input="cool stuff") # main saved some state - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, "cool stuff") # main asks for 1 call to compute_len self.assertEqual(len(resp.poll.calls), 1) call = resp.poll.calls[0] @@ -466,7 +460,6 @@ async def coroutine_main(input: Input) -> Output: # resume main with the result resp = self.execute(coroutine_main, state=state, calls=[resp2]) # validate the final result - self.assertTrue(len(resp.poll.coroutine_state) == 0) out = response_output(resp) self.assertEqual(out, "msg=Dead type='type'")