diff --git a/temporalio/client.py b/temporalio/client.py index 95c18143..becde146 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -8,6 +8,7 @@ import inspect import json import re +import sys import uuid import warnings from abc import ABC, abstractmethod @@ -1668,7 +1669,7 @@ async def start_update( *, args: Sequence[Any] = [], id: Optional[str] = None, - wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED, + wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED, result_type: Optional[Type] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, @@ -3803,7 +3804,7 @@ def __init__( self._workflow_id = workflow_id self._run_id = run_id self._result_type = result_type - self._known_result = None + self._known_result: Optional[temporalio.api.update.v1.Outcome] = None @property def id(self) -> str: @@ -3829,7 +3830,7 @@ async def result( self, *, timeout: Optional[timedelta] = None, - rpc_metadata: Mapping[str, str] = None, + rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, ) -> Any: """Wait for and return the result of the update. The result may already be known in which case no call is made. @@ -4084,7 +4085,7 @@ class UpdateWorkflowInput: update: str args: Sequence[Any] wait_for_stage: Optional[ - temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage + temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType ] headers: Mapping[str, temporalio.api.common.v1.Payload] ret_type: Optional[Type] @@ -4724,9 +4725,7 @@ async def start_workflow_update( # If the status is INVALID_ARGUMENT, we can assume it's an update # failed error if err.status == RPCStatusCode.INVALID_ARGUMENT: - raise WorkflowUpdateFailedError( - input.workflow_id, input.update, err.cause - ) + raise WorkflowUpdateFailedError(input.workflow_id, input.update, err) else: raise @@ -4760,7 +4759,9 @@ async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any: ) try: # Wait for at most the *overall* timeout - async with asyncio.timeout(input.timeout.total_seconds()): + async with asyncio.timeout( + input.timeout.total_seconds() if input.timeout else sys.float_info.max + ): # Continue polling as long as we have either an empty response, or an *rpc* timeout while True: try: diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index b1e1df6b..bf963ca9 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -25,6 +25,7 @@ import opentelemetry.trace import opentelemetry.trace.propagation.tracecontext import opentelemetry.util.types +from client import PollUpdateWorkflowInput, WorkflowUpdateHandle from typing_extensions import Protocol, TypeAlias, TypedDict import temporalio.activity @@ -244,16 +245,25 @@ async def signal_workflow( ): return await super().signal_workflow(input) - async def update_workflow( + async def start_workflow_update( self, input: temporalio.client.UpdateWorkflowInput - ) -> Any: + ) -> WorkflowUpdateHandle: + with self.root._start_as_current_span( + f"StartWorkflowUpdate:{input.update}", + attributes={"temporalWorkflowID": input.workflow_id}, + input=input, + kind=opentelemetry.trace.SpanKind.CLIENT, + ): + return await super().start_workflow_update(input) + + async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any: with self.root._start_as_current_span( - f"UpdateWorkflow:{input.update}", + f"PollWorkflowUpdate:{input.update}", attributes={"temporalWorkflowID": input.workflow_id}, input=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): - return await super().update_workflow(input) + return await super().poll_workflow_update(input) class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor): diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 20db27f3..46279b1b 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -452,6 +452,7 @@ async def run_update( accpetance_command: temporalio.bridge.proto.workflow_commands.WorkflowCommand, ) -> None: command = accpetance_command + assert defn is not None try: if defn.validator is not None: # Run the validator @@ -459,7 +460,7 @@ async def run_update( # Accept the update command.update_response.accepted.SetInParent() - command = None + command = None # type: ignore # Run the handler success = await self._inbound.handle_update_handler(handler_input) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 5380babf..4e1d448b 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -811,7 +811,8 @@ def _update_validator( update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None ): """Decorator for a workflow update validator method.""" - update_def.set_validator(fn) + if fn is not None: + update_def.set_validator(fn) def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None: @@ -1375,7 +1376,9 @@ def bind_fn(self, obj: Any) -> Callable[..., Any]: return _bind_method(obj, self.fn) def bind_validator(self, obj: Any) -> Callable[..., Any]: - return _bind_method(obj, self.validator) + if self.validator is not None: + return _bind_method(obj, self.validator) + return lambda *args, **kwargs: None def set_validator(self, validator: Callable[..., None]) -> None: # TODO: Verify arg types are the same diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 3f24d530..e9851445 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -87,6 +87,8 @@ def test_workflow_defn_good(): name="base_query", fn=GoodDefnBase.base_query, is_method=True ), }, + # TODO: Add + updates={}, sandboxed=True, ) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 1ae41a5f..6fded7d5 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3541,7 +3541,7 @@ async def last_event_async(self, an_arg: str) -> str: @workflow.update async def runs_activity(self, name: str) -> str: - act = workflow.start_activity_method( + act = workflow.start_activity( say_hello, name, schedule_to_close_timeout=timedelta(seconds=5) ) act.cancel() @@ -3565,7 +3565,9 @@ async def runs_activity(self, name: str) -> str: async def test_workflow_update_handlers(client: Client): - async with new_worker(client, UpdateHandlersWorkflow) as worker: + async with new_worker( + client, UpdateHandlersWorkflow, activities=[say_hello] + ) as worker: handle = await client.start_workflow( UpdateHandlersWorkflow.run, id=f"update-handlers-workflow-{uuid.uuid4()}",