From 516d006ca031e93f85af8765c65f18a11b7986b2 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Tue, 25 Jun 2024 14:14:46 +1000 Subject: [PATCH 1/7] Don't pickle proto messages before wrapping as google.protobuf.Any --- src/dispatch/proto.py | 40 +++++++++++++--------------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 9af3631..4e95341 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -78,7 +78,7 @@ def __init__(self, req: function_pb.RunRequest): self._has_input = req.HasField("input") if self._has_input: - self._input = _pb_any_unpack(req.input) + self._input = _any_unpickle(req.input) else: if req.poll_result.coroutine_state: raise IncompatibleStateError # coroutine_state is deprecated @@ -141,7 +141,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs): return Input( req=function_pb.RunRequest( function=function, - input=_pb_any_pickle(input), + input=_any_pickle(input), ) ) @@ -157,7 +157,7 @@ def from_poll_results( req=function_pb.RunRequest( function=function, poll_result=poll_pb.PollResult( - typed_coroutine_state=_pb_any_pickle(coroutine_state), + typed_coroutine_state=_any_pickle(coroutine_state), results=[result._as_proto() for result in call_results], error=error._as_proto() if error else None, ), @@ -241,7 +241,7 @@ def poll( else None ) poll = poll_pb.Poll( - typed_coroutine_state=_pb_any_pickle(coroutine_state), + typed_coroutine_state=_any_pickle(coroutine_state), min_results=min_results, max_results=max_results, max_wait=max_wait, @@ -279,7 +279,7 @@ class Call: correlation_id: Optional[int] = None def _as_proto(self) -> call_pb.Call: - input_bytes = _pb_any_pickle(self.input) + input_bytes = _any_pickle(self.input) return call_pb.Call( correlation_id=self.correlation_id, endpoint=self.endpoint, @@ -301,7 +301,7 @@ def _as_proto(self) -> call_pb.CallResult: output_any = None error_proto = None if self.output is not None: - output_any = _pb_any_pickle(self.output) + output_any = _any_pickle(self.output) if self.error is not None: error_proto = self.error._as_proto() @@ -440,31 +440,17 @@ def _as_proto(self) -> error_pb.Error: ) -def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any: - if any.Is(pickled_pb.Pickled.DESCRIPTOR): - p = pickled_pb.Pickled() - any.Unpack(p) - return pickle.loads(p.pickled_value) - - elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container - b = google.protobuf.wrappers_pb2.BytesValue() - any.Unpack(b) - return pickle.loads(b.value) - - elif not any.type_url and not any.value: - return None - - raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}") - - -def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any: - p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) +def _any_pickle(value: Any) -> google.protobuf.any_pb2.Any: any = google.protobuf.any_pb2.Any() - any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") + if isinstance(value, google.protobuf.message.Message): + any.Pack(value) + else: + p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) + any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") return any -def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any: +def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any: if any.Is(pickled_pb.Pickled.DESCRIPTOR): p = pickled_pb.Pickled() any.Unpack(p) From f33acd1b019bdbac3573c458309df59f14852031 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 26 Jun 2024 13:29:59 +1000 Subject: [PATCH 2/7] Move the helpers to a separate file --- src/dispatch/any.py | 44 ++++++++++++++++++++++++++++ src/dispatch/proto.py | 50 ++++++-------------------------- tests/dispatch/test_scheduler.py | 6 ++-- tests/test_fastapi.py | 1 - 4 files changed, 56 insertions(+), 45 deletions(-) create mode 100644 src/dispatch/any.py diff --git a/src/dispatch/any.py b/src/dispatch/any.py new file mode 100644 index 0000000..d885709 --- /dev/null +++ b/src/dispatch/any.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pickle +from typing import Any + +import google.protobuf.any_pb2 +import google.protobuf.message +import google.protobuf.wrappers_pb2 +from google.protobuf import descriptor_pool, message_factory + +from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb + + +def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: + any = google.protobuf.any_pb2.Any() + if isinstance(value, google.protobuf.message.Message): + any.Pack(value) + else: + p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) + any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") + return any + + +def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: + if any.Is(pickled_pb.Pickled.DESCRIPTOR): + p = pickled_pb.Pickled() + any.Unpack(p) + return pickle.loads(p.pickled_value) + + elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): + b = google.protobuf.wrappers_pb2.BytesValue() + any.Unpack(b) + try: + # Assume it's the legacy container for pickled values. + return pickle.loads(b.value) + except Exception as e: + # Otherwise, return the literal bytes. + return b.value + + pool = descriptor_pool.Default() + msg_descriptor = pool.FindMessageTypeByName(any.TypeName()) + proto = message_factory.GetMessageClass(msg_descriptor)() + any.Unpack(proto) + return proto diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 4e95341..cd166ec 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -11,6 +11,7 @@ import tblib # type: ignore[import-untyped] from google.protobuf import descriptor_pool, duration_pb2, message_factory +from dispatch.any import marshal_any, unmarshal_any from dispatch.error import IncompatibleStateError, InvalidArgumentError from dispatch.id import DispatchID from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb @@ -78,11 +79,11 @@ def __init__(self, req: function_pb.RunRequest): self._has_input = req.HasField("input") if self._has_input: - self._input = _any_unpickle(req.input) + self._input = unmarshal_any(req.input) else: if req.poll_result.coroutine_state: raise IncompatibleStateError # coroutine_state is deprecated - self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state) + self._coroutine_state = unmarshal_any(req.poll_result.typed_coroutine_state) self._call_results = [ CallResult._from_proto(r) for r in req.poll_result.results ] @@ -141,7 +142,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs): return Input( req=function_pb.RunRequest( function=function, - input=_any_pickle(input), + input=marshal_any(input), ) ) @@ -157,7 +158,7 @@ def from_poll_results( req=function_pb.RunRequest( function=function, poll_result=poll_pb.PollResult( - typed_coroutine_state=_any_pickle(coroutine_state), + typed_coroutine_state=marshal_any(coroutine_state), results=[result._as_proto() for result in call_results], error=error._as_proto() if error else None, ), @@ -241,7 +242,7 @@ def poll( else None ) poll = poll_pb.Poll( - typed_coroutine_state=_any_pickle(coroutine_state), + typed_coroutine_state=marshal_any(coroutine_state), min_results=min_results, max_results=max_results, max_wait=max_wait, @@ -279,7 +280,7 @@ class Call: correlation_id: Optional[int] = None def _as_proto(self) -> call_pb.Call: - input_bytes = _any_pickle(self.input) + input_bytes = marshal_any(self.input) return call_pb.Call( correlation_id=self.correlation_id, endpoint=self.endpoint, @@ -301,7 +302,7 @@ def _as_proto(self) -> call_pb.CallResult: output_any = None error_proto = None if self.output is not None: - output_any = _any_pickle(self.output) + output_any = marshal_any(self.output) if self.error is not None: error_proto = self.error._as_proto() @@ -317,7 +318,7 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult: output = None error = None if proto.HasField("output"): - output = _any_unpickle(proto.output) + output = unmarshal_any(proto.output) if proto.HasField("error"): error = Error._from_proto(proto.error) @@ -438,36 +439,3 @@ def _as_proto(self) -> error_pb.Error: return error_pb.Error( type=self.type, message=self.message, value=value, traceback=self.traceback ) - - -def _any_pickle(value: Any) -> google.protobuf.any_pb2.Any: - any = google.protobuf.any_pb2.Any() - if isinstance(value, google.protobuf.message.Message): - any.Pack(value) - else: - p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) - any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") - return any - - -def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any: - if any.Is(pickled_pb.Pickled.DESCRIPTOR): - p = pickled_pb.Pickled() - any.Unpack(p) - return pickle.loads(p.pickled_value) - - elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): - b = google.protobuf.wrappers_pb2.BytesValue() - any.Unpack(b) - try: - # Assume it's the legacy container for pickled values. - return pickle.loads(b.value) - except Exception as e: - # Otherwise, return the literal bytes. - return b.value - - pool = descriptor_pool.Default() - msg_descriptor = pool.FindMessageTypeByName(any.TypeName()) - proto = message_factory.GetMessageClass(msg_descriptor)() - any.Unpack(proto) - return proto diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index c15ed84..386c32e 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -3,10 +3,10 @@ import pytest +from dispatch.any import unmarshal_any from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall -from dispatch.proto import _any_unpickle as any_unpickle from dispatch.scheduler import ( AllFuture, AnyFuture, @@ -464,7 +464,7 @@ async def resume( poll = assert_poll(prev_output) input = Input.from_poll_results( main.__qualname__, - any_unpickle(poll.typed_coroutine_state), + unmarshal_any(poll.typed_coroutine_state), call_results, Error.from_exception(poll_error) if poll_error else None, ) @@ -489,7 +489,7 @@ def assert_exit_result_value(output: Output, expect: Any): result = assert_exit_result(output) assert result.HasField("output") assert not result.HasField("error") - assert expect == any_unpickle(result.output) + assert expect == unmarshal_any(result.output) def assert_exit_result_error( diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 97e9766..5cf717f 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -19,7 +19,6 @@ from dispatch.experimental.durable.registry import clear_functions from dispatch.fastapi import Dispatch from dispatch.function import Arguments, Client, Error, Input, Output, Registry -from dispatch.proto import _any_unpickle as any_unpickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( From ea3cf1ff6090b37e9b7977b78206386f74cf79e9 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 26 Jun 2024 13:44:51 +1000 Subject: [PATCH 3/7] Tests --- src/dispatch/any.py | 11 ++++--- tests/dispatch/test_any.py | 65 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 tests/dispatch/test_any.py diff --git a/src/dispatch/any.py b/src/dispatch/any.py index d885709..84dbcb2 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -12,12 +12,15 @@ def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: + if not isinstance(value, google.protobuf.message.Message): + value = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) + any = google.protobuf.any_pb2.Any() - if isinstance(value, google.protobuf.message.Message): - any.Pack(value) + if value.DESCRIPTOR.full_name.startswith("dispatch.sdk."): + any.Pack(value, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") else: - p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) - any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") + any.Pack(value) + return any diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py new file mode 100644 index 0000000..ee1d993 --- /dev/null +++ b/tests/dispatch/test_any.py @@ -0,0 +1,65 @@ +import pickle +from datetime import datetime, timedelta + +from dispatch.any import marshal_any, unmarshal_any +from dispatch.sdk.v1 import error_pb2 as error_pb + + +def test_unmarshal_none(): + boxed = marshal_any(None) + assert None == unmarshal_any(boxed) + + +def test_unmarshal_bool(): + boxed = marshal_any(True) + assert True == unmarshal_any(boxed) + + +def test_unmarshal_integer(): + boxed = marshal_any(1234) + assert 1234 == unmarshal_any(boxed) + + boxed = marshal_any(-1234) + assert -1234 == unmarshal_any(boxed) + + +def test_unmarshal_float(): + boxed = marshal_any(3.14) + assert 3.14 == unmarshal_any(boxed) + + +def test_unmarshal_string(): + boxed = marshal_any("foo") + assert "foo" == unmarshal_any(boxed) + + +def test_unmarshal_bytes(): + boxed = marshal_any(b"bar") + assert b"bar" == unmarshal_any(boxed) + + +def test_unmarshal_timestamp(): + ts = datetime.fromtimestamp( + 1719372909.641448 + ) # datetime.datetime(2024, 6, 26, 13, 35, 9, 641448) + boxed = marshal_any(ts) + assert ts == unmarshal_any(boxed) + + +def test_unmarshal_duration(): + d = timedelta(seconds=1, microseconds=1234) + boxed = marshal_any(d) + assert d == unmarshal_any(boxed) + + +def test_unmarshal_protobuf_message(): + message = error_pb.Error(type="internal", message="oops") + boxed = marshal_any(message) + + # Check the message isn't pickled (in which case the type_url would + # end with dispatch.sdk.python.v1.Pickled). + assert ( + "buf.build/stealthrocket/dispatch-proto/dispatch.sdk.v1.Error" == boxed.type_url + ) + + assert message == unmarshal_any(boxed) From 7eebe4c585d212764260d4449e8c99191066bc84 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 26 Jun 2024 14:00:35 +1000 Subject: [PATCH 4/7] Prefer using built-in wrappers for primitive values --- src/dispatch/any.py | 61 +++++++++++++++++++++++++++++--------- tests/dispatch/test_any.py | 33 ++++++++++++++++++++- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/src/dispatch/any.py b/src/dispatch/any.py index 84dbcb2..fff8166 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -4,14 +4,33 @@ from typing import Any import google.protobuf.any_pb2 +import google.protobuf.empty_pb2 import google.protobuf.message import google.protobuf.wrappers_pb2 from google.protobuf import descriptor_pool, message_factory from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb +INT64_MIN = -9223372036854775808 +INT64_MAX = 9223372036854775807 + def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: + if value is None: + value = google.protobuf.empty_pb2.Empty() + elif isinstance(value, bool): + value = google.protobuf.wrappers_pb2.BoolValue(value=value) + elif isinstance(value, int) and INT64_MIN <= value <= INT64_MAX: + # To keep things simple, serialize all integers as int64 on the wire. + # For larger integers, fall through and use pickle. + value = google.protobuf.wrappers_pb2.Int64Value(value=value) + elif isinstance(value, float): + value = google.protobuf.wrappers_pb2.DoubleValue(value=value) + elif isinstance(value, str): + value = google.protobuf.wrappers_pb2.StringValue(value=value) + elif isinstance(value, bytes): + value = google.protobuf.wrappers_pb2.BytesValue(value=value) + if not isinstance(value, google.protobuf.message.Message): value = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) @@ -25,23 +44,37 @@ def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: - if any.Is(pickled_pb.Pickled.DESCRIPTOR): - p = pickled_pb.Pickled() - any.Unpack(p) - return pickle.loads(p.pickled_value) - - elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): - b = google.protobuf.wrappers_pb2.BytesValue() - any.Unpack(b) + pool = descriptor_pool.Default() + msg_descriptor = pool.FindMessageTypeByName(any.TypeName()) + proto = message_factory.GetMessageClass(msg_descriptor)() + any.Unpack(proto) + + if isinstance(proto, pickled_pb.Pickled): + return pickle.loads(proto.pickled_value) + elif isinstance(proto, google.protobuf.empty_pb2.Empty): + return None + elif isinstance(proto, google.protobuf.wrappers_pb2.BoolValue): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.Int32Value): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.Int64Value): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.UInt32Value): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.UInt64Value): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.FloatValue): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.DoubleValue): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.StringValue): + return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.BytesValue): try: # Assume it's the legacy container for pickled values. - return pickle.loads(b.value) + return pickle.loads(proto.value) except Exception as e: # Otherwise, return the literal bytes. - return b.value + return proto.value - pool = descriptor_pool.Default() - msg_descriptor = pool.FindMessageTypeByName(any.TypeName()) - proto = message_factory.GetMessageClass(msg_descriptor)() - any.Unpack(proto) return proto diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py index ee1d993..c60f1e9 100644 --- a/tests/dispatch/test_any.py +++ b/tests/dispatch/test_any.py @@ -1,40 +1,71 @@ import pickle from datetime import datetime, timedelta -from dispatch.any import marshal_any, unmarshal_any +from dispatch.any import INT64_MAX, INT64_MIN, marshal_any, unmarshal_any from dispatch.sdk.v1 import error_pb2 as error_pb def test_unmarshal_none(): boxed = marshal_any(None) + assert "type.googleapis.com/google.protobuf.Empty" == boxed.type_url assert None == unmarshal_any(boxed) def test_unmarshal_bool(): boxed = marshal_any(True) + assert "type.googleapis.com/google.protobuf.BoolValue" == boxed.type_url assert True == unmarshal_any(boxed) def test_unmarshal_integer(): boxed = marshal_any(1234) + assert "type.googleapis.com/google.protobuf.Int64Value" == boxed.type_url assert 1234 == unmarshal_any(boxed) boxed = marshal_any(-1234) + assert "type.googleapis.com/google.protobuf.Int64Value" == boxed.type_url assert -1234 == unmarshal_any(boxed) +def test_unmarshal_int64_limits(): + boxed = marshal_any(INT64_MIN) + assert "type.googleapis.com/google.protobuf.Int64Value" == boxed.type_url + assert INT64_MIN == unmarshal_any(boxed) + + boxed = marshal_any(INT64_MAX) + assert "type.googleapis.com/google.protobuf.Int64Value" == boxed.type_url + assert INT64_MAX == unmarshal_any(boxed) + + boxed = marshal_any(INT64_MIN - 1) + assert ( + "buf.build/stealthrocket/dispatch-proto/dispatch.sdk.python.v1.Pickled" + == boxed.type_url + ) + assert INT64_MIN - 1 == unmarshal_any(boxed) + + boxed = marshal_any(INT64_MAX + 1) + assert ( + "buf.build/stealthrocket/dispatch-proto/dispatch.sdk.python.v1.Pickled" + == boxed.type_url + ) + assert INT64_MAX + 1 == unmarshal_any(boxed) + + def test_unmarshal_float(): boxed = marshal_any(3.14) + assert "type.googleapis.com/google.protobuf.DoubleValue" == boxed.type_url assert 3.14 == unmarshal_any(boxed) def test_unmarshal_string(): boxed = marshal_any("foo") + assert "type.googleapis.com/google.protobuf.StringValue" == boxed.type_url assert "foo" == unmarshal_any(boxed) def test_unmarshal_bytes(): boxed = marshal_any(b"bar") + assert "type.googleapis.com/google.protobuf.BytesValue" == boxed.type_url assert b"bar" == unmarshal_any(boxed) From fdbedc12b6cf449c95f73e4c37f2b8dc725bc567 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 26 Jun 2024 14:14:17 +1000 Subject: [PATCH 5/7] Prefer using built-in wrappers for timestamps & durations --- src/dispatch/any.py | 17 +++++++++++++++++ tests/dispatch/test_any.py | 8 ++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/dispatch/any.py b/src/dispatch/any.py index fff8166..f4d8d33 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -1,11 +1,14 @@ from __future__ import annotations import pickle +from datetime import UTC, datetime, timedelta from typing import Any import google.protobuf.any_pb2 +import google.protobuf.duration_pb2 import google.protobuf.empty_pb2 import google.protobuf.message +import google.protobuf.timestamp_pb2 import google.protobuf.wrappers_pb2 from google.protobuf import descriptor_pool, message_factory @@ -30,6 +33,16 @@ def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: value = google.protobuf.wrappers_pb2.StringValue(value=value) elif isinstance(value, bytes): value = google.protobuf.wrappers_pb2.BytesValue(value=value) + elif isinstance(value, datetime): + # Note: datetime only supports microsecond granularity + seconds = int(value.timestamp()) + nanos = value.microsecond * 1000 + value = google.protobuf.timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + elif isinstance(value, timedelta): + # Note: timedelta only supports microsecond granularity + seconds = int(value.total_seconds()) + nanos = value.microseconds * 1000 + value = google.protobuf.duration_pb2.Duration(seconds=seconds, nanos=nanos) if not isinstance(value, google.protobuf.message.Message): value = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) @@ -76,5 +89,9 @@ def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: except Exception as e: # Otherwise, return the literal bytes. return proto.value + elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp): + return proto.ToDatetime(tzinfo=UTC) + elif isinstance(proto, google.protobuf.duration_pb2.Duration): + return proto.ToTimedelta() return proto diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py index c60f1e9..9db4436 100644 --- a/tests/dispatch/test_any.py +++ b/tests/dispatch/test_any.py @@ -1,5 +1,5 @@ import pickle -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from dispatch.any import INT64_MAX, INT64_MIN, marshal_any, unmarshal_any from dispatch.sdk.v1 import error_pb2 as error_pb @@ -70,16 +70,16 @@ def test_unmarshal_bytes(): def test_unmarshal_timestamp(): - ts = datetime.fromtimestamp( - 1719372909.641448 - ) # datetime.datetime(2024, 6, 26, 13, 35, 9, 641448) + ts = datetime.fromtimestamp(1719372909.641448, UTC) boxed = marshal_any(ts) + assert "type.googleapis.com/google.protobuf.Timestamp" == boxed.type_url assert ts == unmarshal_any(boxed) def test_unmarshal_duration(): d = timedelta(seconds=1, microseconds=1234) boxed = marshal_any(d) + assert "type.googleapis.com/google.protobuf.Duration" == boxed.type_url assert d == unmarshal_any(boxed) From 856bf3e9608ee2f708db3dd78b554c3a8c850ddd Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 26 Jun 2024 14:40:08 +1000 Subject: [PATCH 6/7] Try to serialize lists/dicts as structpb --- src/dispatch/any.py | 73 ++++++++++++++++++++++++++++++++++++++ tests/dispatch/test_any.py | 15 ++++++++ 2 files changed, 88 insertions(+) diff --git a/src/dispatch/any.py b/src/dispatch/any.py index f4d8d33..92dda7d 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -8,6 +8,7 @@ import google.protobuf.duration_pb2 import google.protobuf.empty_pb2 import google.protobuf.message +import google.protobuf.struct_pb2 import google.protobuf.timestamp_pb2 import google.protobuf.wrappers_pb2 from google.protobuf import descriptor_pool, message_factory @@ -44,6 +45,12 @@ def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: nanos = value.microseconds * 1000 value = google.protobuf.duration_pb2.Duration(seconds=seconds, nanos=nanos) + if isinstance(value, list) or isinstance(value, dict): + try: + value = as_struct_value(value) + except ValueError: + pass # fallthrough + if not isinstance(value, google.protobuf.message.Message): value = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) @@ -64,24 +71,34 @@ def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: if isinstance(proto, pickled_pb.Pickled): return pickle.loads(proto.pickled_value) + elif isinstance(proto, google.protobuf.empty_pb2.Empty): return None + elif isinstance(proto, google.protobuf.wrappers_pb2.BoolValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.Int32Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.Int64Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.UInt32Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.UInt64Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.FloatValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.DoubleValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.StringValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.BytesValue): try: # Assume it's the legacy container for pickled values. @@ -89,9 +106,65 @@ def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: except Exception as e: # Otherwise, return the literal bytes. return proto.value + elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp): return proto.ToDatetime(tzinfo=UTC) + elif isinstance(proto, google.protobuf.duration_pb2.Duration): return proto.ToTimedelta() + elif isinstance(proto, google.protobuf.struct_pb2.Value): + return from_struct_value(proto) + return proto + + +def as_struct_value(value: Any) -> google.protobuf.struct_pb2.Value: + if value is None: + null_value = google.protobuf.struct_pb2.NullValue.NULL_VALUE + return google.protobuf.struct_pb2.Value(null_value=null_value) + + elif isinstance(value, bool): + return google.protobuf.struct_pb2.Value(bool_value=value) + + elif isinstance(value, int) or isinstance(value, float): + return google.protobuf.struct_pb2.Value(number_value=float(value)) + + elif isinstance(value, str): + return google.protobuf.struct_pb2.Value(string_value=value) + + elif isinstance(value, list): + list_value = google.protobuf.struct_pb2.ListValue( + values=[as_struct_value(v) for v in value] + ) + return google.protobuf.struct_pb2.Value(list_value=list_value) + + elif isinstance(value, dict): + for key in value.keys(): + if not isinstance(key, str): + raise ValueError("unsupported object key") + + struct_value = google.protobuf.struct_pb2.Struct( + fields={k: as_struct_value(v) for k, v in value.items()} + ) + return google.protobuf.struct_pb2.Value(struct_value=struct_value) + + raise ValueError("unsupported value") + + +def from_struct_value(value: google.protobuf.struct_pb2.Value) -> Any: + if value.HasField("null_value"): + return None + elif value.HasField("bool_value"): + return value.bool_value + elif value.HasField("number_value"): + return value.number_value + elif value.HasField("string_value"): + return value.string_value + elif value.HasField("list_value"): + + return [from_struct_value(v) for v in value.list_value.values] + elif value.HasField("struct_value"): + return {k: from_struct_value(v) for k, v in value.struct_value.fields.items()} + else: + raise RuntimeError(f"invalid struct_pb2.Value: {value}") diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py index 9db4436..28fd2e5 100644 --- a/tests/dispatch/test_any.py +++ b/tests/dispatch/test_any.py @@ -94,3 +94,18 @@ def test_unmarshal_protobuf_message(): ) assert message == unmarshal_any(boxed) + + +def test_unmarshal_json_like(): + value = { + "null": None, + "bool": True, + "int": 11, + "float": 3.14, + "string": "foo", + "list": [None, "abc", 1.23], + "object": {"a": ["b", "c"]}, + } + boxed = marshal_any(value) + assert "type.googleapis.com/google.protobuf.Value" == boxed.type_url + assert value == unmarshal_any(boxed) From 6b283a4d134aad753785435bebb765a3ecc950b2 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 26 Jun 2024 14:44:40 +1000 Subject: [PATCH 7/7] Avoid using UTC alias added in Python 3.11 --- src/dispatch/any.py | 4 ++-- tests/dispatch/test_any.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dispatch/any.py b/src/dispatch/any.py index 92dda7d..4933990 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -1,7 +1,7 @@ from __future__ import annotations import pickle -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any import google.protobuf.any_pb2 @@ -108,7 +108,7 @@ def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: return proto.value elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp): - return proto.ToDatetime(tzinfo=UTC) + return proto.ToDatetime(tzinfo=timezone.utc) elif isinstance(proto, google.protobuf.duration_pb2.Duration): return proto.ToTimedelta() diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py index 28fd2e5..4cb23e7 100644 --- a/tests/dispatch/test_any.py +++ b/tests/dispatch/test_any.py @@ -1,5 +1,5 @@ import pickle -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta, timezone from dispatch.any import INT64_MAX, INT64_MIN, marshal_any, unmarshal_any from dispatch.sdk.v1 import error_pb2 as error_pb @@ -70,7 +70,7 @@ def test_unmarshal_bytes(): def test_unmarshal_timestamp(): - ts = datetime.fromtimestamp(1719372909.641448, UTC) + ts = datetime.fromtimestamp(1719372909.641448, timezone.utc) boxed = marshal_any(ts) assert "type.googleapis.com/google.protobuf.Timestamp" == boxed.type_url assert ts == unmarshal_any(boxed)