diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index d04e0f54..e4f69b7f 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -599,8 +599,6 @@ def _apply_query_workflow( ) -> None: # Wrap entire bunch of work in a task async def run_query() -> None: - command = self._add_command() - command.respond_to_query.query_id = job.query_id try: with self._as_read_only(): # Named query or dynamic @@ -632,11 +630,13 @@ async def run_query() -> None: raise ValueError( f"Expected 1 result payload, got {len(result_payloads)}" ) - command.respond_to_query.succeeded.response.CopyFrom( - result_payloads[0] - ) + command = self._add_command() + command.respond_to_query.query_id = job.query_id + command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0]) except Exception as err: try: + command = self._add_command() + command.respond_to_query.query_id = job.query_id self._failure_converter.to_failure( err, self._payload_converter, @@ -1427,7 +1427,7 @@ async def run_activity() -> Any: await asyncio.sleep( err.backoff.backoff_duration.ToTimedelta().total_seconds() ) - handle._apply_schedule_command(self._add_command(), err.backoff) + handle._apply_schedule_command(err.backoff) # We have to put the handle back on the pending activity # dict with its new seq self._pending_activities[handle._seq] = handle @@ -1437,19 +1437,22 @@ async def run_activity() -> Any: # Create the handle and set as pending handle = _ActivityHandle(self, input, run_activity()) - handle._apply_schedule_command(self._add_command()) + handle._apply_schedule_command() self._pending_activities[handle._seq] = handle return handle async def _outbound_signal_child_workflow( self, input: SignalChildWorkflowInput ) -> None: + payloads = ( + self._payload_converter.to_payloads(input.args) if input.args else None + ) command = self._add_command() v = command.signal_external_workflow_execution v.child_workflow_id = input.child_workflow_id v.signal_name = input.signal - if input.args: - v.args.extend(self._payload_converter.to_payloads(input.args)) + if payloads: + v.args.extend(payloads) if input.headers: temporalio.common._apply_headers(input.headers, v.headers) await self._signal_external_workflow(command) @@ -1457,6 +1460,9 @@ async def _outbound_signal_child_workflow( async def _outbound_signal_external_workflow( self, input: SignalExternalWorkflowInput ) -> None: + payloads = ( + self._payload_converter.to_payloads(input.args) if input.args else None + ) command = self._add_command() v = command.signal_external_workflow_execution v.workflow_execution.namespace = input.namespace @@ -1464,8 +1470,8 @@ async def _outbound_signal_external_workflow( if input.workflow_run_id: v.workflow_execution.run_id = input.workflow_run_id v.signal_name = input.signal - if input.args: - v.args.extend(self._payload_converter.to_payloads(input.args)) + if payloads: + v.args.extend(payloads) if input.headers: temporalio.common._apply_headers(input.headers, v.headers) await self._signal_external_workflow(command) @@ -1510,7 +1516,7 @@ async def run_child() -> Any: handle = _ChildWorkflowHandle( self, self._next_seq("child_workflow"), input, run_child() ) - handle._apply_start_command(self._add_command()) + handle._apply_start_command() self._pending_child_workflows[handle._seq] = handle # Wait on start before returning @@ -1761,7 +1767,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None: await coro except _ContinueAsNewError as err: logger.debug("Workflow requested continue as new") - err._apply_command(self._add_command()) + err._apply_command() except (Exception, asyncio.CancelledError) as err: # During tear down we can ignore exceptions. Technically the # command-adding done later would throw a not-in-workflow exception @@ -1776,7 +1782,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None: # Handle continue as new if isinstance(err, _ContinueAsNewError): logger.debug("Workflow requested continue as new") - err._apply_command(self._add_command()) + err._apply_command() return logger.debug( @@ -2261,11 +2267,18 @@ def _resolve_backoff( def _apply_schedule_command( self, - command: temporalio.bridge.proto.workflow_commands.WorkflowCommand, local_backoff: Optional[ temporalio.bridge.proto.activity_result.DoBackoff ] = None, ) -> None: + # Convert arguments before creating command in case it raises error + payloads = ( + self._instance._payload_converter.to_payloads(self._input.args) + if self._input.args + else None + ) + + command = self._instance._add_command() # TODO(cretz): Why can't MyPy infer this? v: Union[ temporalio.bridge.proto.workflow_commands.ScheduleActivity, @@ -2280,10 +2293,8 @@ def _apply_schedule_command( v.activity_type = self._input.activity if self._input.headers: temporalio.common._apply_headers(self._input.headers, v.headers) - if self._input.args: - v.arguments.extend( - self._instance._payload_converter.to_payloads(self._input.args) - ) + if payloads: + v.arguments.extend(payloads) if self._input.schedule_to_close_timeout: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout @@ -2403,20 +2414,23 @@ def _resolve_failure(self, err: BaseException) -> None: # future self._result_fut.set_result(None) - def _apply_start_command( - self, - command: temporalio.bridge.proto.workflow_commands.WorkflowCommand, - ) -> None: + def _apply_start_command(self) -> None: + # Convert arguments before creating command in case it raises error + payloads = ( + self._instance._payload_converter.to_payloads(self._input.args) + if self._input.args + else None + ) + + command = self._instance._add_command() v = command.start_child_workflow_execution v.seq = self._seq v.namespace = self._instance._info.namespace v.workflow_id = self._input.id v.workflow_type = self._input.workflow v.task_queue = self._input.task_queue or self._instance._info.task_queue - if self._input.args: - v.input.extend( - self._instance._payload_converter.to_payloads(self._input.args) - ) + if payloads: + v.input.extend(payloads) if self._input.execution_timeout: v.workflow_execution_timeout.FromTimedelta(self._input.execution_timeout) if self._input.run_timeout: @@ -2520,19 +2534,31 @@ def __init__( self._instance = instance self._input = input - def _apply_command( - self, command: temporalio.bridge.proto.workflow_commands.WorkflowCommand - ) -> None: + def _apply_command(self) -> None: + # Convert arguments before creating command in case it raises error + payloads = ( + self._instance._payload_converter.to_payloads(self._input.args) + if self._input.args + else None + ) + memo_payloads = ( + { + k: self._instance._payload_converter.to_payloads([val])[0] + for k, val in self._input.memo.items() + } + if self._input.memo + else None + ) + + command = self._instance._add_command() v = command.continue_as_new_workflow_execution v.SetInParent() if self._input.workflow: v.workflow_type = self._input.workflow if self._input.task_queue: v.task_queue = self._input.task_queue - if self._input.args: - v.arguments.extend( - self._instance._payload_converter.to_payloads(self._input.args) - ) + if payloads: + v.arguments.extend(payloads) if self._input.run_timeout: v.workflow_run_timeout.FromTimedelta(self._input.run_timeout) if self._input.task_timeout: @@ -2541,11 +2567,9 @@ def _apply_command( temporalio.common._apply_headers(self._input.headers, v.headers) if self._input.retry_policy: self._input.retry_policy.apply_to_proto(v.retry_policy) - if self._input.memo: - for k, val in self._input.memo.items(): - v.memo[k].CopyFrom( - self._instance._payload_converter.to_payloads([val])[0] - ) + if memo_payloads: + for k, val in memo_payloads.items(): + v.memo[k].CopyFrom(val) if self._input.search_attributes: _encode_search_attributes( self._input.search_attributes, v.search_attributes diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 23399137..353579e6 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3350,15 +3350,27 @@ async def test_workflow_optional_param(client: Client): class ExceptionRaisingPayloadConverter(DefaultPayloadConverter): - bad_str = "bad-payload-str" + bad_outbound_str = "bad-outbound-payload-str" + bad_inbound_str = "bad-inbound-payload-str" + + def to_payloads(self, values: Sequence[Any]) -> List[Payload]: + if any( + value == ExceptionRaisingPayloadConverter.bad_outbound_str + for value in values + ): + raise ApplicationError("Intentional outbound converter failure") + return super().to_payloads(values) def from_payloads( self, payloads: Sequence[Payload], type_hints: Optional[List] = None ) -> List[Any]: # Check if any payloads contain the bad data for payload in payloads: - if ExceptionRaisingPayloadConverter.bad_str.encode() in payload.data: - raise ApplicationError("Intentional converter failure") + if ( + ExceptionRaisingPayloadConverter.bad_inbound_str.encode() + in payload.data + ): + raise ApplicationError("Intentional inbound converter failure") return super().from_payloads(payloads, type_hints) @@ -3383,12 +3395,46 @@ async def test_exception_raising_converter_param(client: Client): with pytest.raises(WorkflowFailureError) as err: await client.execute_workflow( ExceptionRaisingConverterWorkflow.run, - ExceptionRaisingPayloadConverter.bad_str, + ExceptionRaisingPayloadConverter.bad_inbound_str, id=f"workflow-{uuid.uuid4()}", task_queue=worker.task_queue, ) assert isinstance(err.value.cause, ApplicationError) - assert "Intentional converter failure" in str(err.value.cause) + assert "Intentional inbound converter failure" in str(err.value.cause) + + +@workflow.defn +class ActivityOutboundConversionFailureWorkflow: + @workflow.run + async def run(self) -> None: + await workflow.execute_activity( + "some-activity", + ExceptionRaisingPayloadConverter.bad_outbound_str, + start_to_close_timeout=timedelta(seconds=10), + ) + + +async def test_workflow_activity_outbound_conversion_failure(client: Client): + # This test used to fail because we created commands _before_ we attempted + # to convert the arguments thereby causing half-built commands to get sent + # to the server. + + # Clone the client but change the data converter to use our converter + config = client.config() + config["data_converter"] = dataclasses.replace( + config["data_converter"], + payload_converter_class=ExceptionRaisingPayloadConverter, + ) + client = Client(**config) + async with new_worker(client, ActivityOutboundConversionFailureWorkflow) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + ActivityOutboundConversionFailureWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + assert isinstance(err.value.cause, ApplicationError) + assert "Intentional outbound converter failure" in str(err.value.cause) @dataclass