diff --git a/pyzeebe/worker/task_router.py b/pyzeebe/worker/task_router.py index 7a250af8..111b46c5 100644 --- a/pyzeebe/worker/task_router.py +++ b/pyzeebe/worker/task_router.py @@ -22,12 +22,18 @@ async def default_exception_handler(e: Exception, job: Job) -> None: class ZeebeTaskRouter: - def __init__(self, before: Optional[List[TaskDecorator]] = None, after: Optional[List[TaskDecorator]] = None): + def __init__( + self, + before: Optional[List[TaskDecorator]] = None, + after: Optional[List[TaskDecorator]] = None, + exception_handler: ExceptionHandler = default_exception_handler, + ): """ Args: before (List[TaskDecorator]): Decorators to be performed before each task after (List[TaskDecorator]): Decorators to be performed after each task """ + self._default_exception_handler = exception_handler self._before: List[TaskDecorator] = before or [] self._after: List[TaskDecorator] = after or [] self.tasks: List[Task] = [] @@ -35,7 +41,7 @@ def __init__(self, before: Optional[List[TaskDecorator]] = None, after: Optional def task( self, task_type: str, - exception_handler: ExceptionHandler = default_exception_handler, + exception_handler: Optional[ExceptionHandler] = None, variables_to_fetch: Optional[List[str]] = None, timeout_ms: int = 10000, max_jobs_to_activate: int = 32, @@ -67,11 +73,12 @@ def task( DuplicateTaskTypeError: If a task from the router already exists in the worker NoVariableNameGivenError: When single_value is set, but no variable_name is given """ + _exception_handler = exception_handler or self._default_exception_handler def task_wrapper(task_function: Callable): config = TaskConfig( task_type, - exception_handler, + _exception_handler, timeout_ms, max_jobs_to_activate, max_running_jobs, diff --git a/tests/unit/worker/task_router_test.py b/tests/unit/worker/task_router_test.py index 6321555a..dc7d0e94 100644 --- a/tests/unit/worker/task_router_test.py +++ b/tests/unit/worker/task_router_test.py @@ -19,6 +19,16 @@ def test_get_task(router: ZeebeTaskRouter, task: Task): assert found_task == task +def test_task_inherits_exception_handler(router: ZeebeTaskRouter, task: Task): + router._default_exception_handler = str + router.task(task.type)(task.original_function) + + found_task = router.get_task(task.type) + found_handler = found_task.config.exception_handler + + assert found_handler == str + + def test_get_fake_task(router: ZeebeTaskRouter): with pytest.raises(TaskNotFoundError): router.get_task(str(uuid4()))