Skip to content

Commit

Permalink
Add UpdateHandle. Polling not yet implemented.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 13, 2023
1 parent c2281c5 commit 84b909c
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 31 deletions.
189 changes: 161 additions & 28 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,14 +1624,14 @@ async def update(
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> Any:
"""Send an update request to the workflow.
"""Send an update request to the workflow and wait for it to complete.
This will target the workflow with :py:attr:`run_id` if present. To use a
different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`.
.. warning::
Handles created as a result of :py:meth:`Client.start_workflow` will
signal the latest workflow with the same workflow ID even if it is
WorkflowHandles created as a result of :py:meth:`Client.start_workflow` will
send updates to the latest workflow with the same workflow ID even if it is
unrelated to the started workflow.
Args:
Expand All @@ -1645,6 +1645,55 @@ async def update(
client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for the RPC call.
Raises:
RPCError: There was some issue sending the update to the workflow.
"""
handle = await self.start_update(
update,
arg,
args=args,
id=id,
wait_for_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED,
result_type=result_type,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
)
return await handle.result()

async def start_update(
self,
update: Union[str, Callable],
arg: Any = temporalio.common._arg_unset,
*,
args: Sequence[Any] = [],
id: Optional[str] = None,
wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED,
result_type: Optional[Type] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> WorkflowUpdateHandle:
"""Send an update request to the workflow and return a handle to it.
This will target the workflow with :py:attr:`run_id` if present. To use a
different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`.
.. warning::
WorkflowHandles created as a result of :py:meth:`Client.start_workflow` will
send updates to the latest workflow with the same workflow ID even if it is
unrelated to the started workflow.
Args:
update: Update function or name on the workflow.
arg: Single argument to the update.
args: Multiple arguments to the update. Cannot be set if arg is.
id: ID of the update. If not set, the server will set a UUID as the ID.
wait_for_stage: Specifies at what point in the update request life cycle this request should return.
result_type: For string updates, this can set the specific result
type hint to deserialize into.
rpc_metadata: Headers used on the RPC call. Keys here override
client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for the RPC call.
Raises:
RPCError: There was some issue sending the update to the workflow.
"""
Expand All @@ -1665,7 +1714,7 @@ async def update(
else:
update_name = str(update)

return await self._client._impl.update_workflow(
return await self._client._impl.start_workflow_update(
UpdateWorkflowInput(
id=self._id,
run_id=self._run_id,
Expand All @@ -1676,6 +1725,7 @@ async def update(
ret_type=ret_type,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
wait_for_stage=wait_for_stage,
)
)

Expand Down Expand Up @@ -2763,7 +2813,9 @@ class ScheduleActionStartWorkflow(ScheduleAction):
headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]] = None

@staticmethod
def _from_proto(info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo) -> ScheduleActionStartWorkflow: # type: ignore[override]
def _from_proto(
info: temporalio.api.workflow.v1.NewWorkflowExecutionInfo,
) -> ScheduleActionStartWorkflow: # type: ignore[override]
return ScheduleActionStartWorkflow("<unset>", raw_info=info)

# Overload for no-param workflow
Expand Down Expand Up @@ -3731,6 +3783,82 @@ async def __anext__(self) -> ScheduleListDescription:
return ret


class WorkflowUpdateHandle:
"""Handle for a workflow update execution request."""

def __init__(
self,
client: Client,
id: str,
name: str,
workflow_id: str,
*,
run_id: Optional[str] = None,
result_type: Optional[Type] = None,
):
self._client = client
self._id = id
self._name = name
self._workflow_id = workflow_id
self._run_id = run_id
self._result_type = result_type
self._known_result = None

@property
def id(self) -> str:
"""ID of this Update request"""
return self._id

@property
def name(self) -> str:
"""The name of the Update being invoked"""
return self._name

@property
def workflow_id(self) -> str:
"""The ID of the Workflow targeted by this Update"""
return self._workflow_id

@property
def run_id(self) -> Optional[str]:
"""If specified, the specific run of the Workflow targeted by this Update"""
return self._run_id

async def result(
self,
*,
timeout: Optional[timedelta] = None,
rpc_metadata: Mapping[str, str] = None,
) -> Any:
outcome: temporalio.api.update.v1.Outcome
if self._known_result is not None:
outcome = self._known_result
else:
# TODO: This
raise NotImplementedError

if outcome.HasField("failure"):
raise WorkflowUpdateFailedError(
self.id,
self.name,
await self._client.data_converter.decode_failure(outcome.failure.cause),
)
if not outcome.success.payloads:
return None
type_hints = [self._result_type] if self._result_type else None
results = await self._client.data_converter.decode(
outcome.success.payloads, type_hints
)
if not results:
return None
elif len(results) > 1:
warnings.warn(f"Expected single update result, got {len(results)}")
return results[0]

def _set_known_result(self, result: temporalio.api.update.v1.Outcome) -> None:
self._known_result = result


class WorkflowFailureError(temporalio.exceptions.TemporalError):
"""Error that occurs when a workflow is unsuccessful."""

Expand Down Expand Up @@ -3939,13 +4067,14 @@ class TerminateWorkflowInput:
class UpdateWorkflowInput:
"""Input for :py:meth:`OutboundInterceptor.update_workflow`."""

# TODO: Wait policy

id: str
run_id: Optional[str]
update_id: str
update: str
args: Sequence[Any]
wait_for_stage: Optional[
temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage
]
headers: Mapping[str, temporalio.api.common.v1.Payload]
# Type may be absent
ret_type: Optional[Type]
Expand Down Expand Up @@ -4197,9 +4326,11 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
"""Called for every :py:meth:`WorkflowHandle.terminate` call."""
await self.next.terminate_workflow(input)

async def update_workflow(self, input: UpdateWorkflowInput) -> Any:
async def start_workflow_update(
self, input: UpdateWorkflowInput
) -> WorkflowUpdateHandle:
"""Called for every :py:meth:`WorkflowHandle.signal` call."""
return await self.next.update_workflow(input)
return await self.next.start_workflow_update(input)

### Async activity calls

Expand Down Expand Up @@ -4523,7 +4654,14 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
)

async def update_workflow(self, input: UpdateWorkflowInput) -> Any:
async def start_workflow_update(
self, input: UpdateWorkflowInput
) -> WorkflowUpdateHandle:
wait_policy = (
temporalio.api.update.v1.WaitPolicy(lifecycle_stage=input.wait_for_stage)
if input.wait_for_stage is not None
else None
)
req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest(
namespace=self._client.namespace,
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
Expand All @@ -4539,6 +4677,7 @@ async def update_workflow(self, input: UpdateWorkflowInput) -> Any:
name=input.update,
),
),
wait_policy=wait_policy,
)
if input.args:
req.request.input.args.payloads.extend(
Expand All @@ -4559,25 +4698,19 @@ async def update_workflow(self, input: UpdateWorkflowInput) -> Any:
raise WorkflowUpdateFailedError(input.id, input.update, err.cause)
else:
raise
if resp.outcome.HasField("failure"):
raise WorkflowUpdateFailedError(
input.id,
input.update,
await self._client.data_converter.decode_failure(
resp.outcome.failure.cause
),
)
if not resp.outcome.success.payloads:
return None
type_hints = [input.ret_type] if input.ret_type else None
results = await self._client.data_converter.decode(
resp.outcome.success.payloads, type_hints

update_handle = WorkflowUpdateHandle(
client=self._client,
id=input.update_id,
name=input.update,
workflow_id=input.id,
run_id=input.run_id,
result_type=input.ret_type,
)
if not results:
return None
elif len(results) > 1:
warnings.warn(f"Expected single update result, got {len(results)}")
return results[0]
if resp.HasField("outcome"):
update_handle._set_known_result(resp.outcome)

return update_handle

### Async activity calls

Expand Down
9 changes: 6 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
WorkflowHandle,
WorkflowQueryFailedError,
WorkflowQueryRejectedError,
WorkflowUpdateHandle,
_history_from_json,
)
from temporalio.common import RetryPolicy
Expand Down Expand Up @@ -401,9 +402,11 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
self._parent.traces.append(("terminate_workflow", input))
return await super().terminate_workflow(input)

async def update_workflow(self, input: UpdateWorkflowInput) -> Any:
self._parent.traces.append(("update_workflow", input))
return await super().update_workflow(input)
async def start_workflow_update(
self, input: UpdateWorkflowInput
) -> WorkflowUpdateHandle:
self._parent.traces.append(("start_workflow_update", input))
return await super().start_workflow_update(input)


async def test_interceptor(client: Client, worker: ExternalWorker):
Expand Down

0 comments on commit 84b909c

Please sign in to comment.