diff --git a/temporalio/client.py b/temporalio/client.py index 16d64496..10d1d1f2 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -1624,14 +1624,14 @@ async def update( rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, ) -> Any: - """Send an update request to the workflow. + """Send an update request to the workflow and wait for it to complete. This will target the workflow with :py:attr:`run_id` if present. To use a different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`. .. warning:: - Handles created as a result of :py:meth:`Client.start_workflow` will - signal the latest workflow with the same workflow ID even if it is + WorkflowHandles created as a result of :py:meth:`Client.start_workflow` will + send updates to the latest workflow with the same workflow ID even if it is unrelated to the started workflow. Args: @@ -1645,6 +1645,55 @@ async def update( client-level RPC metadata keys. rpc_timeout: Optional RPC deadline to set for the RPC call. + Raises: + RPCError: There was some issue sending the update to the workflow. + """ + handle = await self.start_update( + update, + arg, + args=args, + id=id, + wait_for_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED, + result_type=result_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + return await handle.result() + + async def start_update( + self, + update: Union[str, Callable], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: Optional[str] = None, + wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED, + result_type: Optional[Type] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle: + """Send an update request to the workflow and return a handle to it. + + This will target the workflow with :py:attr:`run_id` if present. To use a + different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`. + + .. warning:: + WorkflowHandles created as a result of :py:meth:`Client.start_workflow` will + send updates to the latest workflow with the same workflow ID even if it is + unrelated to the started workflow. + + Args: + update: Update function or name on the workflow. + arg: Single argument to the update. + args: Multiple arguments to the update. Cannot be set if arg is. + id: ID of the update. If not set, the server will set a UUID as the ID. + wait_for_stage: Specifies at what point in the update request life cycle this request should return. + result_type: For string updates, this can set the specific result + type hint to deserialize into. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + Raises: RPCError: There was some issue sending the update to the workflow. """ @@ -1665,7 +1714,7 @@ async def update( else: update_name = str(update) - return await self._client._impl.update_workflow( + return await self._client._impl.start_workflow_update( UpdateWorkflowInput( id=self._id, run_id=self._run_id, @@ -1676,6 +1725,7 @@ async def update( ret_type=ret_type, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + wait_for_stage=wait_for_stage, ) ) @@ -2763,7 +2813,9 @@ class ScheduleActionStartWorkflow(ScheduleAction): headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]] = None @staticmethod - def _from_proto(info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo) -> ScheduleActionStartWorkflow: # type: ignore[override] + def _from_proto( + info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo, + ) -> ScheduleActionStartWorkflow: # type: ignore[override] return ScheduleActionStartWorkflow("", raw_info=info) # Overload for no-param workflow @@ -3731,6 +3783,82 @@ async def __anext__(self) -> ScheduleListDescription: return ret +class WorkflowUpdateHandle: + """Handle for a workflow update execution request.""" + + def __init__( + self, + client: Client, + id: str, + name: str, + workflow_id: str, + *, + run_id: Optional[str] = None, + result_type: Optional[Type] = None, + ): + self._client = client + self._id = id + self._name = name + self._workflow_id = workflow_id + self._run_id = run_id + self._result_type = result_type + self._known_result = None + + @property + def id(self) -> str: + """ID of this Update request""" + return self._id + + @property + def name(self) -> str: + """The name of the Update being invoked""" + return self._name + + @property + def workflow_id(self) -> str: + """The ID of the Workflow targeted by this Update""" + return self._workflow_id + + @property + def run_id(self) -> Optional[str]: + """If specified, the specific run of the Workflow targeted by this Update""" + return self._run_id + + async def result( + self, + *, + timeout: Optional[timedelta] = None, + rpc_metadata: Mapping[str, str] = None, + ) -> Any: + outcome: temporalio.api.update.v1.Outcome + if self._known_result is not None: + outcome = self._known_result + else: + # TODO: This + raise NotImplementedError + + if outcome.HasField("failure"): + raise WorkflowUpdateFailedError( + self.id, + self.name, + await self._client.data_converter.decode_failure(outcome.failure.cause), + ) + if not outcome.success.payloads: + return None + type_hints = [self._result_type] if self._result_type else None + results = await self._client.data_converter.decode( + outcome.success.payloads, type_hints + ) + if not results: + return None + elif len(results) > 1: + warnings.warn(f"Expected single update result, got {len(results)}") + return results[0] + + def _set_known_result(self, result: temporalio.api.update.v1.Outcome) -> None: + self._known_result = result + + class WorkflowFailureError(temporalio.exceptions.TemporalError): """Error that occurs when a workflow is unsuccessful.""" @@ -3939,13 +4067,14 @@ class TerminateWorkflowInput: class UpdateWorkflowInput: """Input for :py:meth:`OutboundInterceptor.update_workflow`.""" - # TODO: Wait policy - id: str run_id: Optional[str] update_id: str update: str args: Sequence[Any] + wait_for_stage: Optional[ + temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage + ] headers: Mapping[str, temporalio.api.common.v1.Payload] # Type may be absent ret_type: Optional[Type] @@ -4197,9 +4326,11 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: """Called for every :py:meth:`WorkflowHandle.terminate` call.""" await self.next.terminate_workflow(input) - async def update_workflow(self, input: UpdateWorkflowInput) -> Any: + async def start_workflow_update( + self, input: UpdateWorkflowInput + ) -> WorkflowUpdateHandle: """Called for every :py:meth:`WorkflowHandle.signal` call.""" - return await self.next.update_workflow(input) + return await self.next.start_workflow_update(input) ### Async activity calls @@ -4523,7 +4654,14 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout ) - async def update_workflow(self, input: UpdateWorkflowInput) -> Any: + async def start_workflow_update( + self, input: UpdateWorkflowInput + ) -> WorkflowUpdateHandle: + wait_policy = ( + temporalio.api.update.v1.WaitPolicy(lifecycle_stage=input.wait_for_stage) + if input.wait_for_stage is not None + else None + ) req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( namespace=self._client.namespace, workflow_execution=temporalio.api.common.v1.WorkflowExecution( @@ -4539,6 +4677,7 @@ async def update_workflow(self, input: UpdateWorkflowInput) -> Any: name=input.update, ), ), + wait_policy=wait_policy, ) if input.args: req.request.input.args.payloads.extend( @@ -4559,25 +4698,19 @@ async def update_workflow(self, input: UpdateWorkflowInput) -> Any: raise WorkflowUpdateFailedError(input.id, input.update, err.cause) else: raise - if resp.outcome.HasField("failure"): - raise WorkflowUpdateFailedError( - input.id, - input.update, - await self._client.data_converter.decode_failure( - resp.outcome.failure.cause - ), - ) - if not resp.outcome.success.payloads: - return None - type_hints = [input.ret_type] if input.ret_type else None - results = await self._client.data_converter.decode( - resp.outcome.success.payloads, type_hints + + update_handle = WorkflowUpdateHandle( + client=self._client, + id=input.update_id, + name=input.update, + workflow_id=input.id, + run_id=input.run_id, + result_type=input.ret_type, ) - if not results: - return None - elif len(results) > 1: - warnings.warn(f"Expected single update result, got {len(results)}") - return results[0] + if resp.HasField("outcome"): + update_handle._set_known_result(resp.outcome) + + return update_handle ### Async activity calls diff --git a/tests/test_client.py b/tests/test_client.py index 255fbb51..a535beea 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -64,6 +64,7 @@ WorkflowHandle, WorkflowQueryFailedError, WorkflowQueryRejectedError, + WorkflowUpdateHandle, _history_from_json, ) from temporalio.common import RetryPolicy @@ -401,9 +402,11 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: self._parent.traces.append(("terminate_workflow", input)) return await super().terminate_workflow(input) - async def update_workflow(self, input: UpdateWorkflowInput) -> Any: - self._parent.traces.append(("update_workflow", input)) - return await super().update_workflow(input) + async def start_workflow_update( + self, input: UpdateWorkflowInput + ) -> WorkflowUpdateHandle: + self._parent.traces.append(("start_workflow_update", input)) + return await super().start_workflow_update(input) async def test_interceptor(client: Client, worker: ExternalWorker):