Skip to content

Commit

Permalink
Linting / mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 16, 2023
1 parent d1e0681 commit 3ba297c
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 17 deletions.
17 changes: 9 additions & 8 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import json
import re
import sys
import uuid
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -1668,7 +1669,7 @@ async def start_update(
*,
args: Sequence[Any] = [],
id: Optional[str] = None,
wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED,
wait_for_stage: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType = temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED,
result_type: Optional[Type] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
Expand Down Expand Up @@ -3803,7 +3804,7 @@ def __init__(
self._workflow_id = workflow_id
self._run_id = run_id
self._result_type = result_type
self._known_result = None
self._known_result: Optional[temporalio.api.update.v1.Outcome] = None

@property
def id(self) -> str:
Expand All @@ -3829,7 +3830,7 @@ async def result(
self,
*,
timeout: Optional[timedelta] = None,
rpc_metadata: Mapping[str, str] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> Any:
"""Wait for and return the result of the update. The result may already be known in which case no call is made.
Expand Down Expand Up @@ -4084,7 +4085,7 @@ class UpdateWorkflowInput:
update: str
args: Sequence[Any]
wait_for_stage: Optional[
temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage
temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType
]
headers: Mapping[str, temporalio.api.common.v1.Payload]
ret_type: Optional[Type]
Expand Down Expand Up @@ -4724,9 +4725,7 @@ async def start_workflow_update(
# If the status is INVALID_ARGUMENT, we can assume it's an update
# failed error
if err.status == RPCStatusCode.INVALID_ARGUMENT:
raise WorkflowUpdateFailedError(
input.workflow_id, input.update, err.cause
)
raise WorkflowUpdateFailedError(input.workflow_id, input.update, err)
else:
raise

Expand Down Expand Up @@ -4760,7 +4759,9 @@ async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any:
)
try:
# Wait for at most the *overall* timeout
async with asyncio.timeout(input.timeout.total_seconds()):
async with asyncio.timeout(
input.timeout.total_seconds() if input.timeout else sys.float_info.max
):
# Continue polling as long as we have either an empty response, or an *rpc* timeout
while True:
try:
Expand Down
18 changes: 14 additions & 4 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import opentelemetry.trace
import opentelemetry.trace.propagation.tracecontext
import opentelemetry.util.types
from client import PollUpdateWorkflowInput, WorkflowUpdateHandle
from typing_extensions import Protocol, TypeAlias, TypedDict

import temporalio.activity
Expand Down Expand Up @@ -244,16 +245,25 @@ async def signal_workflow(
):
return await super().signal_workflow(input)

async def update_workflow(
async def start_workflow_update(
self, input: temporalio.client.UpdateWorkflowInput
) -> Any:
) -> WorkflowUpdateHandle:
with self.root._start_as_current_span(
f"StartWorkflowUpdate:{input.update}",
attributes={"temporalWorkflowID": input.workflow_id},
input=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().start_workflow_update(input)

async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any:
with self.root._start_as_current_span(
f"UpdateWorkflow:{input.update}",
f"PollWorkflowUpdate:{input.update}",
attributes={"temporalWorkflowID": input.workflow_id},
input=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().update_workflow(input)
return await super().poll_workflow_update(input)


class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor):
Expand Down
3 changes: 2 additions & 1 deletion temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,15 @@ async def run_update(
accpetance_command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
) -> None:
command = accpetance_command
assert defn is not None
try:
if defn.validator is not None:
# Run the validator
await self._inbound.handle_update_validator(handler_input)

# Accept the update
command.update_response.accepted.SetInParent()
command = None
command = None # type: ignore

# Run the handler
success = await self._inbound.handle_update_handler(handler_input)
Expand Down
7 changes: 5 additions & 2 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,8 @@ def _update_validator(
update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None
):
"""Decorator for a workflow update validator method."""
update_def.set_validator(fn)
if fn is not None:
update_def.set_validator(fn)


def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None:
Expand Down Expand Up @@ -1375,7 +1376,9 @@ def bind_fn(self, obj: Any) -> Callable[..., Any]:
return _bind_method(obj, self.fn)

def bind_validator(self, obj: Any) -> Callable[..., Any]:
return _bind_method(obj, self.validator)
if self.validator is not None:
return _bind_method(obj, self.validator)
return lambda *args, **kwargs: None

def set_validator(self, validator: Callable[..., None]) -> None:
# TODO: Verify arg types are the same
Expand Down
2 changes: 2 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def test_workflow_defn_good():
name="base_query", fn=GoodDefnBase.base_query, is_method=True
),
},
# TODO: Add
updates={},
sandboxed=True,
)

Expand Down
6 changes: 4 additions & 2 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3541,7 +3541,7 @@ async def last_event_async(self, an_arg: str) -> str:

@workflow.update
async def runs_activity(self, name: str) -> str:
act = workflow.start_activity_method(
act = workflow.start_activity(
say_hello, name, schedule_to_close_timeout=timedelta(seconds=5)
)
act.cancel()
Expand All @@ -3565,7 +3565,9 @@ async def runs_activity(self, name: str) -> str:


async def test_workflow_update_handlers(client: Client):
async with new_worker(client, UpdateHandlersWorkflow) as worker:
async with new_worker(
client, UpdateHandlersWorkflow, activities=[say_hello]
) as worker:
handle = await client.start_workflow(
UpdateHandlersWorkflow.run,
id=f"update-handlers-workflow-{uuid.uuid4()}",
Expand Down

0 comments on commit 3ba297c

Please sign in to comment.