diff --git a/README.md b/README.md index 4e82dad2..1bcd2f71 100644 --- a/README.md +++ b/README.md @@ -722,7 +722,8 @@ While running in a workflow, in addition to features documented elsewhere, the f #### Exceptions -* Workflows can raise exceptions to fail the workflow or the "workflow task" (i.e. suspend the workflow retrying). +* Workflows/updates can raise exceptions to fail the workflow or the "workflow task" (i.e. suspend the workflow + in a retrying state). * Exceptions that are instances of `temporalio.exceptions.FailureError` will fail the workflow with that exception * For failing the workflow explicitly with a user exception, use `temporalio.exceptions.ApplicationError`. This can be marked non-retryable or include details as needed. @@ -732,6 +733,13 @@ While running in a workflow, in addition to features documented elsewhere, the f fixed. This is helpful for bad code or other non-predictable exceptions. To actually fail the workflow, use an `ApplicationError` as mentioned above. +This default can be changed by providing a list of exception types to `workflow_failure_exception_types` when creating a +`Worker` or `failure_exception_types` on the `@workflow.defn` decorator. If a workflow-thrown exception is an instance +of any type in either list, it will fail the workflow instead of the task. This means a value of `[Exception]` will +cause every exception to fail the workflow instead of the task. Also, as a special case, if +`temporalio.workflow.NondeterminismError` (or any superclass of it) is set, non-deterministic exceptions will fail the +workflow. WARNING: These settings are experimental. + #### External Workflows * `workflow.get_external_workflow_handle()` inside a workflow returns a handle to interact with another workflow diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 151d7ce9..f388d8ea 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -2,10 +2,13 @@ use prost::Message; use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyTuple}; +use std::collections::HashMap; +use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; use temporal_sdk_core::api::errors::{PollActivityError, PollWfError}; use temporal_sdk_core::replay::{HistoryForReplay, ReplayWorkerInput}; +use temporal_sdk_core_api::errors::WorkflowErrorType; use temporal_sdk_core_api::Worker; use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion; use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion}; @@ -45,6 +48,8 @@ pub struct WorkerConfig { max_task_queue_activities_per_second: Option, graceful_shutdown_period_millis: u64, use_worker_versioning: bool, + nondeterminism_as_workflow_fail: bool, + nondeterminism_as_workflow_fail_for_types: HashSet, } macro_rules! enter_sync { @@ -234,6 +239,22 @@ impl TryFrom for temporal_sdk_core::WorkerConfig { // always set it even if 0. .graceful_shutdown_period(Duration::from_millis(conf.graceful_shutdown_period_millis)) .use_worker_versioning(conf.use_worker_versioning) + .workflow_failure_errors(if conf.nondeterminism_as_workflow_fail { + HashSet::from([WorkflowErrorType::Nondeterminism]) + } else { + HashSet::new() + }) + .workflow_types_to_failure_errors( + conf.nondeterminism_as_workflow_fail_for_types + .iter() + .map(|s| { + ( + s.to_owned(), + HashSet::from([WorkflowErrorType::Nondeterminism]), + ) + }) + .collect::>>(), + ) .build() .map_err(|err| PyValueError::new_err(format!("Invalid worker config: {}", err))) } diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 578d3fa6..842bee3c 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -6,7 +6,16 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Sequence, Tuple +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + List, + Optional, + Sequence, + Set, + Tuple, +) import google.protobuf.internal.containers from typing_extensions import TypeAlias @@ -48,6 +57,8 @@ class WorkerConfig: max_task_queue_activities_per_second: Optional[float] graceful_shutdown_period_millis: int use_worker_versioning: bool + nondeterminism_as_workflow_fail: bool + nondeterminism_as_workflow_fail_for_types: Set[str] class Worker: diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 05ca56f1..d7d540fe 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -43,6 +43,7 @@ def __init__( interceptors: Sequence[Interceptor] = [], build_id: Optional[str] = None, identity: Optional[str] = None, + workflow_failure_exception_types: Sequence[Type[BaseException]] = [], debug_mode: bool = False, runtime: Optional[temporalio.runtime.Runtime] = None, disable_safe_workflow_eviction: bool = False, @@ -66,6 +67,7 @@ def __init__( interceptors=interceptors, build_id=build_id, identity=identity, + workflow_failure_exception_types=workflow_failure_exception_types, debug_mode=debug_mode, runtime=runtime, disable_safe_workflow_eviction=disable_safe_workflow_eviction, @@ -153,35 +155,6 @@ async def workflow_replay_iterator( An async iterator that returns replayed workflow results as they are replayed. """ - # Create bridge worker - task_queue = f"replay-{self._config['build_id']}" - runtime = self._config["runtime"] or temporalio.runtime.Runtime.default() - bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay( - runtime._core_runtime, - temporalio.bridge.worker.WorkerConfig( - namespace=self._config["namespace"], - task_queue=task_queue, - build_id=self._config["build_id"] or load_default_build_id(), - identity_override=self._config["identity"], - # All values below are ignored but required by Core - max_cached_workflows=2, - max_outstanding_workflow_tasks=2, - max_outstanding_activities=1, - max_outstanding_local_activities=1, - max_concurrent_workflow_task_polls=1, - nonsticky_to_sticky_poll_ratio=1, - max_concurrent_activity_task_polls=1, - no_remote_activities=True, - sticky_queue_schedule_to_start_timeout_millis=1000, - max_heartbeat_throttle_interval_millis=1000, - default_heartbeat_throttle_interval_millis=1000, - max_activities_per_second=None, - max_task_queue_activities_per_second=None, - graceful_shutdown_period_millis=0, - use_worker_versioning=False, - ), - ) - try: last_replay_failure: Optional[Exception] last_replay_complete = asyncio.Event() @@ -212,29 +185,62 @@ def on_eviction_hook( last_replay_failure = None last_replay_complete.set() - # Start the worker - workflow_worker_task = asyncio.create_task( - _WorkflowWorker( - bridge_worker=lambda: bridge_worker, + # Create worker referencing bridge worker + bridge_worker: temporalio.bridge.worker.Worker + task_queue = f"replay-{self._config['build_id']}" + runtime = self._config["runtime"] or temporalio.runtime.Runtime.default() + workflow_worker = _WorkflowWorker( + bridge_worker=lambda: bridge_worker, + namespace=self._config["namespace"], + task_queue=task_queue, + workflows=self._config["workflows"], + workflow_task_executor=self._config["workflow_task_executor"], + workflow_runner=self._config["workflow_runner"], + unsandboxed_workflow_runner=self._config["unsandboxed_workflow_runner"], + data_converter=self._config["data_converter"], + interceptors=self._config["interceptors"], + workflow_failure_exception_types=self._config[ + "workflow_failure_exception_types" + ], + debug_mode=self._config["debug_mode"], + metric_meter=runtime.metric_meter, + on_eviction_hook=on_eviction_hook, + disable_eager_activity_execution=False, + disable_safe_eviction=self._config["disable_safe_workflow_eviction"], + ) + # Create bridge worker + bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay( + runtime._core_runtime, + temporalio.bridge.worker.WorkerConfig( namespace=self._config["namespace"], task_queue=task_queue, - workflows=self._config["workflows"], - workflow_task_executor=self._config["workflow_task_executor"], - workflow_runner=self._config["workflow_runner"], - unsandboxed_workflow_runner=self._config[ - "unsandboxed_workflow_runner" - ], - data_converter=self._config["data_converter"], - interceptors=self._config["interceptors"], - debug_mode=self._config["debug_mode"], - metric_meter=runtime.metric_meter, - on_eviction_hook=on_eviction_hook, - disable_eager_activity_execution=False, - disable_safe_eviction=self._config[ - "disable_safe_workflow_eviction" - ], - ).run() + build_id=self._config["build_id"] or load_default_build_id(), + identity_override=self._config["identity"], + # Need to tell core whether we want to consider all + # non-determinism exceptions as workflow fail, and whether we do + # per workflow type + nondeterminism_as_workflow_fail=workflow_worker.nondeterminism_as_workflow_fail(), + nondeterminism_as_workflow_fail_for_types=workflow_worker.nondeterminism_as_workflow_fail_for_types(), + # All values below are ignored but required by Core + max_cached_workflows=2, + max_outstanding_workflow_tasks=2, + max_outstanding_activities=1, + max_outstanding_local_activities=1, + max_concurrent_workflow_task_polls=1, + nonsticky_to_sticky_poll_ratio=1, + max_concurrent_activity_task_polls=1, + no_remote_activities=True, + sticky_queue_schedule_to_start_timeout_millis=1000, + max_heartbeat_throttle_interval_millis=1000, + default_heartbeat_throttle_interval_millis=1000, + max_activities_per_second=None, + max_task_queue_activities_per_second=None, + graceful_shutdown_period_millis=0, + use_worker_versioning=False, + ), ) + # Start worker + workflow_worker_task = asyncio.create_task(workflow_worker.run()) # Yield iterator async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]: @@ -301,6 +307,7 @@ class ReplayerConfig(TypedDict, total=False): interceptors: Sequence[Interceptor] build_id: Optional[str] identity: Optional[str] + workflow_failure_exception_types: Sequence[Type[BaseException]] debug_mode: bool runtime: Optional[temporalio.runtime.Runtime] disable_safe_workflow_eviction: bool diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 7d40a28e..107229e5 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -73,6 +73,7 @@ def __init__( max_activities_per_second: Optional[float] = None, max_task_queue_activities_per_second: Optional[float] = None, graceful_shutdown_timeout: timedelta = timedelta(), + workflow_failure_exception_types: Sequence[Type[BaseException]] = [], shared_state_manager: Optional[SharedStateManager] = None, debug_mode: bool = False, disable_eager_activity_execution: bool = False, @@ -167,6 +168,13 @@ def __init__( graceful_shutdown_timeout: Amount of time after shutdown is called that activities are given to complete before their tasks are cancelled. + workflow_failure_exception_types: The types of exceptions that, if a + workflow-thrown exception extends, will cause the + workflow/update to fail instead of suspending the workflow via + task failure. These are applied in addition to ones set on the + ``workflow.defn`` decorator. If ``Exception`` is set, it + effectively will fail a workflow/update in all user exception + cases. WARNING: This setting is experimental. shared_state_manager: Used for obtaining cross-process friendly synchronization primitives. This is required for non-async activities where the activity_executor is not a @@ -258,6 +266,7 @@ def __init__( max_activities_per_second=max_activities_per_second, max_task_queue_activities_per_second=max_task_queue_activities_per_second, graceful_shutdown_timeout=graceful_shutdown_timeout, + workflow_failure_exception_types=workflow_failure_exception_types, shared_state_manager=shared_state_manager, debug_mode=debug_mode, disable_eager_activity_execution=disable_eager_activity_execution, @@ -309,6 +318,7 @@ def __init__( unsandboxed_workflow_runner=unsandboxed_workflow_runner, data_converter=client_config["data_converter"], interceptors=interceptors, + workflow_failure_exception_types=workflow_failure_exception_types, debug_mode=debug_mode, disable_eager_activity_execution=disable_eager_activity_execution, metric_meter=runtime.metric_meter, @@ -366,6 +376,14 @@ def __init__( 1000 * graceful_shutdown_timeout.total_seconds() ), use_worker_versioning=use_worker_versioning, + # Need to tell core whether we want to consider all + # non-determinism exceptions as workflow fail, and whether we do + # per workflow type + nondeterminism_as_workflow_fail=self._workflow_worker is not None + and self._workflow_worker.nondeterminism_as_workflow_fail(), + nondeterminism_as_workflow_fail_for_types=self._workflow_worker.nondeterminism_as_workflow_fail_for_types() + if self._workflow_worker + else set(), ), ) @@ -605,6 +623,7 @@ class WorkerConfig(TypedDict, total=False): max_activities_per_second: Optional[float] max_task_queue_activities_per_second: Optional[float] graceful_shutdown_timeout: timedelta + workflow_failure_exception_types: Sequence[Type[BaseException]] shared_state_manager: Optional[SharedStateManager] debug_mode: bool disable_eager_activity_execution: bool diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index e4ffd4d7..634f59a4 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -7,7 +7,7 @@ import logging import os from datetime import timezone -from typing import Callable, Dict, List, MutableMapping, Optional, Sequence, Type +from typing import Callable, Dict, List, MutableMapping, Optional, Sequence, Set, Type import temporalio.activity import temporalio.api.common.v1 @@ -52,6 +52,7 @@ def __init__( unsandboxed_workflow_runner: WorkflowRunner, data_converter: temporalio.converter.DataConverter, interceptors: Sequence[Interceptor], + workflow_failure_exception_types: Sequence[Type[BaseException]], debug_mode: bool, disable_eager_activity_execution: bool, metric_meter: temporalio.common.MetricMeter, @@ -89,6 +90,7 @@ def __init__( self._extern_functions.update( **_WorkflowExternFunctions(__temporal_get_metric_meter=lambda: metric_meter) ) + self._workflow_failure_exception_types = workflow_failure_exception_types self._running_workflows: Dict[str, WorkflowInstance] = {} self._disable_eager_activity_execution = disable_eager_activity_execution self._on_eviction_hook = on_eviction_hook @@ -104,6 +106,11 @@ def __init__( # Keep track of workflows that could not be evicted self._could_not_evict_count = 0 + # Set the worker-level failure exception types into the runner + workflow_runner.set_worker_level_failure_exception_types( + workflow_failure_exception_types + ) + # Validate and build workflow dict self._workflows: Dict[str, temporalio.workflow._Definition] = {} self._dynamic_workflow: Optional[temporalio.workflow._Definition] = None @@ -389,8 +396,25 @@ def _create_workflow_instance( randomness_seed=start.randomness_seed, extern_functions=self._extern_functions, disable_eager_activity_execution=self._disable_eager_activity_execution, + worker_level_failure_exception_types=self._workflow_failure_exception_types, ) if defn.sandboxed: return self._workflow_runner.create_instance(det) else: return self._unsandboxed_workflow_runner.create_instance(det) + + def nondeterminism_as_workflow_fail(self) -> bool: + return any( + issubclass(temporalio.workflow.NondeterminismError, typ) + for typ in self._workflow_failure_exception_types + ) + + def nondeterminism_as_workflow_fail_for_types(self) -> Set[str]: + return set( + k + for k, v in self._workflows.items() + if any( + issubclass(temporalio.workflow.NondeterminismError, typ) + for typ in v.failure_exception_types + ) + ) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 39536121..4a8493d6 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -107,6 +107,17 @@ def create_instance(self, det: WorkflowInstanceDetails) -> WorkflowInstance: """ raise NotImplementedError + def set_worker_level_failure_exception_types( + self, types: Sequence[Type[BaseException]] + ) -> None: + """Set worker-level failure exception types that will be used to + validate in the sandbox when calling ``prepare_workflow``. + + Args: + types: Exception types. + """ + pass + @dataclass(frozen=True) class WorkflowInstanceDetails: @@ -120,6 +131,7 @@ class WorkflowInstanceDetails: randomness_seed: int extern_functions: Mapping[str, Callable] disable_eager_activity_execution: bool + worker_level_failure_exception_types: Sequence[Type[BaseException]] class WorkflowInstance(ABC): @@ -177,6 +189,9 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._info = det.info self._extern_functions = det.extern_functions self._disable_eager_activity_execution = det.disable_eager_activity_execution + self._worker_level_failure_exception_types = ( + det.worker_level_failure_exception_types + ) self._primary_task: Optional[asyncio.Task[None]] = None self._time_ns = 0 self._cancel_requested = False @@ -315,10 +330,10 @@ def activate( # index). self._run_once(check_conditions=index == 1 or index == 2) except Exception as err: - # We want failure errors during activation, like those that can - # happen during payload conversion, to fail the workflow not the + # We want some errors during activation, like those that can happen + # during payload conversion, to be able to fail the workflow not the # task - if isinstance(err, temporalio.exceptions.FailureError): + if self._is_workflow_failure_exception(err): try: self._set_workflow_failure(err) except Exception as inner_err: @@ -515,12 +530,10 @@ async def run_update() -> None: if isinstance(err, temporalio.workflow.ReadOnlyContextError): self._current_activation_error = err return - # Temporal errors always fail the update. Other errors fail it during validation, but the task during - # handling. - if ( - isinstance(err, temporalio.exceptions.FailureError) - or not past_validation - ): + # Validation failures are always update failures. We reuse + # workflow failure logic to decide task failure vs update + # failure after validation. + if not past_validation or self._is_workflow_failure_exception(err): if command is None: command = self._add_command() command.update_response.protocol_instance_id = ( @@ -1549,6 +1562,19 @@ def _convert_payloads( except Exception as err: raise RuntimeError("Failed decoding arguments") from err + def _is_workflow_failure_exception(self, err: BaseException) -> bool: + # An exception is a failure instead of a task fail if it's already a + # failure error or if it is an instance of any of the failure types in + # the worker or workflow-level setting + return ( + isinstance(err, temporalio.exceptions.FailureError) + or any(isinstance(err, typ) for typ in self._defn.failure_exception_types) + or any( + isinstance(err, typ) + for typ in self._worker_level_failure_exception_types + ) + ) + def _next_seq(self, type: str) -> int: seq = self._curr_seqs.get(type, 0) + 1 self._curr_seqs[type] = seq @@ -1705,14 +1731,14 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None: err ): self._add_command().cancel_workflow_execution.SetInParent() - elif isinstance(err, temporalio.exceptions.FailureError): + elif self._is_workflow_failure_exception(err): # All other failure errors fail the workflow self._set_workflow_failure(err) else: # All other exceptions fail the task self._current_activation_error = err - def _set_workflow_failure(self, err: temporalio.exceptions.FailureError) -> None: + def _set_workflow_failure(self, err: BaseException) -> None: # All other failure errors fail the workflow failure = self._add_command().fail_workflow_execution.failure failure.SetInParent() diff --git a/temporalio/worker/workflow_sandbox/_in_sandbox.py b/temporalio/worker/workflow_sandbox/_in_sandbox.py index d918d19a..3091cef1 100644 --- a/temporalio/worker/workflow_sandbox/_in_sandbox.py +++ b/temporalio/worker/workflow_sandbox/_in_sandbox.py @@ -6,7 +6,7 @@ import dataclasses import logging -from typing import Type +from typing import Any, Type import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_completion @@ -39,7 +39,37 @@ def __init__( # class. We can't use the definition that was given to us because it has # type hints and references to outside-of-sandbox types. new_defn = temporalio.workflow._Definition.must_from_class(workflow_class) - new_instance_details = dataclasses.replace(instance_details, defn=new_defn) + + # Also, we have to re-import the worker-level exception types, because + # some exceptions are not passthrough and therefore our issubclass fails + # because it'll be comparing out-of-sandbox types with in-sandbox types + exception_types = instance_details.worker_level_failure_exception_types + if exception_types: + # Copy first, then add in-sandbox types appended + exception_types = list(exception_types) + # Try to re-import each + for typ in instance_details.worker_level_failure_exception_types: + try: + class_hier = typ.__qualname__.split(".") + module = __import__(typ.__module__, fromlist=[class_hier[0]]) + reimported_type: Any = module + for name in class_hier: + reimported_type = getattr(reimported_type, name) + if not issubclass(reimported_type, BaseException): + raise TypeError( + f"Final imported type of {reimported_type} does not extend BaseException" + ) + exception_types.append(reimported_type) + except Exception as err: + raise TypeError( + f"Failed to re-import workflow exception failure type {typ} in sandbox" + ) from err + + new_instance_details = dataclasses.replace( + instance_details, + defn=new_defn, + worker_level_failure_exception_types=exception_types, + ) # Instantiate the runner and the instance self.instance = runner_class().create_instance(new_instance_details) diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index 86b32400..aa882b9e 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -7,7 +7,7 @@ from __future__ import annotations from datetime import datetime, timedelta, timezone -from typing import Any, Type +from typing import Any, Sequence, Type import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_completion @@ -68,6 +68,7 @@ def __init__( super().__init__() self._runner_class = runner_class self._restrictions = restrictions + self._worker_level_failure_exception_types: Sequence[type[BaseException]] = [] def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None: """Implements :py:meth:`WorkflowRunner.prepare_workflow`.""" @@ -83,6 +84,7 @@ def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None: randomness_seed=-1, extern_functions={}, disable_eager_activity_execution=False, + worker_level_failure_exception_types=self._worker_level_failure_exception_types, ), ) @@ -90,6 +92,12 @@ def create_instance(self, det: WorkflowInstanceDetails) -> WorkflowInstance: """Implements :py:meth:`WorkflowRunner.create_instance`.""" return _Instance(det, self._runner_class, self._restrictions) + def set_worker_level_failure_exception_types( + self, types: Sequence[type[BaseException]] + ) -> None: + """Implements :py:meth:`WorkflowRunner.set_worker_level_failure_exception_types`.""" + self._worker_level_failure_exception_types = types + # Implements in_sandbox._ExternEnvironment. Some of these calls are called from # within the sandbox. diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 230c5b76..91d2486c 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -83,7 +83,10 @@ def defn(cls: ClassType) -> ClassType: @overload def defn( - *, name: Optional[str] = None, sandboxed: bool = True + *, + name: Optional[str] = None, + sandboxed: bool = True, + failure_exception_types: Sequence[Type[BaseException]] = [], ) -> Callable[[ClassType], ClassType]: ... @@ -101,6 +104,7 @@ def defn( name: Optional[str] = None, sandboxed: bool = True, dynamic: bool = False, + failure_exception_types: Sequence[Type[BaseException]] = [], ): """Decorator for workflow classes. @@ -116,6 +120,12 @@ def defn( dynamic: If true, this activity will be dynamic. Dynamic workflows have to accept a single 'Sequence[RawValue]' parameter. This cannot be set to true if name is present. + failure_exception_types: The types of exceptions that, if a + workflow-thrown exception extends, will cause the workflow/update to + fail instead of suspending the workflow via task failure. These are + applied in addition to ones set on the worker constructor. If + ``Exception`` is set, it effectively will fail a workflow/update in + all user exception cases. WARNING: This setting is experimental. """ def decorator(cls: ClassType) -> ClassType: @@ -124,6 +134,7 @@ def decorator(cls: ClassType) -> ClassType: cls, workflow_name=name or cls.__name__ if not dynamic else None, sandboxed=sandboxed, + failure_exception_types=failure_exception_types, ) return cls @@ -1162,6 +1173,7 @@ class _Definition: queries: Mapping[Optional[str], _QueryDefinition] updates: Mapping[Optional[str], _UpdateDefinition] sandboxed: bool + failure_exception_types: Sequence[Type[BaseException]] # Types loaded on post init if both are None arg_types: Optional[List[Type]] = None ret_type: Optional[Type] = None @@ -1200,7 +1212,11 @@ def must_from_run_fn(fn: Callable[..., Awaitable[Any]]) -> _Definition: @staticmethod def _apply_to_class( - cls: Type, *, workflow_name: Optional[str], sandboxed: bool + cls: Type, + *, + workflow_name: Optional[str], + sandboxed: bool, + failure_exception_types: Sequence[Type[BaseException]], ) -> None: # Check it's not being doubly applied if _Definition.from_class(cls): @@ -1323,6 +1339,7 @@ def _apply_to_class( queries=queries, updates=updates, sandboxed=sandboxed, + failure_exception_types=failure_exception_types, ) setattr(cls, "__temporal_workflow_definition", defn) setattr(run_fn, "__temporal_workflow_definition", defn) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index b8d30be6..9199cb25 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -25,6 +25,7 @@ def new_worker( task_queue: Optional[str] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), max_cached_workflows: int = 1000, + workflow_failure_exception_types: Sequence[Type[BaseException]] = [], **kwargs, ) -> Worker: return Worker( @@ -34,6 +35,7 @@ def new_worker( activities=activities, workflow_runner=workflow_runner, max_cached_workflows=max_cached_workflows, + workflow_failure_exception_types=workflow_failure_exception_types, **kwargs, ) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index f37a7374..510fa18d 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -118,6 +118,7 @@ def test_workflow_defn_good(): ), }, sandboxed=True, + failure_exception_types=[], ) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 3eac77e7..f60fc55d 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -12,6 +12,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone +from enum import IntEnum from typing import ( Any, Awaitable, @@ -22,6 +23,7 @@ Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -51,6 +53,7 @@ WorkflowHandle, WorkflowQueryFailedError, WorkflowUpdateFailedError, + WorkflowUpdateHandle, ) from temporalio.common import ( RawValue, @@ -73,6 +76,7 @@ ApplicationError, CancelledError, ChildWorkflowError, + TemporalError, TimeoutError, WorkflowAlreadyStartedError, ) @@ -4253,3 +4257,257 @@ async def test_workflow_current_build_id_appropriately_set( assert bid == "1.1" await worker.shutdown() + + +class FailureTypesScenario(IntEnum): + THROW_CUSTOM_EXCEPTION = 1 + CAUSE_NON_DETERMINISM = 2 + WAIT_FOREVER = 3 + + +class FailureTypesCustomException(Exception): + ... + + +class FailureTypesWorkflowBase(ABC): + async def run(self, scenario: FailureTypesScenario) -> None: + await self._apply_scenario(scenario) + + @workflow.signal + async def signal(self, scenario: FailureTypesScenario) -> None: + await self._apply_scenario(scenario) + + @workflow.update + async def update(self, scenario: FailureTypesScenario) -> None: + # We have to rollover the task so the task failure isn't treated as + # non-acceptance + await asyncio.sleep(0.01) + await self._apply_scenario(scenario) + + async def _apply_scenario(self, scenario: FailureTypesScenario) -> None: + if scenario == FailureTypesScenario.THROW_CUSTOM_EXCEPTION: + raise FailureTypesCustomException("Intentional exception") + elif scenario == FailureTypesScenario.CAUSE_NON_DETERMINISM: + if not workflow.unsafe.is_replaying(): + await asyncio.sleep(0.01) + elif scenario == FailureTypesScenario.WAIT_FOREVER: + await workflow.wait_condition(lambda: False) + + +@workflow.defn +class FailureTypesUnconfiguredWorkflow(FailureTypesWorkflowBase): + @workflow.run + async def run(self, scenario: FailureTypesScenario) -> None: + await super().run(scenario) + + +@workflow.defn( + failure_exception_types=[FailureTypesCustomException, workflow.NondeterminismError] +) +class FailureTypesConfiguredExplicitlyWorkflow(FailureTypesWorkflowBase): + @workflow.run + async def run(self, scenario: FailureTypesScenario) -> None: + await super().run(scenario) + + +@workflow.defn(failure_exception_types=[Exception]) +class FailureTypesConfiguredInheritedWorkflow(FailureTypesWorkflowBase): + @workflow.run + async def run(self, scenario: FailureTypesScenario) -> None: + await super().run(scenario) + + +async def test_workflow_failure_types_configured( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1903" + ) + + # Asserter for a single scenario + async def assert_scenario( + workflow: Type[FailureTypesWorkflowBase], + *, + expect_task_fail: bool, + fail_message_contains: str, + worker_level_failure_exception_type: Optional[Type[Exception]] = None, + workflow_scenario: Optional[FailureTypesScenario] = None, + signal_scenario: Optional[FailureTypesScenario] = None, + update_scenario: Optional[FailureTypesScenario] = None, + ) -> None: + logging.debug( + f"Asserting scenario %s", + { + "workflow": workflow, + "expect_task_fail": expect_task_fail, + "fail_message_contains": fail_message_contains, + "worker_level_failure_exception_type": worker_level_failure_exception_type, + "workflow_scenario": workflow_scenario, + "signal_scenario": signal_scenario, + "update_scenario": update_scenario, + }, + ) + async with new_worker( + client, + workflow, + max_cached_workflows=0, + workflow_failure_exception_types=[worker_level_failure_exception_type] + if worker_level_failure_exception_type + else [], + ) as worker: + # Start workflow + handle = await client.start_workflow( + workflow.run, + workflow_scenario or FailureTypesScenario.WAIT_FOREVER, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + if signal_scenario: + await handle.signal(workflow.signal, signal_scenario) + update_handle: Optional[WorkflowUpdateHandle[Any]] = None + if update_scenario: + update_handle = await handle.start_update( + workflow.update, update_scenario, id="my-update-1" + ) + + # Expect task or exception fail + if expect_task_fail: + + async def has_expected_task_fail() -> bool: + async for e in handle.fetch_history_events(): + if ( + e.HasField("workflow_task_failed_event_attributes") + and fail_message_contains + in e.workflow_task_failed_event_attributes.failure.message + ): + return True + return False + + await assert_eq_eventually(True, has_expected_task_fail) + else: + with pytest.raises(TemporalError) as err: + # Update does not throw on non-determinism, the workflow + # does instead + if ( + update_handle + and update_scenario + == FailureTypesScenario.THROW_CUSTOM_EXCEPTION + ): + await update_handle.result() + else: + await handle.result() + assert isinstance(err.value.cause, ApplicationError) + assert fail_message_contains in err.value.cause.message + + # Run a scenario + async def run_scenario( + workflow: Type[FailureTypesWorkflowBase], + scenario: FailureTypesScenario, + *, + expect_task_fail: bool = False, + worker_level_failure_exception_type: Optional[Type[Exception]] = None, + ) -> None: + # Run for workflow, signal, and update + fail_message_contains = ( + "Intentional exception" + if scenario == FailureTypesScenario.THROW_CUSTOM_EXCEPTION + else "Nondeterminism" + ) + await assert_scenario( + workflow, + expect_task_fail=expect_task_fail, + fail_message_contains=fail_message_contains, + worker_level_failure_exception_type=worker_level_failure_exception_type, + workflow_scenario=scenario, + ) + await assert_scenario( + workflow, + expect_task_fail=expect_task_fail, + fail_message_contains=fail_message_contains, + worker_level_failure_exception_type=worker_level_failure_exception_type, + signal_scenario=scenario, + ) + await assert_scenario( + workflow, + expect_task_fail=expect_task_fail, + fail_message_contains=fail_message_contains, + worker_level_failure_exception_type=worker_level_failure_exception_type, + update_scenario=scenario, + ) + + # Run all tasks concurrently + await asyncio.gather( + # When unconfigured completely, confirm task fails as normal + run_scenario( + FailureTypesUnconfiguredWorkflow, + FailureTypesScenario.THROW_CUSTOM_EXCEPTION, + expect_task_fail=True, + ), + run_scenario( + FailureTypesUnconfiguredWorkflow, + FailureTypesScenario.CAUSE_NON_DETERMINISM, + expect_task_fail=True, + ), + # When configured at the worker level explicitly, confirm not task fail + # but rather expected exceptions + run_scenario( + FailureTypesUnconfiguredWorkflow, + FailureTypesScenario.THROW_CUSTOM_EXCEPTION, + worker_level_failure_exception_type=FailureTypesCustomException, + ), + run_scenario( + FailureTypesUnconfiguredWorkflow, + FailureTypesScenario.CAUSE_NON_DETERMINISM, + worker_level_failure_exception_type=workflow.NondeterminismError, + ), + # When configured at the worker level inherited + run_scenario( + FailureTypesUnconfiguredWorkflow, + FailureTypesScenario.THROW_CUSTOM_EXCEPTION, + worker_level_failure_exception_type=Exception, + ), + run_scenario( + FailureTypesUnconfiguredWorkflow, + FailureTypesScenario.CAUSE_NON_DETERMINISM, + worker_level_failure_exception_type=Exception, + ), + # When configured at the workflow level explicitly + run_scenario( + FailureTypesConfiguredExplicitlyWorkflow, + FailureTypesScenario.THROW_CUSTOM_EXCEPTION, + ), + run_scenario( + FailureTypesConfiguredExplicitlyWorkflow, + FailureTypesScenario.CAUSE_NON_DETERMINISM, + ), + # When configured at the workflow level inherited + run_scenario( + FailureTypesConfiguredInheritedWorkflow, + FailureTypesScenario.THROW_CUSTOM_EXCEPTION, + ), + run_scenario( + FailureTypesConfiguredInheritedWorkflow, + FailureTypesScenario.CAUSE_NON_DETERMINISM, + ), + ) + + +@workflow.defn(failure_exception_types=[Exception]) +class FailOnBadInputWorkflow: + @workflow.run + async def run(self, param: str) -> None: + pass + + +async def test_workflow_fail_on_bad_input(client: Client): + with pytest.raises(WorkflowFailureError) as err: + async with new_worker(client, FailOnBadInputWorkflow) as worker: + await client.execute_workflow( + "FailOnBadInputWorkflow", + 123, + id=f"wf-{uuid}", + task_queue=worker.task_queue, + ) + assert isinstance(err.value.cause, ApplicationError) + assert "Failed decoding arguments" in err.value.cause.message