Skip to content

Commit

Permalink
Create commands after payload conversion (#591)
Browse files Browse the repository at this point in the history
Fixes #540
Fixes #564
  • Loading branch information
cretz authored Aug 9, 2024
1 parent 4b93d1a commit a5b9661
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 44 deletions.
102 changes: 63 additions & 39 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,6 @@ def _apply_query_workflow(
) -> None:
# Wrap entire bunch of work in a task
async def run_query() -> None:
command = self._add_command()
command.respond_to_query.query_id = job.query_id
try:
with self._as_read_only():
# Named query or dynamic
Expand Down Expand Up @@ -632,11 +630,13 @@ async def run_query() -> None:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
)
command.respond_to_query.succeeded.response.CopyFrom(
result_payloads[0]
)
command = self._add_command()
command.respond_to_query.query_id = job.query_id
command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0])
except Exception as err:
try:
command = self._add_command()
command.respond_to_query.query_id = job.query_id
self._failure_converter.to_failure(
err,
self._payload_converter,
Expand Down Expand Up @@ -1427,7 +1427,7 @@ async def run_activity() -> Any:
await asyncio.sleep(
err.backoff.backoff_duration.ToTimedelta().total_seconds()
)
handle._apply_schedule_command(self._add_command(), err.backoff)
handle._apply_schedule_command(err.backoff)
# We have to put the handle back on the pending activity
# dict with its new seq
self._pending_activities[handle._seq] = handle
Expand All @@ -1437,35 +1437,41 @@ async def run_activity() -> Any:

# Create the handle and set as pending
handle = _ActivityHandle(self, input, run_activity())
handle._apply_schedule_command(self._add_command())
handle._apply_schedule_command()
self._pending_activities[handle._seq] = handle
return handle

async def _outbound_signal_child_workflow(
self, input: SignalChildWorkflowInput
) -> None:
payloads = (
self._payload_converter.to_payloads(input.args) if input.args else None
)
command = self._add_command()
v = command.signal_external_workflow_execution
v.child_workflow_id = input.child_workflow_id
v.signal_name = input.signal
if input.args:
v.args.extend(self._payload_converter.to_payloads(input.args))
if payloads:
v.args.extend(payloads)
if input.headers:
temporalio.common._apply_headers(input.headers, v.headers)
await self._signal_external_workflow(command)

async def _outbound_signal_external_workflow(
self, input: SignalExternalWorkflowInput
) -> None:
payloads = (
self._payload_converter.to_payloads(input.args) if input.args else None
)
command = self._add_command()
v = command.signal_external_workflow_execution
v.workflow_execution.namespace = input.namespace
v.workflow_execution.workflow_id = input.workflow_id
if input.workflow_run_id:
v.workflow_execution.run_id = input.workflow_run_id
v.signal_name = input.signal
if input.args:
v.args.extend(self._payload_converter.to_payloads(input.args))
if payloads:
v.args.extend(payloads)
if input.headers:
temporalio.common._apply_headers(input.headers, v.headers)
await self._signal_external_workflow(command)
Expand Down Expand Up @@ -1510,7 +1516,7 @@ async def run_child() -> Any:
handle = _ChildWorkflowHandle(
self, self._next_seq("child_workflow"), input, run_child()
)
handle._apply_start_command(self._add_command())
handle._apply_start_command()
self._pending_child_workflows[handle._seq] = handle

# Wait on start before returning
Expand Down Expand Up @@ -1761,7 +1767,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
await coro
except _ContinueAsNewError as err:
logger.debug("Workflow requested continue as new")
err._apply_command(self._add_command())
err._apply_command()
except (Exception, asyncio.CancelledError) as err:
# During tear down we can ignore exceptions. Technically the
# command-adding done later would throw a not-in-workflow exception
Expand All @@ -1776,7 +1782,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
# Handle continue as new
if isinstance(err, _ContinueAsNewError):
logger.debug("Workflow requested continue as new")
err._apply_command(self._add_command())
err._apply_command()
return

logger.debug(
Expand Down Expand Up @@ -2261,11 +2267,18 @@ def _resolve_backoff(

def _apply_schedule_command(
self,
command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
local_backoff: Optional[
temporalio.bridge.proto.activity_result.DoBackoff
] = None,
) -> None:
# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)

command = self._instance._add_command()
# TODO(cretz): Why can't MyPy infer this?
v: Union[
temporalio.bridge.proto.workflow_commands.ScheduleActivity,
Expand All @@ -2280,10 +2293,8 @@ def _apply_schedule_command(
v.activity_type = self._input.activity
if self._input.headers:
temporalio.common._apply_headers(self._input.headers, v.headers)
if self._input.args:
v.arguments.extend(
self._instance._payload_converter.to_payloads(self._input.args)
)
if payloads:
v.arguments.extend(payloads)
if self._input.schedule_to_close_timeout:
v.schedule_to_close_timeout.FromTimedelta(
self._input.schedule_to_close_timeout
Expand Down Expand Up @@ -2403,20 +2414,23 @@ def _resolve_failure(self, err: BaseException) -> None:
# future
self._result_fut.set_result(None)

def _apply_start_command(
self,
command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
) -> None:
def _apply_start_command(self) -> None:
# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)

command = self._instance._add_command()
v = command.start_child_workflow_execution
v.seq = self._seq
v.namespace = self._instance._info.namespace
v.workflow_id = self._input.id
v.workflow_type = self._input.workflow
v.task_queue = self._input.task_queue or self._instance._info.task_queue
if self._input.args:
v.input.extend(
self._instance._payload_converter.to_payloads(self._input.args)
)
if payloads:
v.input.extend(payloads)
if self._input.execution_timeout:
v.workflow_execution_timeout.FromTimedelta(self._input.execution_timeout)
if self._input.run_timeout:
Expand Down Expand Up @@ -2520,19 +2534,31 @@ def __init__(
self._instance = instance
self._input = input

def _apply_command(
self, command: temporalio.bridge.proto.workflow_commands.WorkflowCommand
) -> None:
def _apply_command(self) -> None:
# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)
memo_payloads = (
{
k: self._instance._payload_converter.to_payloads([val])[0]
for k, val in self._input.memo.items()
}
if self._input.memo
else None
)

command = self._instance._add_command()
v = command.continue_as_new_workflow_execution
v.SetInParent()
if self._input.workflow:
v.workflow_type = self._input.workflow
if self._input.task_queue:
v.task_queue = self._input.task_queue
if self._input.args:
v.arguments.extend(
self._instance._payload_converter.to_payloads(self._input.args)
)
if payloads:
v.arguments.extend(payloads)
if self._input.run_timeout:
v.workflow_run_timeout.FromTimedelta(self._input.run_timeout)
if self._input.task_timeout:
Expand All @@ -2541,11 +2567,9 @@ def _apply_command(
temporalio.common._apply_headers(self._input.headers, v.headers)
if self._input.retry_policy:
self._input.retry_policy.apply_to_proto(v.retry_policy)
if self._input.memo:
for k, val in self._input.memo.items():
v.memo[k].CopyFrom(
self._instance._payload_converter.to_payloads([val])[0]
)
if memo_payloads:
for k, val in memo_payloads.items():
v.memo[k].CopyFrom(val)
if self._input.search_attributes:
_encode_search_attributes(
self._input.search_attributes, v.search_attributes
Expand Down
56 changes: 51 additions & 5 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3350,15 +3350,27 @@ async def test_workflow_optional_param(client: Client):


class ExceptionRaisingPayloadConverter(DefaultPayloadConverter):
bad_str = "bad-payload-str"
bad_outbound_str = "bad-outbound-payload-str"
bad_inbound_str = "bad-inbound-payload-str"

def to_payloads(self, values: Sequence[Any]) -> List[Payload]:
if any(
value == ExceptionRaisingPayloadConverter.bad_outbound_str
for value in values
):
raise ApplicationError("Intentional outbound converter failure")
return super().to_payloads(values)

def from_payloads(
self, payloads: Sequence[Payload], type_hints: Optional[List] = None
) -> List[Any]:
# Check if any payloads contain the bad data
for payload in payloads:
if ExceptionRaisingPayloadConverter.bad_str.encode() in payload.data:
raise ApplicationError("Intentional converter failure")
if (
ExceptionRaisingPayloadConverter.bad_inbound_str.encode()
in payload.data
):
raise ApplicationError("Intentional inbound converter failure")
return super().from_payloads(payloads, type_hints)


Expand All @@ -3383,12 +3395,46 @@ async def test_exception_raising_converter_param(client: Client):
with pytest.raises(WorkflowFailureError) as err:
await client.execute_workflow(
ExceptionRaisingConverterWorkflow.run,
ExceptionRaisingPayloadConverter.bad_str,
ExceptionRaisingPayloadConverter.bad_inbound_str,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
assert isinstance(err.value.cause, ApplicationError)
assert "Intentional converter failure" in str(err.value.cause)
assert "Intentional inbound converter failure" in str(err.value.cause)


@workflow.defn
class ActivityOutboundConversionFailureWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.execute_activity(
"some-activity",
ExceptionRaisingPayloadConverter.bad_outbound_str,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_workflow_activity_outbound_conversion_failure(client: Client):
# This test used to fail because we created commands _before_ we attempted
# to convert the arguments thereby causing half-built commands to get sent
# to the server.

# Clone the client but change the data converter to use our converter
config = client.config()
config["data_converter"] = dataclasses.replace(
config["data_converter"],
payload_converter_class=ExceptionRaisingPayloadConverter,
)
client = Client(**config)
async with new_worker(client, ActivityOutboundConversionFailureWorkflow) as worker:
with pytest.raises(WorkflowFailureError) as err:
await client.execute_workflow(
ActivityOutboundConversionFailureWorkflow.run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
assert isinstance(err.value.cause, ApplicationError)
assert "Intentional outbound converter failure" in str(err.value.cause)


@dataclass
Expand Down

0 comments on commit a5b9661

Please sign in to comment.