diff --git a/src/dispatch/any.py b/src/dispatch/any.py index fff8166..f4d8d33 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -1,11 +1,14 @@ from __future__ import annotations import pickle +from datetime import UTC, datetime, timedelta from typing import Any import google.protobuf.any_pb2 +import google.protobuf.duration_pb2 import google.protobuf.empty_pb2 import google.protobuf.message +import google.protobuf.timestamp_pb2 import google.protobuf.wrappers_pb2 from google.protobuf import descriptor_pool, message_factory @@ -30,6 +33,16 @@ def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: value = google.protobuf.wrappers_pb2.StringValue(value=value) elif isinstance(value, bytes): value = google.protobuf.wrappers_pb2.BytesValue(value=value) + elif isinstance(value, datetime): + # Note: datetime only supports microsecond granularity + seconds = int(value.timestamp()) + nanos = value.microsecond * 1000 + value = google.protobuf.timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + elif isinstance(value, timedelta): + # Note: timedelta only supports microsecond granularity + seconds = int(value.total_seconds()) + nanos = value.microseconds * 1000 + value = google.protobuf.duration_pb2.Duration(seconds=seconds, nanos=nanos) if not isinstance(value, google.protobuf.message.Message): value = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) @@ -76,5 +89,9 @@ def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: except Exception as e: # Otherwise, return the literal bytes. return proto.value + elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp): + return proto.ToDatetime(tzinfo=UTC) + elif isinstance(proto, google.protobuf.duration_pb2.Duration): + return proto.ToTimedelta() return proto diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py index c60f1e9..9db4436 100644 --- a/tests/dispatch/test_any.py +++ b/tests/dispatch/test_any.py @@ -1,5 +1,5 @@ import pickle -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from dispatch.any import INT64_MAX, INT64_MIN, marshal_any, unmarshal_any from dispatch.sdk.v1 import error_pb2 as error_pb @@ -70,16 +70,16 @@ def test_unmarshal_bytes(): def test_unmarshal_timestamp(): - ts = datetime.fromtimestamp( - 1719372909.641448 - ) # datetime.datetime(2024, 6, 26, 13, 35, 9, 641448) + ts = datetime.fromtimestamp(1719372909.641448, UTC) boxed = marshal_any(ts) + assert "type.googleapis.com/google.protobuf.Timestamp" == boxed.type_url assert ts == unmarshal_any(boxed) def test_unmarshal_duration(): d = timedelta(seconds=1, microseconds=1234) boxed = marshal_any(d) + assert "type.googleapis.com/google.protobuf.Duration" == boxed.type_url assert d == unmarshal_any(boxed)