From 11a97d1ab2ebfe8c973bf396b1e14077ec611e52 Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Mon, 20 May 2024 15:23:47 -0500 Subject: [PATCH] Required wait update stage, update polling improvements, and other update changes (#521) Fixes #484 Fixes #424 Fixes #485 Fixes #514 --- .github/workflows/ci.yml | 1 + temporalio/client.py | 238 ++++++++++++++++-------- temporalio/worker/_workflow_instance.py | 6 +- tests/worker/test_workflow.py | 128 ++++++++++++- 4 files changed, 294 insertions(+), 79 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0856c2fd..20e75c4e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -105,3 +105,4 @@ jobs: python-repo-path: ${{github.event.pull_request.head.repo.full_name}} version: ${{github.event.pull_request.head.ref}} version-is-repo-ref: true + features-repo-ref: python-update-updates diff --git a/temporalio/client.py b/temporalio/client.py index 98197b64..1730c7d5 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -1155,8 +1155,9 @@ def id(self) -> str: @property def run_id(self) -> Optional[str]: - """Run ID used for :py:meth:`signal` and :py:meth:`query` calls if - present to ensure the query or signal happen on this exact run. + """Run ID used for :py:meth:`signal`, :py:meth:`query`, and + :py:meth:`update` calls if present to ensure the signal/query/update + happen on this exact run. This is only created via :py:meth:`Client.get_workflow_handle`. :py:meth:`Client.start_workflow` will not set this value. @@ -1843,7 +1844,7 @@ async def execute_update( update: Update function or name on the workflow. arg: Single argument to the update. args: Multiple arguments to the update. Cannot be set if arg is. - id: ID of the update. If not set, the server will set a UUID as the ID. + id: ID of the update. If not set, the default is a new UUID. result_type: For string updates, this can set the specific result type hint to deserialize into. rpc_metadata: Headers used on the RPC call. Keys here override @@ -1858,8 +1859,8 @@ async def execute_update( update, arg, args=args, + wait_for_stage=WorkflowUpdateStage.COMPLETED, id=id, - wait_for_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED, result_type=result_type, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, @@ -1872,6 +1873,7 @@ async def start_update( self, update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], *, + wait_for_stage: WorkflowUpdateStage, id: Optional[str] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, @@ -1887,6 +1889,7 @@ async def start_update( ], arg: ParamType, *, + wait_for_stage: WorkflowUpdateStage, id: Optional[str] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, @@ -1902,6 +1905,7 @@ async def start_update( ], *, args: MultiParamSpec.args, + wait_for_stage: WorkflowUpdateStage, id: Optional[str] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, @@ -1915,6 +1919,7 @@ async def start_update( update: str, arg: Any = temporalio.common._arg_unset, *, + wait_for_stage: WorkflowUpdateStage, args: Sequence[Any] = [], id: Optional[str] = None, result_type: Optional[Type] = None, @@ -1928,6 +1933,7 @@ async def start_update( update: Union[str, Callable], arg: Any = temporalio.common._arg_unset, *, + wait_for_stage: WorkflowUpdateStage, args: Sequence[Any] = [], id: Optional[str] = None, result_type: Optional[Type] = None, @@ -1950,8 +1956,11 @@ async def start_update( Args: update: Update function or name on the workflow. arg: Single argument to the update. + wait_for_stage: Required stage to wait until returning. ADMITTED is + not currently supported. See https://docs.temporal.io/workflows#update + for more details. args: Multiple arguments to the update. Cannot be set if arg is. - id: ID of the update. If not set, the server will set a UUID as the ID. + id: ID of the update. If not set, the default is a new UUID. result_type: For string updates, this can set the specific result type hint to deserialize into. rpc_metadata: Headers used on the RPC call. Keys here override @@ -1964,9 +1973,9 @@ async def start_update( return await self._start_update( update, arg, + wait_for_stage=wait_for_stage, args=args, id=id, - wait_for_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ACCEPTED, result_type=result_type, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, @@ -1977,13 +1986,15 @@ async def _start_update( update: Union[str, Callable], arg: Any = temporalio.common._arg_unset, *, + wait_for_stage: WorkflowUpdateStage, args: Sequence[Any] = [], id: Optional[str] = None, - wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED, result_type: Optional[Type] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, ) -> WorkflowUpdateHandle[Any]: + if wait_for_stage == WorkflowUpdateStage.ADMITTED: + raise ValueError("ADMITTED wait stage not supported") update_name: str ret_type = result_type if isinstance(update, temporalio.workflow.UpdateMethodMultiParam): @@ -2011,6 +2022,68 @@ async def _start_update( ) ) + def get_update_handle( + self, + id: str, + *, + workflow_run_id: Optional[str] = None, + result_type: Optional[Type] = None, + ) -> WorkflowUpdateHandle[Any]: + """Get a handle for an update. The handle can be used to wait on the + update result. + + Users may prefer the more typesafe :py:meth:`get_update_handle_for` + which accepts an update definition. + + .. warning:: + This API is experimental + + Args: + id: Update ID to get a handle to. + workflow_run_id: Run ID to tie the handle to. If this is not set, + the :py:attr:`run_id` will be used. + result_type: The result type to deserialize into if known. + + Returns: + The update handle. + """ + return WorkflowUpdateHandle( + self._client, + id, + self._id, + workflow_run_id=workflow_run_id or self._run_id, + result_type=result_type, + ) + + def get_update_handle_for( + self, + update: temporalio.workflow.UpdateMethodMultiParam[Any, LocalReturnType], + id: str, + *, + workflow_run_id: Optional[str] = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: + """Get a typed handle for an update. The handle can be used to wait on + the update result. + + This is the same as :py:meth:`get_update_handle` but typed. + + .. warning:: + This API is experimental + + Args: + update: The update method to use for typing the handle. + id: Update ID to get a handle to. + workflow_run_id: Run ID to tie the handle to. If this is not set, + the :py:attr:`run_id` will be used. + result_type: The result type to deserialize into if known. + + Returns: + The update handle. + """ + return self.get_update_handle( + id, workflow_run_id=workflow_run_id, result_type=update._defn.ret_type + ) + @dataclass(frozen=True) class AsyncActivityIDReference: @@ -4235,15 +4308,38 @@ async def result( WorkflowUpdateFailedError: If the update failed RPCError: Update result could not be fetched for some other reason. """ - if self._known_outcome is not None: - outcome = self._known_outcome - return await _update_outcome_to_result( - outcome, - self.id, - self._client.data_converter, - self._result_type, + # Poll until outcome reached + await self._poll_until_outcome( + rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout + ) + + # Convert outcome to failure or value + assert self._known_outcome + if self._known_outcome.HasField("failure"): + raise WorkflowUpdateFailedError( + await self._client.data_converter.decode_failure( + self._known_outcome.failure + ), ) + if not self._known_outcome.success.payloads: + return None # type: ignore + type_hints = [self._result_type] if self._result_type else None + results = await self._client.data_converter.decode( + self._known_outcome.success.payloads, type_hints + ) + if not results: + return None # type: ignore + elif len(results) > 1: + warnings.warn(f"Expected single update result, got {len(results)}") + return results[0] + async def _poll_until_outcome( + self, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> None: + if self._known_outcome: + return req = temporalio.api.workflowservice.v1.PollWorkflowExecutionUpdateRequest( namespace=self._client.namespace, update_ref=temporalio.api.update.v1.UpdateRef( @@ -4259,27 +4355,33 @@ async def result( ), ) - # Continue polling as long as we have either an empty response, or an *rpc* timeout + # Continue polling as long as we have no outcome while True: - try: - res = ( - await self._client.workflow_service.poll_workflow_execution_update( - req, - retry=True, - metadata=rpc_metadata, - timeout=rpc_timeout, - ) - ) - if res.HasField("outcome"): - return await _update_outcome_to_result( - res.outcome, - self.id, - self._client.data_converter, - self._result_type, - ) - except RPCError as err: - if err.status != RPCStatusCode.DEADLINE_EXCEEDED: - raise + res = await self._client.workflow_service.poll_workflow_execution_update( + req, + retry=True, + metadata=rpc_metadata, + timeout=rpc_timeout, + ) + if res.HasField("outcome"): + self._known_outcome = res.outcome + return + + +class WorkflowUpdateStage(IntEnum): + """Stage to wait for workflow update to reach before returning from + ``start_update``. + """ + + ADMITTED = int( + temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED + ) + ACCEPTED = int( + temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ACCEPTED + ) + COMPLETED = int( + temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED + ) class WorkflowFailureError(temporalio.exceptions.TemporalError): @@ -4508,9 +4610,7 @@ class StartWorkflowUpdateInput: update_id: Optional[str] update: str args: Sequence[Any] - wait_for_stage: Optional[ - temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType - ] + wait_for_stage: WorkflowUpdateStage headers: Mapping[str, temporalio.api.common.v1.Payload] ret_type: Optional[Type] rpc_metadata: Mapping[str, str] @@ -5125,11 +5225,7 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: async def start_workflow_update( self, input: StartWorkflowUpdateInput ) -> WorkflowUpdateHandle[Any]: - wait_policy = ( - temporalio.api.update.v1.WaitPolicy(lifecycle_stage=input.wait_for_stage) - if input.wait_for_stage is not None - else None - ) + # Build request req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( namespace=self._client.namespace, workflow_execution=temporalio.api.common.v1.WorkflowExecution( @@ -5138,14 +5234,18 @@ async def start_workflow_update( ), request=temporalio.api.update.v1.Request( meta=temporalio.api.update.v1.Meta( - update_id=input.update_id or "", + update_id=input.update_id or str(uuid.uuid4()), identity=self._client.identity, ), input=temporalio.api.update.v1.Input( name=input.update, ), ), - wait_policy=wait_policy, + wait_policy=temporalio.api.update.v1.WaitPolicy( + lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType( + input.wait_for_stage + ) + ), ) if input.args: req.request.input.args.payloads.extend( @@ -5155,25 +5255,36 @@ async def start_workflow_update( temporalio.common._apply_headers( input.headers, req.request.input.header.fields ) - try: + + # Repeatedly try to invoke start until the update reaches user-provided + # wait stage or is at least ACCEPTED (as of the time of this writing, + # the user cannot specify sooner than ACCEPTED) + resp: temporalio.api.workflowservice.v1.UpdateWorkflowExecutionResponse + while True: resp = await self._client.workflow_service.update_workflow_execution( req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout ) - except RPCError as err: - raise + if ( + resp.stage >= req.wait_policy.lifecycle_stage + or resp.stage + >= temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ACCEPTED + ): + break - determined_id = resp.update_ref.update_id - update_handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle( + # Build the handle. If the user's wait stage is COMPLETED, make sure we + # poll for result. + handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle( client=self._client, - id=determined_id, + id=req.request.meta.update_id, workflow_id=input.id, workflow_run_id=input.run_id, result_type=input.ret_type, ) if resp.HasField("outcome"): - update_handle._known_outcome = resp.outcome - - return update_handle + handle._known_outcome = resp.outcome + if input.wait_for_stage == WorkflowUpdateStage.COMPLETED: + await handle._poll_until_outcome() + return handle ### Async activity calls @@ -5700,27 +5811,6 @@ def _fix_history_enum(prefix: str, parent: Dict[str, Any], *attrs: str) -> None: _fix_history_enum(prefix, child_item, *attrs[1:]) -async def _update_outcome_to_result( - outcome: temporalio.api.update.v1.Outcome, - id: str, - converter: temporalio.converter.DataConverter, - rtype: Optional[Type], -) -> Any: - if outcome.HasField("failure"): - raise WorkflowUpdateFailedError( - await converter.decode_failure(outcome.failure), - ) - if not outcome.success.payloads: - return None - type_hints = [rtype] if rtype else None - results = await converter.decode(outcome.success.payloads, type_hints) - if not results: - return None - elif len(results) > 1: - warnings.warn(f"Expected single update result, got {len(results)}") - return results[0] - - @dataclass(frozen=True) class WorkerBuildIdVersionSets: """Represents the sets of compatible Build ID versions associated with some Task Queue, as diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 949160e0..34c997ad 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -302,15 +302,15 @@ def activate( activation_err: Optional[Exception] = None try: - # Split into job sets with patches, then signals, then non-queries, then - # queries + # Split into job sets with patches, then signals + updates, then + # non-queries, then queries job_sets: List[ List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob] ] = [[], [], [], []] for job in act.jobs: if job.HasField("notify_has_patch"): job_sets[0].append(job) - elif job.HasField("signal_workflow"): + elif job.HasField("signal_workflow") or job.HasField("do_update"): job_sets[1].append(job) elif not job.HasField("query_workflow"): job_sets[2].append(job) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index f7b11c7c..20e8685b 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -54,6 +54,7 @@ WorkflowQueryFailedError, WorkflowUpdateFailedError, WorkflowUpdateHandle, + WorkflowUpdateStage, ) from temporalio.common import ( RawValue, @@ -4174,6 +4175,124 @@ async def test_workflow_update_task_fails(client: Client, env: WorkflowEnvironme assert bad_validator_fail_ct == 2 +@workflow.defn +class ImmediatelyCompleteUpdateAndWorkflow: + def __init__(self) -> None: + self._got_update = "no" + + @workflow.run + async def run(self) -> str: + return "workflow-done" + + @workflow.update + async def update(self) -> str: + self._got_update = "yes" + return "update-done" + + @workflow.query + def got_update(self) -> str: + return self._got_update + + +async def test_workflow_update_before_worker_start( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1903" + ) + # In order to confirm that all started workflows get updates before the + # workflow completes, this test will start a workflow and start an update. + # Only then will it start the worker to process both in the task. The + # workflow and update should both succeed properly. This also invokes a + # query to confirm update mutation. We do this with the cache off to confirm + # replay behavior. + + # Start workflow + task_queue = f"tq-{uuid.uuid4()}" + handle = await client.start_workflow( + ImmediatelyCompleteUpdateAndWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Execute update in background + update_task = asyncio.create_task( + handle.execute_update(ImmediatelyCompleteUpdateAndWorkflow.update) + ) + + # Start no-cache worker on the task queue + async with new_worker( + client, + ImmediatelyCompleteUpdateAndWorkflow, + task_queue=task_queue, + max_cached_workflows=0, + ): + # Confirm workflow completed as expected + assert "workflow-done" == await handle.result() + assert "update-done" == await update_task + assert "yes" == await handle.query( + ImmediatelyCompleteUpdateAndWorkflow.got_update + ) + + +@workflow.defn +class UpdateSeparateHandleWorkflow: + def __init__(self) -> None: + self._complete = False + self._complete_update = False + + @workflow.run + async def run(self) -> str: + await workflow.wait_condition(lambda: self._complete) + return "workflow-done" + + @workflow.update + async def update(self) -> str: + await workflow.wait_condition(lambda: self._complete_update) + self._complete = True + return "update-done" + + @workflow.signal + async def signal(self) -> None: + self._complete_update = True + + +async def test_workflow_update_separate_handle( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1903" + ) + async with new_worker(client, UpdateSeparateHandleWorkflow) as worker: + # Start the workflow + handle = await client.start_workflow( + UpdateSeparateHandleWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Start an update waiting on accepted + update_handle_1 = await handle.start_update( + UpdateSeparateHandleWorkflow.update, + wait_for_stage=WorkflowUpdateStage.ACCEPTED, + ) + + # Create another handle and have them both wait for update complete + update_handle_2 = client.get_workflow_handle( + handle.id, run_id=handle.result_run_id + ).get_update_handle_for(UpdateSeparateHandleWorkflow.update, update_handle_1.id) + update_handle_task1 = asyncio.create_task(update_handle_1.result()) + update_handle_task2 = asyncio.create_task(update_handle_2.result()) + + # Signal completion and confirm all completed as expected + await handle.signal(UpdateSeparateHandleWorkflow.signal) + assert "update-done" == await update_handle_task1 + assert "update-done" == await update_handle_task2 + assert "workflow-done" == await handle.result() + + @workflow.defn class TimeoutSupportWorkflow: @workflow.run @@ -4432,7 +4551,10 @@ async def assert_scenario( update_handle: Optional[WorkflowUpdateHandle[Any]] = None if update_scenario: update_handle = await handle.start_update( - workflow.update, update_scenario, id="my-update-1" + workflow.update, + update_scenario, + wait_for_stage=WorkflowUpdateStage.ACCEPTED, + id="my-update-1", ) # Expect task or exception fail @@ -4448,7 +4570,9 @@ async def has_expected_task_fail() -> bool: return True return False - await assert_eq_eventually(True, has_expected_task_fail) + await assert_eq_eventually( + True, has_expected_task_fail, timeout=timedelta(seconds=20) + ) else: with pytest.raises(TemporalError) as err: # Update does not throw on non-determinism, the workflow