diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 90f6ec4..ffe6a10 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -12,7 +12,9 @@ import tblib # type: ignore[import-untyped] from google.protobuf import descriptor_pool, duration_pb2, message_factory +from dispatch.error import IncompatibleStateError, InvalidArgumentError from dispatch.id import DispatchID +from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import error_pb2 as error_pb from dispatch.sdk.v1 import exit_pb2 as exit_pb @@ -77,18 +79,11 @@ def __init__(self, req: function_pb.RunRequest): self._has_input = req.HasField("input") if self._has_input: - if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): - input_pb = google.protobuf.wrappers_pb2.BytesValue() - req.input.Unpack(input_pb) - input_bytes = input_pb.value - try: - self._input = pickle.loads(input_bytes) - except Exception as e: - self._input = input_bytes - else: - self._input = _pb_any_unpack(req.input) + self._input = _pb_any_unpack(req.input) else: - self._coroutine_state = req.poll_result.coroutine_state + if req.poll_result.coroutine_state: + raise IncompatibleStateError # coroutine_state is deprecated + self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state) self._call_results = [ CallResult._from_proto(r) for r in req.poll_result.results ] @@ -155,7 +150,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs): def from_poll_results( cls, function: str, - coroutine_state: Optional[bytes], + coroutine_state: Any, call_results: List[CallResult], error: Optional[Error] = None, ): @@ -163,7 +158,7 @@ def from_poll_results( req=function_pb.RunRequest( function=function, poll_result=poll_pb.PollResult( - coroutine_state=coroutine_state, + typed_coroutine_state=_pb_any_pickle(coroutine_state), results=[result._as_proto() for result in call_results], error=error._as_proto() if error else None, ), @@ -232,7 +227,7 @@ def exit( @classmethod def poll( cls, - coroutine_state: Optional[bytes] = None, + coroutine_state: Any = None, calls: Optional[List[Call]] = None, min_results: int = 1, max_results: int = 10, @@ -247,7 +242,7 @@ def poll( else None ) poll = poll_pb.Poll( - coroutine_state=coroutine_state, + typed_coroutine_state=_pb_any_pickle(coroutine_state), min_results=min_results, max_results=max_results, max_wait=max_wait, @@ -447,21 +442,47 @@ def _as_proto(self) -> error_pb.Error: def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any: - any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue()) - return pickle.loads(value_bytes.value) + if any.Is(pickled_pb.Pickled.DESCRIPTOR): + p = pickled_pb.Pickled() + any.Unpack(p) + return pickle.loads(p.pickled_value) + + elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container + b = google.protobuf.wrappers_pb2.BytesValue() + any.Unpack(b) + return pickle.loads(b.value) + + elif not any.type_url and not any.value: + return None + + raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}") + + +def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any: + p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) + any = google.protobuf.any_pb2.Any() + any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") + return any -def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any: - value_bytes = pickle.dumps(x) - pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes) - pb_any = google.protobuf.any_pb2.Any() - pb_any.Pack(pb_bytes) - return pb_any +def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any: + if any.Is(pickled_pb.Pickled.DESCRIPTOR): + p = pickled_pb.Pickled() + any.Unpack(p) + return pickle.loads(p.pickled_value) + elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): + b = google.protobuf.wrappers_pb2.BytesValue() + any.Unpack(b) + try: + # Assume it's the legacy container for pickled values. + return pickle.loads(b.value) + except Exception as e: + # Otherwise, return the literal bytes. + return b.value -def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any: pool = descriptor_pool.Default() - msg_descriptor = pool.FindMessageTypeByName(x.TypeName()) + msg_descriptor = pool.FindMessageTypeByName(any.TypeName()) proto = message_factory.GetMessageClass(msg_descriptor)() - x.Unpack(proto) + any.Unpack(proto) return proto diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 921f513..75b2e54 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -357,11 +357,9 @@ def _init_state(self, input: Input) -> State: ) def _rebuild_state(self, input: Input): - logger.debug( - "resuming scheduler with %d bytes of state", len(input.coroutine_state) - ) + logger.info("resuming main coroutine") try: - state = pickle.loads(input.coroutine_state) + state = input.coroutine_state if not isinstance(state, State): raise ValueError("invalid state") if state.version != self.version: @@ -369,7 +367,7 @@ def _rebuild_state(self, input: Input): f"version mismatch: '{state.version}' vs. current '{self.version}'" ) return state - except (pickle.PickleError, ValueError) as e: + except ValueError as e: logger.warning("state is incompatible", exc_info=True) raise IncompatibleStateError from e @@ -454,32 +452,24 @@ async def _run(self, input: Input) -> Output: await asyncio.gather(*asyncio_tasks, return_exceptions=True) return coroutine_result - # Serialize coroutines and scheduler state. - logger.debug("serializing state") + # Yield to Dispatch. + logger.debug("yielding to Dispatch with %d call(s)", len(pending_calls)) try: - serialized_state = pickle.dumps(state) + return Output.poll( + coroutine_state=state, + calls=pending_calls, + min_results=max(1, min(state.outstanding_calls, self.poll_min_results)), + max_results=max(1, min(state.outstanding_calls, self.poll_max_results)), + max_wait_seconds=self.poll_max_wait_seconds, + ) except pickle.PickleError as e: logger.exception("state could not be serialized") return Output.error(Error.from_exception(e, status=Status.PERMANENT_ERROR)) - - # Close coroutines before yielding. - for suspended in state.suspended.values(): - suspended.coroutine.close() - state.suspended = {} - - # Yield to Dispatch. - logger.debug( - "yielding to Dispatch with %d call(s) and %d bytes of state", - len(pending_calls), - len(serialized_state), - ) - return Output.poll( - coroutine_state=serialized_state, - calls=pending_calls, - min_results=max(1, min(state.outstanding_calls, self.poll_min_results)), - max_results=max(1, min(state.outstanding_calls, self.poll_max_results)), - max_wait_seconds=self.poll_max_wait_seconds, - ) + finally: + # Close coroutines. + for suspended in state.suspended.values(): + suspended.coroutine.close() + state.suspended = {} async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): diff --git a/src/dispatch/sdk/python/__init__.py b/src/dispatch/sdk/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dispatch/sdk/python/v1/__init__.py b/src/dispatch/sdk/python/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dispatch/sdk/python/v1/pickled_pb2.py b/src/dispatch/sdk/python/v1/pickled_pb2.py new file mode 100644 index 0000000..a844920 --- /dev/null +++ b/src/dispatch/sdk/python/v1/pickled_pb2.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: dispatch/sdk/python/v1/pickled.proto +# Protobuf Python Version: 4.25.2 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n$dispatch/sdk/python/v1/pickled.proto\x12\x16\x64ispatch.sdk.python.v1".\n\x07Pickled\x12#\n\rpickled_value\x18\x01 \x01(\x0cR\x0cpickledValueB\xa5\x01\n\x1a\x63om.dispatch.sdk.python.v1B\x0cPickledProtoP\x01\xa2\x02\x03\x44SP\xaa\x02\x16\x44ispatch.Sdk.Python.V1\xca\x02\x16\x44ispatch\\Sdk\\Python\\V1\xe2\x02"Dispatch\\Sdk\\Python\\V1\\GPBMetadata\xea\x02\x19\x44ispatch::Sdk::Python::V1b\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "dispatch.sdk.python.v1.pickled_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals["DESCRIPTOR"]._options = None + _globals["DESCRIPTOR"]._serialized_options = ( + b'\n\032com.dispatch.sdk.python.v1B\014PickledProtoP\001\242\002\003DSP\252\002\026Dispatch.Sdk.Python.V1\312\002\026Dispatch\\Sdk\\Python\\V1\342\002"Dispatch\\Sdk\\Python\\V1\\GPBMetadata\352\002\031Dispatch::Sdk::Python::V1' + ) + _globals["_PICKLED"]._serialized_start = 64 + _globals["_PICKLED"]._serialized_end = 110 +# @@protoc_insertion_point(module_scope) diff --git a/src/dispatch/sdk/python/v1/pickled_pb2.pyi b/src/dispatch/sdk/python/v1/pickled_pb2.pyi new file mode 100644 index 0000000..d30207f --- /dev/null +++ b/src/dispatch/sdk/python/v1/pickled_pb2.pyi @@ -0,0 +1,13 @@ +from typing import ClassVar as _ClassVar +from typing import Optional as _Optional + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message + +DESCRIPTOR: _descriptor.FileDescriptor + +class Pickled(_message.Message): + __slots__ = ("pickled_value",) + PICKLED_VALUE_FIELD_NUMBER: _ClassVar[int] + pickled_value: bytes + def __init__(self, pickled_value: _Optional[bytes] = ...) -> None: ... diff --git a/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py b/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py new file mode 100644 index 0000000..8a93939 --- /dev/null +++ b/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc diff --git a/src/dispatch/sdk/v1/call_pb2.py b/src/dispatch/sdk/v1/call_pb2.py index 1a44f4d..1a08f27 100644 --- a/src/dispatch/sdk/v1/call_pb2.py +++ b/src/dispatch/sdk/v1/call_pb2.py @@ -20,7 +20,7 @@ from dispatch.sdk.v1 import error_pb2 as dispatch_dot_sdk_dot_v1_dot_error__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1a\x64ispatch/sdk/v1/call.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1b\x64ispatch/sdk/v1/error.proto\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto"\x87\x02\n\x04\x43\x61ll\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12$\n\x08\x65ndpoint\x18\x02 \x01(\tB\x08\xbaH\x05r\x03\x88\x01\x01R\x08\x65ndpoint\x12\x41\n\x08\x66unction\x18\x03 \x01(\tB%\xbaH"r\x1d\x32\x1b^[a-zA-Z_][a-zA-Z0-9_<>.]*$\xc8\x01\x01R\x08\x66unction\x12*\n\x05input\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05input\x12\x43\n\nexpiration\x18\x05 \x01(\x0b\x32\x19.google.protobuf.DurationB\x08\xbaH\x05\xaa\x01\x02\x32\x00R\nexpiration"\xb0\x01\n\nCallResult\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12,\n\x06output\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x06output\x12,\n\x05\x65rror\x18\x03 \x01(\x0b\x32\x16.dispatch.sdk.v1.ErrorR\x05\x65rror\x12\x1f\n\x0b\x64ispatch_id\x18\x04 \x01(\tR\ndispatchIdB~\n\x13\x63om.dispatch.sdk.v1B\tCallProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3' + b'\n\x1a\x64ispatch/sdk/v1/call.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1b\x64ispatch/sdk/v1/error.proto\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto"\xa1\x02\n\x04\x43\x61ll\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12$\n\x08\x65ndpoint\x18\x02 \x01(\tB\x08\xbaH\x05r\x03\x88\x01\x01R\x08\x65ndpoint\x12\x41\n\x08\x66unction\x18\x03 \x01(\tB%\xbaH"r\x1d\x32\x1b^[a-zA-Z_][a-zA-Z0-9_<>.]*$\xc8\x01\x01R\x08\x66unction\x12*\n\x05input\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05input\x12\x43\n\nexpiration\x18\x05 \x01(\x0b\x32\x19.google.protobuf.DurationB\x08\xbaH\x05\xaa\x01\x02\x32\x00R\nexpiration\x12\x18\n\x07version\x18\x06 \x01(\tR\x07version"\xb0\x01\n\nCallResult\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12,\n\x06output\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x06output\x12,\n\x05\x65rror\x18\x03 \x01(\x0b\x32\x16.dispatch.sdk.v1.ErrorR\x05\x65rror\x12\x1f\n\x0b\x64ispatch_id\x18\x04 \x01(\tR\ndispatchIdB~\n\x13\x63om.dispatch.sdk.v1B\tCallProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3' ) _globals = globals() @@ -46,7 +46,7 @@ "expiration" ]._serialized_options = b"\272H\005\252\001\0022\000" _globals["_CALL"]._serialized_start = 165 - _globals["_CALL"]._serialized_end = 428 - _globals["_CALLRESULT"]._serialized_start = 431 - _globals["_CALLRESULT"]._serialized_end = 607 + _globals["_CALL"]._serialized_end = 454 + _globals["_CALLRESULT"]._serialized_start = 457 + _globals["_CALLRESULT"]._serialized_end = 633 # @@protoc_insertion_point(module_scope) diff --git a/src/dispatch/sdk/v1/call_pb2.pyi b/src/dispatch/sdk/v1/call_pb2.pyi index 8989e1c..8e3c6d3 100644 --- a/src/dispatch/sdk/v1/call_pb2.pyi +++ b/src/dispatch/sdk/v1/call_pb2.pyi @@ -14,17 +14,26 @@ from dispatch.sdk.v1 import error_pb2 as _error_pb2 DESCRIPTOR: _descriptor.FileDescriptor class Call(_message.Message): - __slots__ = ("correlation_id", "endpoint", "function", "input", "expiration") + __slots__ = ( + "correlation_id", + "endpoint", + "function", + "input", + "expiration", + "version", + ) CORRELATION_ID_FIELD_NUMBER: _ClassVar[int] ENDPOINT_FIELD_NUMBER: _ClassVar[int] FUNCTION_FIELD_NUMBER: _ClassVar[int] INPUT_FIELD_NUMBER: _ClassVar[int] EXPIRATION_FIELD_NUMBER: _ClassVar[int] + VERSION_FIELD_NUMBER: _ClassVar[int] correlation_id: int endpoint: str function: str input: _any_pb2.Any expiration: _duration_pb2.Duration + version: str def __init__( self, correlation_id: _Optional[int] = ..., @@ -32,6 +41,7 @@ class Call(_message.Message): function: _Optional[str] = ..., input: _Optional[_Union[_any_pb2.Any, _Mapping]] = ..., expiration: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ..., + version: _Optional[str] = ..., ) -> None: ... class CallResult(_message.Message): diff --git a/src/dispatch/sdk/v1/dispatch_pb2.py b/src/dispatch/sdk/v1/dispatch_pb2.py index 793b592..17adc55 100644 --- a/src/dispatch/sdk/v1/dispatch_pb2.py +++ b/src/dispatch/sdk/v1/dispatch_pb2.py @@ -17,7 +17,7 @@ from dispatch.sdk.v1 import call_pb2 as dispatch_dot_sdk_dot_v1_dot_call__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b"\n\x1e\x64ispatch/sdk/v1/dispatch.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1a\x64ispatch/sdk/v1/call.proto\"\x9d\x03\n\x0f\x44ispatchRequest\x12+\n\x05\x63\x61lls\x18\x01 \x03(\x0b\x32\x15.dispatch.sdk.v1.CallR\x05\x63\x61lls:\xdc\x02\xbaH\xd8\x02\x1as\n(dispatch.request.calls.endpoint.nonempty\x12\x1d\x43\x61ll endpoint cannot be empty\x1a(this.calls.all(call, has(call.endpoint))\x1a\xe0\x01\n&dispatch.request.calls.endpoint.scheme\x12)Call endpoint must be a http or https URL\x1a\x8a\x01this.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://') || call.endpoint.startsWith('bridge://'))\"5\n\x10\x44ispatchResponse\x12!\n\x0c\x64ispatch_ids\x18\x01 \x03(\tR\x0b\x64ispatchIds2d\n\x0f\x44ispatchService\x12Q\n\x08\x44ispatch\x12 .dispatch.sdk.v1.DispatchRequest\x1a!.dispatch.sdk.v1.DispatchResponse\"\x00\x42\x82\x01\n\x13\x63om.dispatch.sdk.v1B\rDispatchProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3" + b"\n\x1e\x64ispatch/sdk/v1/dispatch.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1a\x64ispatch/sdk/v1/call.proto\"\xea\x03\n\x0f\x44ispatchRequest\x12+\n\x05\x63\x61lls\x18\x01 \x03(\x0b\x32\x15.dispatch.sdk.v1.CallR\x05\x63\x61lls:\xa9\x03\xbaH\xa5\x03\x1as\n(dispatch.request.calls.endpoint.nonempty\x12\x1d\x43\x61ll endpoint cannot be empty\x1a(this.calls.all(call, has(call.endpoint))\x1a\xad\x02\n&dispatch.request.calls.endpoint.scheme\x12HCall endpoint must be a http, https or a bridge URL or an AWS Lambda ARN\x1a\xb8\x01this.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://') || call.endpoint.startsWith('bridge://') || call.endpoint.startsWith('arn:aws:lambda'))\"5\n\x10\x44ispatchResponse\x12!\n\x0c\x64ispatch_ids\x18\x01 \x03(\tR\x0b\x64ispatchIds2d\n\x0f\x44ispatchService\x12Q\n\x08\x44ispatch\x12 .dispatch.sdk.v1.DispatchRequest\x1a!.dispatch.sdk.v1.DispatchResponse\"\x00\x42\x82\x01\n\x13\x63om.dispatch.sdk.v1B\rDispatchProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3" ) _globals = globals() @@ -32,12 +32,12 @@ ) _globals["_DISPATCHREQUEST"]._options = None _globals["_DISPATCHREQUEST"]._serialized_options = ( - b"\272H\330\002\032s\n(dispatch.request.calls.endpoint.nonempty\022\035Call endpoint cannot be empty\032(this.calls.all(call, has(call.endpoint))\032\340\001\n&dispatch.request.calls.endpoint.scheme\022)Call endpoint must be a http or https URL\032\212\001this.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://') || call.endpoint.startsWith('bridge://'))" + b"\272H\245\003\032s\n(dispatch.request.calls.endpoint.nonempty\022\035Call endpoint cannot be empty\032(this.calls.all(call, has(call.endpoint))\032\255\002\n&dispatch.request.calls.endpoint.scheme\022HCall endpoint must be a http, https or a bridge URL or an AWS Lambda ARN\032\270\001this.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://') || call.endpoint.startsWith('bridge://') || call.endpoint.startsWith('arn:aws:lambda'))" ) _globals["_DISPATCHREQUEST"]._serialized_start = 109 - _globals["_DISPATCHREQUEST"]._serialized_end = 522 - _globals["_DISPATCHRESPONSE"]._serialized_start = 524 - _globals["_DISPATCHRESPONSE"]._serialized_end = 577 - _globals["_DISPATCHSERVICE"]._serialized_start = 579 - _globals["_DISPATCHSERVICE"]._serialized_end = 679 + _globals["_DISPATCHREQUEST"]._serialized_end = 599 + _globals["_DISPATCHRESPONSE"]._serialized_start = 601 + _globals["_DISPATCHRESPONSE"]._serialized_end = 654 + _globals["_DISPATCHSERVICE"]._serialized_start = 656 + _globals["_DISPATCHSERVICE"]._serialized_end = 756 # @@protoc_insertion_point(module_scope) diff --git a/src/dispatch/sdk/v1/poll_pb2.py b/src/dispatch/sdk/v1/poll_pb2.py index a404c05..9cfc253 100644 --- a/src/dispatch/sdk/v1/poll_pb2.py +++ b/src/dispatch/sdk/v1/poll_pb2.py @@ -13,6 +13,7 @@ _sym_db = _symbol_database.Default() +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 from buf.validate import validate_pb2 as buf_dot_validate_dot_validate__pb2 @@ -20,7 +21,7 @@ from dispatch.sdk.v1 import error_pb2 as dispatch_dot_sdk_dot_v1_dot_error__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b"\n\x1a\x64ispatch/sdk/v1/poll.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1a\x64ispatch/sdk/v1/call.proto\x1a\x1b\x64ispatch/sdk/v1/error.proto\x1a\x1egoogle/protobuf/duration.proto\"\x81\x02\n\x04Poll\x12'\n\x0f\x63oroutine_state\x18\x01 \x01(\x0cR\x0e\x63oroutineState\x12+\n\x05\x63\x61lls\x18\x02 \x03(\x0b\x32\x15.dispatch.sdk.v1.CallR\x05\x63\x61lls\x12\x43\n\x08max_wait\x18\x03 \x01(\x0b\x32\x19.google.protobuf.DurationB\r\xbaH\n\xaa\x01\x04\x32\x02\x08\x01\xc8\x01\x01R\x07maxWait\x12.\n\x0bmax_results\x18\x04 \x01(\x05\x42\r\xbaH\n\x1a\x05\x18\xe8\x07(\x01\xc8\x01\x01R\nmaxResults\x12.\n\x0bmin_results\x18\x05 \x01(\x05\x42\r\xbaH\n\x1a\x05\x18\xe8\x07(\x01\xc8\x01\x01R\nminResults\"\x9a\x01\n\nPollResult\x12'\n\x0f\x63oroutine_state\x18\x01 \x01(\x0cR\x0e\x63oroutineState\x12\x35\n\x07results\x18\x02 \x03(\x0b\x32\x1b.dispatch.sdk.v1.CallResultR\x07results\x12,\n\x05\x65rror\x18\x03 \x01(\x0b\x32\x16.dispatch.sdk.v1.ErrorR\x05\x65rrorB~\n\x13\x63om.dispatch.sdk.v1B\tPollProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3" + b'\n\x1a\x64ispatch/sdk/v1/poll.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1a\x64ispatch/sdk/v1/call.proto\x1a\x1b\x64ispatch/sdk/v1/error.proto\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto"\xd8\x02\n\x04Poll\x12)\n\x0f\x63oroutine_state\x18\x01 \x01(\x0cH\x00R\x0e\x63oroutineState\x12J\n\x15typed_coroutine_state\x18\x06 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\x13typedCoroutineState\x12+\n\x05\x63\x61lls\x18\x02 \x03(\x0b\x32\x15.dispatch.sdk.v1.CallR\x05\x63\x61lls\x12\x43\n\x08max_wait\x18\x03 \x01(\x0b\x32\x19.google.protobuf.DurationB\r\xbaH\n\xaa\x01\x04\x32\x02\x08\x01\xc8\x01\x01R\x07maxWait\x12.\n\x0bmax_results\x18\x04 \x01(\x05\x42\r\xbaH\n\x1a\x05\x18\xe8\x07(\x01\xc8\x01\x01R\nmaxResults\x12.\n\x0bmin_results\x18\x05 \x01(\x05\x42\r\xbaH\n\x1a\x05\x18\xe8\x07(\x01\xc8\x01\x01R\nminResultsB\x07\n\x05state"\xf1\x01\n\nPollResult\x12)\n\x0f\x63oroutine_state\x18\x01 \x01(\x0cH\x00R\x0e\x63oroutineState\x12J\n\x15typed_coroutine_state\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\x13typedCoroutineState\x12\x35\n\x07results\x18\x02 \x03(\x0b\x32\x1b.dispatch.sdk.v1.CallResultR\x07results\x12,\n\x05\x65rror\x18\x03 \x01(\x0b\x32\x16.dispatch.sdk.v1.ErrorR\x05\x65rrorB\x07\n\x05stateB~\n\x13\x63om.dispatch.sdk.v1B\tPollProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3' ) _globals = globals() @@ -45,8 +46,8 @@ _globals["_POLL"].fields_by_name[ "min_results" ]._serialized_options = b"\272H\n\032\005\030\350\007(\001\310\001\001" - _globals["_POLL"]._serialized_start = 166 - _globals["_POLL"]._serialized_end = 423 - _globals["_POLLRESULT"]._serialized_start = 426 - _globals["_POLLRESULT"]._serialized_end = 580 + _globals["_POLL"]._serialized_start = 193 + _globals["_POLL"]._serialized_end = 537 + _globals["_POLLRESULT"]._serialized_start = 540 + _globals["_POLLRESULT"]._serialized_end = 781 # @@protoc_insertion_point(module_scope) diff --git a/src/dispatch/sdk/v1/poll_pb2.pyi b/src/dispatch/sdk/v1/poll_pb2.pyi index cb46596..b132517 100644 --- a/src/dispatch/sdk/v1/poll_pb2.pyi +++ b/src/dispatch/sdk/v1/poll_pb2.pyi @@ -4,6 +4,7 @@ from typing import Mapping as _Mapping from typing import Optional as _Optional from typing import Union as _Union +from google.protobuf import any_pb2 as _any_pb2 from google.protobuf import descriptor as _descriptor from google.protobuf import duration_pb2 as _duration_pb2 from google.protobuf import message as _message @@ -16,13 +17,22 @@ from dispatch.sdk.v1 import error_pb2 as _error_pb2 DESCRIPTOR: _descriptor.FileDescriptor class Poll(_message.Message): - __slots__ = ("coroutine_state", "calls", "max_wait", "max_results", "min_results") + __slots__ = ( + "coroutine_state", + "typed_coroutine_state", + "calls", + "max_wait", + "max_results", + "min_results", + ) COROUTINE_STATE_FIELD_NUMBER: _ClassVar[int] + TYPED_COROUTINE_STATE_FIELD_NUMBER: _ClassVar[int] CALLS_FIELD_NUMBER: _ClassVar[int] MAX_WAIT_FIELD_NUMBER: _ClassVar[int] MAX_RESULTS_FIELD_NUMBER: _ClassVar[int] MIN_RESULTS_FIELD_NUMBER: _ClassVar[int] coroutine_state: bytes + typed_coroutine_state: _any_pb2.Any calls: _containers.RepeatedCompositeFieldContainer[_call_pb2.Call] max_wait: _duration_pb2.Duration max_results: int @@ -30,6 +40,7 @@ class Poll(_message.Message): def __init__( self, coroutine_state: _Optional[bytes] = ..., + typed_coroutine_state: _Optional[_Union[_any_pb2.Any, _Mapping]] = ..., calls: _Optional[_Iterable[_Union[_call_pb2.Call, _Mapping]]] = ..., max_wait: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ..., max_results: _Optional[int] = ..., @@ -37,16 +48,19 @@ class Poll(_message.Message): ) -> None: ... class PollResult(_message.Message): - __slots__ = ("coroutine_state", "results", "error") + __slots__ = ("coroutine_state", "typed_coroutine_state", "results", "error") COROUTINE_STATE_FIELD_NUMBER: _ClassVar[int] + TYPED_COROUTINE_STATE_FIELD_NUMBER: _ClassVar[int] RESULTS_FIELD_NUMBER: _ClassVar[int] ERROR_FIELD_NUMBER: _ClassVar[int] coroutine_state: bytes + typed_coroutine_state: _any_pb2.Any results: _containers.RepeatedCompositeFieldContainer[_call_pb2.CallResult] error: _error_pb2.Error def __init__( self, coroutine_state: _Optional[bytes] = ..., + typed_coroutine_state: _Optional[_Union[_any_pb2.Any, _Mapping]] = ..., results: _Optional[_Iterable[_Union[_call_pb2.CallResult, _Mapping]]] = ..., error: _Optional[_Union[_error_pb2.Error, _Mapping]] = ..., ) -> None: ... diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py index 195c4d1..ac23738 100644 --- a/src/dispatch/test/service.py +++ b/src/dispatch/test/service.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Set, Tuple import grpc +from google.protobuf import any_pb2 as any_pb from typing_extensions import TypeAlias import dispatch.sdk.v1.call_pb2 as call_pb @@ -210,7 +211,7 @@ def dispatch_calls(self): parent_id=request.parent_dispatch_id, root_id=request.root_dispatch_id, function=request.function, - coroutine_state=response.poll.coroutine_state, + typed_coroutine_state=response.poll.typed_coroutine_state, waiting={}, results={}, ) @@ -282,7 +283,7 @@ def dispatch_calls(self): root_dispatch_id=poller.root_id, function=poller.function, poll_result=poll_pb.PollResult( - coroutine_state=poller.coroutine_state, + typed_coroutine_state=poller.typed_coroutine_state, results=poller.results.values(), ), ) @@ -354,7 +355,7 @@ class Poller: function: str - coroutine_state: bytes + typed_coroutine_state: any_pb.Any # TODO: support max_wait/min_results/max_results waiting: Dict[DispatchID, call_pb.Call] diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index a88c4b7..554b633 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -449,7 +449,7 @@ def resume( poll = self.assert_poll(prev_output) input = Input.from_poll_results( main.__qualname__, - poll.coroutine_state, + any_unpickle(poll.typed_coroutine_state), call_results, Error.from_exception(poll_error) if poll_error else None, ) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index c0cce39..1f53543 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -20,6 +20,7 @@ from dispatch.experimental.durable.registry import clear_functions from dispatch.function import Arguments, Error, Function, Input, Output, Registry from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import parse_verification_key, public_key_from_pem @@ -109,22 +110,15 @@ async def my_function(input: Input) -> Output: http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) client = EndpointClient(http_client) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) - req = function_pb.RunRequest( function=my_function.name, - input=input_any, + input=any_pickle("Hello World!"), ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 149b817..e7f4e52 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -21,6 +21,7 @@ from dispatch.fastapi import Dispatch from dispatch.function import Arguments, Error, Function, Input, Output from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -91,23 +92,17 @@ async def my_function(input: Input) -> Output: ) client = create_endpoint_client(app) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) req = function_pb.RunRequest( function=my_function.name, - input=input_any, + input=any_pickle("Hello World!"), ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") @@ -229,12 +224,11 @@ def execute( req = function_pb.RunRequest(function=func.name) if input is not None: - input_bytes = pickle.dumps(input) - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes)) - req.input.CopyFrom(input_any) + any = any_pickle(input) + req.input.CopyFrom(any) if state is not None: - req.poll_result.coroutine_state = state + any = any_pickle(state) + req.poll_result.typed_coroutine_state.CopyFrom(any) if calls is not None: for c in calls: req.poll_result.results.append(c) @@ -254,10 +248,6 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: resp = self.client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - # Assert the response is terminal. Good enough until the test client can - # orchestrate coroutines. - self.assertTrue(len(resp.poll.coroutine_state) == 0) - resp.exit.result.correlation_id = call.correlation_id return resp.exit.result @@ -324,9 +314,10 @@ async def my_function(input: Input) -> Output: return Output.value("not reached") resp = self.execute(my_function, input="cool stuff") - self.assertEqual(b"42", resp.poll.coroutine_state) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(b"42", state) - resp = self.execute(my_function, state=resp.poll.coroutine_state) + resp = self.execute(my_function, state=state) self.assertEqual("ValueError", resp.exit.result.error.type) self.assertEqual( "This input is for a resumed coroutine", resp.exit.result.error.message @@ -367,32 +358,29 @@ async def coroutine3(input: Input) -> Output: if input.is_first_call: counter = input.input else: - (counter,) = struct.unpack("@i", input.coroutine_state) + counter = input.coroutine_state counter -= 1 if counter <= 0: return Output.value("done") - coroutine_state = struct.pack("@i", counter) - return Output.poll(coroutine_state=coroutine_state) + return Output.poll(coroutine_state=counter) # first call resp = self.execute(coroutine3, input=4) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, 3) # resume, state = 3 resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, 2) # resume, state = 2 resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, 1) # resume, state = 1 resp = self.execute(coroutine3, state=state) - state = resp.poll.coroutine_state - self.assertTrue(len(state) == 0) out = response_output(resp) self.assertEqual(out, "done") @@ -406,18 +394,18 @@ async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( - coroutine_state=text.encode(), + coroutine_state=text, calls=[coro_compute_len._build_primitive_call(text)], ) - text = input.coroutine_state.decode() + text = input.coroutine_state length = input.call_results[0].output return Output.value(f"length={length} text='{text}'") resp = self.execute(coroutine_main, input="cool stuff") # main saved some state - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, "cool stuff") # main asks for 1 call to compute_len self.assertEqual(len(resp.poll.calls), 1) call = resp.poll.calls[0] @@ -433,7 +421,6 @@ async def coroutine_main(input: Input) -> Output: # resume main with the result resp = self.execute(coroutine_main, state=state, calls=[resp2]) # validate the final result - self.assertTrue(len(resp.poll.coroutine_state) == 0) out = response_output(resp) self.assertEqual("length=10 text='cool stuff'", out) @@ -447,7 +434,7 @@ async def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( - coroutine_state=text.encode(), + coroutine_state=text, calls=[coro_compute_len._build_primitive_call(text)], ) error = input.call_results[0].error @@ -459,8 +446,8 @@ async def coroutine_main(input: Input) -> Output: resp = self.execute(coroutine_main, input="cool stuff") # main saved some state - state = resp.poll.coroutine_state - self.assertTrue(len(state) > 0) + state = any_unpickle(resp.poll.typed_coroutine_state) + self.assertEqual(state, "cool stuff") # main asks for 1 call to compute_len self.assertEqual(len(resp.poll.calls), 1) call = resp.poll.calls[0] @@ -473,7 +460,6 @@ async def coroutine_main(input: Input) -> Output: # resume main with the result resp = self.execute(coroutine_main, state=state, calls=[resp2]) # validate the final result - self.assertTrue(len(resp.poll.coroutine_state) == 0) out = response_output(resp) self.assertEqual(out, "msg=Dead type='type'") diff --git a/tests/test_flask.py b/tests/test_flask.py index f681a17..ae6d431 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -19,6 +19,7 @@ from dispatch.flask import Dispatch from dispatch.function import Arguments, Error, Function, Input, Output from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -56,23 +57,16 @@ async def my_function(input: Input) -> Output: ) client = create_endpoint_client(app) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) req = function_pb.RunRequest( - function=my_function.name, - input=input_any, + function=my_function.name, input=any_pickle("Hello World!") ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") diff --git a/tests/test_http.py b/tests/test_http.py index 9ac4e1a..56aa8e3 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -19,6 +19,7 @@ from dispatch.function import Arguments, Error, Function, Input, Output, Registry from dispatch.http import Dispatch from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import parse_verification_key, public_key_from_pem @@ -91,22 +92,14 @@ async def my_function(input: Input) -> Output: http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) client = EndpointClient(http_client) - pickled = pickle.dumps("Hello World!") - input_any = google.protobuf.any_pb2.Any() - input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) - req = function_pb.RunRequest( - function=my_function.name, - input=input_any, + function=my_function.name, input=any_pickle("Hello World!") ) resp = client.run(req) self.assertIsInstance(resp, function_pb.RunResponse) - resp.exit.result.output.Unpack( - output_bytes := google.protobuf.wrappers_pb2.BytesValue() - ) - output = pickle.loads(output_bytes.value) + output = any_unpickle(resp.exit.result.output) self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")