diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index f7dddfb..ffe6a10 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -452,7 +452,10 @@ def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any: any.Unpack(b) return pickle.loads(b.value) - raise InvalidArgumentError("unsupported pickled value container") + elif not any.type_url and not any.value: + return None + + raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}") def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any: diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 149b817..25e6829 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -21,6 +21,7 @@ from dispatch.fastapi import Dispatch from dispatch.function import Arguments, Error, Function, Input, Output from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -91,23 +92,17 @@ async def my_function(input: Input) -> Output: ) client = create_endpoint_client(app) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) req = function_pb.RunRequest( function=my_function.name, - input=input_any, + input=any_pickle("Hello World!"), ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") @@ -229,10 +224,8 @@ def execute( req = function_pb.RunRequest(function=func.name) if input is not None: - input_bytes = pickle.dumps(input) - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes)) - req.input.CopyFrom(input_any) + any = any_pickle(input) + req.input.CopyFrom(any) if state is not None: req.poll_result.coroutine_state = state if calls is not None: diff --git a/tests/test_flask.py b/tests/test_flask.py index f681a17..ae6d431 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -19,6 +19,7 @@ from dispatch.flask import Dispatch from dispatch.function import Arguments, Error, Function, Input, Output from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -56,23 +57,16 @@ async def my_function(input: Input) -> Output: ) client = create_endpoint_client(app) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) req = function_pb.RunRequest( - function=my_function.name, - input=input_any, + function=my_function.name, input=any_pickle("Hello World!") ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") diff --git a/tests/test_http.py b/tests/test_http.py index 9ac4e1a..56aa8e3 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -19,6 +19,7 @@ from dispatch.function import Arguments, Error, Function, Input, Output, Registry from dispatch.http import Dispatch from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import parse_verification_key, public_key_from_pem @@ -91,22 +92,14 @@ async def my_function(input: Input) -> Output: http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) client = EndpointClient(http_client) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) - req = function_pb.RunRequest( - function=my_function.name, - input=input_any, + function=my_function.name, input=any_pickle("Hello World!") ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")