Skip to content

Commit

Permalink
Latest review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 25, 2023
1 parent 6670ee5 commit 1731fcf
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 24 deletions.
File renamed without changes.
32 changes: 11 additions & 21 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,7 @@ async def execute_update(
rpc_timeout: Optional RPC deadline to set for the RPC call.
Raises:
WorkflowUpdateFailedError: If the update failed
RPCError: There was some issue sending the update to the workflow.
"""
handle = await self._start_update(
Expand Down Expand Up @@ -3928,7 +3929,6 @@ def __init__(
self,
client: Client,
id: str,
name: str,
workflow_id: str,
*,
workflow_run_id: Optional[str] = None,
Expand All @@ -3941,22 +3941,16 @@ def __init__(
"""
self._client = client
self._id = id
self._name = name
self._workflow_id = workflow_id
self._workflow_run_id = workflow_run_id
self._result_type = result_type
self._known_result: Optional[temporalio.api.update.v1.Outcome] = None
self._known_outcome: Optional[temporalio.api.update.v1.Outcome] = None

@property
def id(self) -> str:
"""ID of this Update request."""
return self._id

@property
def name(self) -> str:
"""The name of the Update being invoked."""
return self._name

@property
def workflow_id(self) -> str:
"""The ID of the Workflow targeted by this Update."""
Expand All @@ -3974,8 +3968,9 @@ async def result(
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> LocalReturnType:
"""Wait for and return the result of the update. The result may already be known in which case no call is made.
Otherwise the result will be polled for until returned, or until the provided timeout is reached, if specified.
"""Wait for and return the result of the update. The result may already be known in which case no network call
is made. Otherwise the result will be polled for until returned, or until the provided timeout is reached, if
specified.
Args:
timeout: Optional timeout specifying maximum wait time for the result.
Expand All @@ -3984,15 +3979,15 @@ async def result(
overall timeout has been reached.
Raises:
WorkflowUpdateFailedError: If the update failed
TimeoutError: The specified timeout was reached when waiting for the update result.
RPCError: Update result could not be fetched for some other reason.
"""
if self._known_result is not None:
outcome = self._known_result
if self._known_outcome is not None:
outcome = self._known_outcome
return await _update_outcome_to_result(
outcome,
self.id,
self.name,
self._client.data_converter,
self._result_type,
)
Expand All @@ -4002,7 +3997,6 @@ async def result(
self.workflow_id,
self.workflow_run_id,
self.id,
self.name,
timeout,
self._result_type,
rpc_metadata,
Expand Down Expand Up @@ -4238,7 +4232,6 @@ class PollWorkflowUpdateInput:
workflow_id: str
run_id: Optional[str]
update_id: str
update: str
timeout: Optional[timedelta]
ret_type: Optional[Type]
rpc_metadata: Mapping[str, str]
Expand Down Expand Up @@ -4491,7 +4484,7 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:

async def start_workflow_update(
self, input: StartWorkflowUpdateInput
) -> WorkflowUpdateHandle:
) -> WorkflowUpdateHandle[Any]:
"""Called for every :py:meth:`WorkflowHandle.update` and :py:meth:`WorkflowHandle.start_update` call."""
return await self.next.start_workflow_update(input)

Expand Down Expand Up @@ -4823,7 +4816,7 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:

async def start_workflow_update(
self, input: StartWorkflowUpdateInput
) -> WorkflowUpdateHandle:
) -> WorkflowUpdateHandle[Any]:
wait_policy = (
temporalio.api.update.v1.WaitPolicy(lifecycle_stage=input.wait_for_stage)
if input.wait_for_stage is not None
Expand Down Expand Up @@ -4865,13 +4858,12 @@ async def start_workflow_update(
update_handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle(
client=self._client,
id=determined_id,
name=input.update,
workflow_id=input.id,
workflow_run_id=input.run_id,
result_type=input.ret_type,
)
if resp.HasField("outcome"):
update_handle._known_result = resp.outcome
update_handle._known_outcome = resp.outcome

return update_handle

Expand Down Expand Up @@ -4905,7 +4897,6 @@ async def poll_loop():
return await _update_outcome_to_result(
res.outcome,
input.update_id,
input.update,
self._client.data_converter,
input.ret_type,
)
Expand Down Expand Up @@ -5449,7 +5440,6 @@ def _fix_history_enum(prefix: str, parent: Dict[str, Any], *attrs: str) -> None:
async def _update_outcome_to_result(
outcome: temporalio.api.update.v1.Outcome,
id: str,
name: str,
converter: temporalio.converter.DataConverter,
rtype: Optional[Type],
) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def signal_workflow(

async def start_workflow_update(
self, input: temporalio.client.StartWorkflowUpdateInput
) -> temporalio.client.WorkflowUpdateHandle:
) -> temporalio.client.WorkflowUpdateHandle[Any]:
with self.root._start_as_current_span(
f"StartWorkflowUpdate:{input.update}",
attributes={"temporalWorkflowID": input.id},
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,13 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:

async def start_workflow_update(
self, input: StartWorkflowUpdateInput
) -> WorkflowUpdateHandle:
) -> WorkflowUpdateHandle[Any]:
self._parent.traces.append(("start_workflow_update", input))
return await super().start_workflow_update(input)

async def poll_workflow_update(
self, input: PollWorkflowUpdateInput
) -> WorkflowUpdateHandle:
) -> WorkflowUpdateHandle[Any]:
self._parent.traces.append(("poll_workflow_update", input))
return await super().poll_workflow_update(input)

Expand Down

0 comments on commit 1731fcf

Please sign in to comment.