Skip to content

Commit

Permalink
Merge pull request #177 from dispatchrun/bytes-containers
Browse files Browse the repository at this point in the history
Wrap pickled values in dispatch.sdk.python.v1 container
  • Loading branch information
chriso authored Jun 13, 2024
2 parents 6661a09 + 9fb79ef commit 4a344ad
Show file tree
Hide file tree
Showing 18 changed files with 195 additions and 143 deletions.
73 changes: 47 additions & 26 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import tblib # type: ignore[import-untyped]
from google.protobuf import descriptor_pool, duration_pb2, message_factory

from dispatch.error import IncompatibleStateError, InvalidArgumentError
from dispatch.id import DispatchID
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import error_pb2 as error_pb
from dispatch.sdk.v1 import exit_pb2 as exit_pb
Expand Down Expand Up @@ -77,18 +79,11 @@ def __init__(self, req: function_pb.RunRequest):

self._has_input = req.HasField("input")
if self._has_input:
if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
input_pb = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_pb)
input_bytes = input_pb.value
try:
self._input = pickle.loads(input_bytes)
except Exception as e:
self._input = input_bytes
else:
self._input = _pb_any_unpack(req.input)
self._input = _pb_any_unpack(req.input)
else:
self._coroutine_state = req.poll_result.coroutine_state
if req.poll_result.coroutine_state:
raise IncompatibleStateError # coroutine_state is deprecated
self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state)
self._call_results = [
CallResult._from_proto(r) for r in req.poll_result.results
]
Expand Down Expand Up @@ -155,15 +150,15 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
def from_poll_results(
cls,
function: str,
coroutine_state: Optional[bytes],
coroutine_state: Any,
call_results: List[CallResult],
error: Optional[Error] = None,
):
return Input(
req=function_pb.RunRequest(
function=function,
poll_result=poll_pb.PollResult(
coroutine_state=coroutine_state,
typed_coroutine_state=_pb_any_pickle(coroutine_state),
results=[result._as_proto() for result in call_results],
error=error._as_proto() if error else None,
),
Expand Down Expand Up @@ -232,7 +227,7 @@ def exit(
@classmethod
def poll(
cls,
coroutine_state: Optional[bytes] = None,
coroutine_state: Any = None,
calls: Optional[List[Call]] = None,
min_results: int = 1,
max_results: int = 10,
Expand All @@ -247,7 +242,7 @@ def poll(
else None
)
poll = poll_pb.Poll(
coroutine_state=coroutine_state,
typed_coroutine_state=_pb_any_pickle(coroutine_state),
min_results=min_results,
max_results=max_results,
max_wait=max_wait,
Expand Down Expand Up @@ -447,21 +442,47 @@ def _as_proto(self) -> error_pb.Error:


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue())
return pickle.loads(value_bytes.value)
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
return pickle.loads(b.value)

elif not any.type_url and not any.value:
return None

raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")


def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any = google.protobuf.any_pb2.Any()
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
value_bytes = pickle.dumps(x)
pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes)
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any
def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(b.value)
except Exception as e:
# Otherwise, return the literal bytes.
return b.value

def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any:
pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(x.TypeName())
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
x.Unpack(proto)
any.Unpack(proto)
return proto
44 changes: 17 additions & 27 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,17 @@ def _init_state(self, input: Input) -> State:
)

def _rebuild_state(self, input: Input):
logger.debug(
"resuming scheduler with %d bytes of state", len(input.coroutine_state)
)
logger.info("resuming main coroutine")
try:
state = pickle.loads(input.coroutine_state)
state = input.coroutine_state
if not isinstance(state, State):
raise ValueError("invalid state")
if state.version != self.version:
raise ValueError(
f"version mismatch: '{state.version}' vs. current '{self.version}'"
)
return state
except (pickle.PickleError, ValueError) as e:
except ValueError as e:
logger.warning("state is incompatible", exc_info=True)
raise IncompatibleStateError from e

Expand Down Expand Up @@ -454,32 +452,24 @@ async def _run(self, input: Input) -> Output:
await asyncio.gather(*asyncio_tasks, return_exceptions=True)
return coroutine_result

# Serialize coroutines and scheduler state.
logger.debug("serializing state")
# Yield to Dispatch.
logger.debug("yielding to Dispatch with %d call(s)", len(pending_calls))
try:
serialized_state = pickle.dumps(state)
return Output.poll(
coroutine_state=state,
calls=pending_calls,
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
max_wait_seconds=self.poll_max_wait_seconds,
)
except pickle.PickleError as e:
logger.exception("state could not be serialized")
return Output.error(Error.from_exception(e, status=Status.PERMANENT_ERROR))

# Close coroutines before yielding.
for suspended in state.suspended.values():
suspended.coroutine.close()
state.suspended = {}

# Yield to Dispatch.
logger.debug(
"yielding to Dispatch with %d call(s) and %d bytes of state",
len(pending_calls),
len(serialized_state),
)
return Output.poll(
coroutine_state=serialized_state,
calls=pending_calls,
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
max_wait_seconds=self.poll_max_wait_seconds,
)
finally:
# Close coroutines.
for suspended in state.suspended.values():
suspended.coroutine.close()
state.suspended = {}


async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]):
Expand Down
Empty file.
Empty file.
32 changes: 32 additions & 0 deletions src/dispatch/sdk/python/v1/pickled_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions src/dispatch/sdk/python/v1/pickled_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import ClassVar as _ClassVar
from typing import Optional as _Optional

from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message

DESCRIPTOR: _descriptor.FileDescriptor

class Pickled(_message.Message):
__slots__ = ("pickled_value",)
PICKLED_VALUE_FIELD_NUMBER: _ClassVar[int]
pickled_value: bytes
def __init__(self, pickled_value: _Optional[bytes] = ...) -> None: ...
3 changes: 3 additions & 0 deletions src/dispatch/sdk/python/v1/pickled_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
8 changes: 4 additions & 4 deletions src/dispatch/sdk/v1/call_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 11 additions & 1 deletion src/dispatch/sdk/v1/call_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,34 @@ from dispatch.sdk.v1 import error_pb2 as _error_pb2
DESCRIPTOR: _descriptor.FileDescriptor

class Call(_message.Message):
__slots__ = ("correlation_id", "endpoint", "function", "input", "expiration")
__slots__ = (
"correlation_id",
"endpoint",
"function",
"input",
"expiration",
"version",
)
CORRELATION_ID_FIELD_NUMBER: _ClassVar[int]
ENDPOINT_FIELD_NUMBER: _ClassVar[int]
FUNCTION_FIELD_NUMBER: _ClassVar[int]
INPUT_FIELD_NUMBER: _ClassVar[int]
EXPIRATION_FIELD_NUMBER: _ClassVar[int]
VERSION_FIELD_NUMBER: _ClassVar[int]
correlation_id: int
endpoint: str
function: str
input: _any_pb2.Any
expiration: _duration_pb2.Duration
version: str
def __init__(
self,
correlation_id: _Optional[int] = ...,
endpoint: _Optional[str] = ...,
function: _Optional[str] = ...,
input: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...,
expiration: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ...,
version: _Optional[str] = ...,
) -> None: ...

class CallResult(_message.Message):
Expand Down
14 changes: 7 additions & 7 deletions src/dispatch/sdk/v1/dispatch_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4a344ad

Please sign in to comment.