Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid pickling primitive values and proto messages #179

Merged
merged 7 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions src/dispatch/any.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

import pickle
from datetime import datetime, timedelta, timezone
from typing import Any

import google.protobuf.any_pb2
import google.protobuf.duration_pb2
import google.protobuf.empty_pb2
import google.protobuf.message
import google.protobuf.struct_pb2
import google.protobuf.timestamp_pb2
import google.protobuf.wrappers_pb2
from google.protobuf import descriptor_pool, message_factory

from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb

INT64_MIN = -9223372036854775808
INT64_MAX = 9223372036854775807


def marshal_any(value: Any) -> google.protobuf.any_pb2.Any:
if value is None:
value = google.protobuf.empty_pb2.Empty()
elif isinstance(value, bool):
value = google.protobuf.wrappers_pb2.BoolValue(value=value)
elif isinstance(value, int) and INT64_MIN <= value <= INT64_MAX:
# To keep things simple, serialize all integers as int64 on the wire.
# For larger integers, fall through and use pickle.
value = google.protobuf.wrappers_pb2.Int64Value(value=value)
elif isinstance(value, float):
value = google.protobuf.wrappers_pb2.DoubleValue(value=value)
elif isinstance(value, str):
value = google.protobuf.wrappers_pb2.StringValue(value=value)
elif isinstance(value, bytes):
value = google.protobuf.wrappers_pb2.BytesValue(value=value)
elif isinstance(value, datetime):
# Note: datetime only supports microsecond granularity
seconds = int(value.timestamp())
nanos = value.microsecond * 1000
value = google.protobuf.timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
elif isinstance(value, timedelta):
# Note: timedelta only supports microsecond granularity
seconds = int(value.total_seconds())
nanos = value.microseconds * 1000
value = google.protobuf.duration_pb2.Duration(seconds=seconds, nanos=nanos)

if isinstance(value, list) or isinstance(value, dict):
try:
value = as_struct_value(value)
except ValueError:
pass # fallthrough

if not isinstance(value, google.protobuf.message.Message):
value = pickled_pb.Pickled(pickled_value=pickle.dumps(value))

any = google.protobuf.any_pb2.Any()
if value.DESCRIPTOR.full_name.startswith("dispatch.sdk."):
any.Pack(value, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
else:
any.Pack(value)

return any


def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any:
pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
any.Unpack(proto)

if isinstance(proto, pickled_pb.Pickled):
return pickle.loads(proto.pickled_value)

elif isinstance(proto, google.protobuf.empty_pb2.Empty):
return None

elif isinstance(proto, google.protobuf.wrappers_pb2.BoolValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.Int32Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.Int64Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.UInt32Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.UInt64Value):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.FloatValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.DoubleValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.StringValue):
return proto.value

elif isinstance(proto, google.protobuf.wrappers_pb2.BytesValue):
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(proto.value)
except Exception as e:
# Otherwise, return the literal bytes.
return proto.value

elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp):
return proto.ToDatetime(tzinfo=timezone.utc)

elif isinstance(proto, google.protobuf.duration_pb2.Duration):
return proto.ToTimedelta()

elif isinstance(proto, google.protobuf.struct_pb2.Value):
return from_struct_value(proto)

return proto


def as_struct_value(value: Any) -> google.protobuf.struct_pb2.Value:
if value is None:
null_value = google.protobuf.struct_pb2.NullValue.NULL_VALUE
return google.protobuf.struct_pb2.Value(null_value=null_value)

elif isinstance(value, bool):
return google.protobuf.struct_pb2.Value(bool_value=value)

elif isinstance(value, int) or isinstance(value, float):
return google.protobuf.struct_pb2.Value(number_value=float(value))

elif isinstance(value, str):
return google.protobuf.struct_pb2.Value(string_value=value)

elif isinstance(value, list):
list_value = google.protobuf.struct_pb2.ListValue(
values=[as_struct_value(v) for v in value]
)
return google.protobuf.struct_pb2.Value(list_value=list_value)

elif isinstance(value, dict):
for key in value.keys():
if not isinstance(key, str):
raise ValueError("unsupported object key")

struct_value = google.protobuf.struct_pb2.Struct(
fields={k: as_struct_value(v) for k, v in value.items()}
)
return google.protobuf.struct_pb2.Value(struct_value=struct_value)

raise ValueError("unsupported value")


def from_struct_value(value: google.protobuf.struct_pb2.Value) -> Any:
if value.HasField("null_value"):
return None
elif value.HasField("bool_value"):
return value.bool_value
elif value.HasField("number_value"):
return value.number_value
elif value.HasField("string_value"):
return value.string_value
elif value.HasField("list_value"):

return [from_struct_value(v) for v in value.list_value.values]
elif value.HasField("struct_value"):
return {k: from_struct_value(v) for k, v in value.struct_value.fields.items()}
else:
raise RuntimeError(f"invalid struct_pb2.Value: {value}")
64 changes: 9 additions & 55 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tblib # type: ignore[import-untyped]
from google.protobuf import descriptor_pool, duration_pb2, message_factory

from dispatch.any import marshal_any, unmarshal_any
from dispatch.error import IncompatibleStateError, InvalidArgumentError
from dispatch.id import DispatchID
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
Expand Down Expand Up @@ -78,11 +79,11 @@ def __init__(self, req: function_pb.RunRequest):

self._has_input = req.HasField("input")
if self._has_input:
self._input = _pb_any_unpack(req.input)
self._input = unmarshal_any(req.input)
else:
if req.poll_result.coroutine_state:
raise IncompatibleStateError # coroutine_state is deprecated
self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state)
self._coroutine_state = unmarshal_any(req.poll_result.typed_coroutine_state)
self._call_results = [
CallResult._from_proto(r) for r in req.poll_result.results
]
Expand Down Expand Up @@ -141,7 +142,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
return Input(
req=function_pb.RunRequest(
function=function,
input=_pb_any_pickle(input),
input=marshal_any(input),
)
)

Expand All @@ -157,7 +158,7 @@ def from_poll_results(
req=function_pb.RunRequest(
function=function,
poll_result=poll_pb.PollResult(
typed_coroutine_state=_pb_any_pickle(coroutine_state),
typed_coroutine_state=marshal_any(coroutine_state),
results=[result._as_proto() for result in call_results],
error=error._as_proto() if error else None,
),
Expand Down Expand Up @@ -241,7 +242,7 @@ def poll(
else None
)
poll = poll_pb.Poll(
typed_coroutine_state=_pb_any_pickle(coroutine_state),
typed_coroutine_state=marshal_any(coroutine_state),
min_results=min_results,
max_results=max_results,
max_wait=max_wait,
Expand Down Expand Up @@ -279,7 +280,7 @@ class Call:
correlation_id: Optional[int] = None

def _as_proto(self) -> call_pb.Call:
input_bytes = _pb_any_pickle(self.input)
input_bytes = marshal_any(self.input)
return call_pb.Call(
correlation_id=self.correlation_id,
endpoint=self.endpoint,
Expand All @@ -301,7 +302,7 @@ def _as_proto(self) -> call_pb.CallResult:
output_any = None
error_proto = None
if self.output is not None:
output_any = _pb_any_pickle(self.output)
output_any = marshal_any(self.output)
if self.error is not None:
error_proto = self.error._as_proto()

Expand All @@ -317,7 +318,7 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult:
output = None
error = None
if proto.HasField("output"):
output = _any_unpickle(proto.output)
output = unmarshal_any(proto.output)
if proto.HasField("error"):
error = Error._from_proto(proto.error)

Expand Down Expand Up @@ -438,50 +439,3 @@ def _as_proto(self) -> error_pb.Error:
return error_pb.Error(
type=self.type, message=self.message, value=value, traceback=self.traceback
)


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
return pickle.loads(b.value)

elif not any.type_url and not any.value:
return None

raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")


def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any = google.protobuf.any_pb2.Any()
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(b.value)
except Exception as e:
# Otherwise, return the literal bytes.
return b.value

pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
any.Unpack(proto)
return proto
Loading