diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index f507b03ac..a18f148a3 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import timedelta from typing import Any, Callable, Union from airflow.models import BaseOperator @@ -166,6 +167,8 @@ def create_task_metadata( task_id = f"{node.name}_run" if use_task_group is True: task_id = "run" + if "cosmos_task_timeout" in node.config.keys(): + args["execution_timeout"] = timedelta(seconds=int(node.config["cosmos_task_timeout"])) elif node.resource_type == DbtResourceType.SOURCE: if (source_rendering_behavior == SourceRenderingBehavior.NONE) or ( source_rendering_behavior == SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS diff --git a/docs/configuration/index.rst b/docs/configuration/index.rst index f6e60f61b..39895548e 100644 --- a/docs/configuration/index.rst +++ b/docs/configuration/index.rst @@ -27,3 +27,4 @@ Cosmos offers a number of configuration options to customize its behavior. For m Compiled SQL Logging Caching + Task Timeout diff --git a/docs/configuration/task-timeout.rst b/docs/configuration/task-timeout.rst new file mode 100644 index 000000000..16dd844ca --- /dev/null +++ b/docs/configuration/task-timeout.rst @@ -0,0 +1,18 @@ +.. _task-timeout: + +Task Timeout +================ + +In Airflow, the ``execution_timeout`` parameter allows you to set the maximum runtime for a Task. +In Cosmos, you can apply an ``execution_timeout`` to each dbt model task by specifying a ``cosmos_task_timeout`` in the model’s configuration, which sets a runtime threshold to trigger a timeout error if exceeded. + +Example: + +.. code-block:: yaml + + version: 2 + + models: + - name: my_model + config: + - cosmos_task_timeout: 600 # Specify in seconds. diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 1bd8cab35..d2fd11690 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1,5 +1,5 @@ import os -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from unittest.mock import MagicMock, patch @@ -568,6 +568,22 @@ def test_create_task_metadata_snapshot(caplog): assert metadata.arguments == {"models": "my_snapshot"} +def test_create_task_metadata_timeout(): + sample_node = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.my_folder.my_model", + resource_type=DbtResourceType.MODEL, + depends_on=[], + file_path="", + tags=[], + config={"cosmos_task_timeout": 1}, + ) + metadata = create_task_metadata( + sample_node, execution_mode=ExecutionMode.LOCAL, args={}, dbt_dag_task_group_identifier="" + ) + assert "execution_timeout" in metadata.arguments + assert metadata.arguments["execution_timeout"] == timedelta(seconds=1) + + @pytest.mark.parametrize( "node_type,node_unique_id,test_indirect_selection,additional_arguments", [