diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index caf992b78..d03af5f1f 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -83,7 +83,7 @@ def create_test_task_metadata( def create_task_metadata( - node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_task_group: bool = False + node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_name_as_task_id_prefix: bool = True ) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -106,9 +106,9 @@ def create_task_metadata( if hasattr(node.resource_type, "value") and node.resource_type in dbt_resource_to_class: if node.resource_type == DbtResourceType.MODEL: - task_id = f"{node.name}_run" - - if use_task_group is True: + if use_name_as_task_id_prefix: + task_id = f"{node.name}_run" + else: task_id = "run" else: task_id = f"{node.name}_{node.resource_type.value}" @@ -167,17 +167,14 @@ def build_airflow_graph( # The exception are the test nodes, since it would be too slow to run test tasks individually. # If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup for node_id, node in nodes.items(): - use_task_group = ( - node.resource_type == DbtResourceType.MODEL - and test_behavior == TestBehavior.AFTER_EACH - and node.has_test is True - ) task_meta = create_task_metadata( - node=node, execution_mode=execution_mode, args=task_args, use_task_group=use_task_group + node=node, + execution_mode=execution_mode, + args=task_args, + use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH, ) - if task_meta and node.resource_type != DbtResourceType.TEST: - if use_task_group is True: + if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH: with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group: task = create_airflow_task(task_meta, dag, task_group=model_task_group) test_meta = create_test_task_metadata( diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 1ad6c1737..81a81efb4 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -50,7 +50,6 @@ class DbtNode: file_path: Path tags: list[str] = field(default_factory=lambda: []) config: dict[str, Any] = field(default_factory=lambda: {}) - has_test: bool = False class DbtGraph: @@ -265,8 +264,6 @@ def load_via_dbt_ls(self) -> None: self.nodes = nodes self.filtered_nodes = nodes - self.update_node_dependency() - logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) @@ -311,8 +308,6 @@ def load_via_custom_parser(self) -> None: project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude ) - self.update_node_dependency() - logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) @@ -342,28 +337,11 @@ def load_from_dbt_manifest(self) -> None: tags=node_dict["tags"], config=node_dict["config"], ) - nodes[node.unique_id] = node self.nodes = nodes self.filtered_nodes = select_nodes( project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude ) - - self.update_node_dependency() - logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) - - def update_node_dependency(self) -> None: - """ - This will update the property `has_text` if node has `dbt` test - - Updates in-place: - * self.filtered_nodes - """ - for _, node in self.filtered_nodes.items(): - if node.resource_type == DbtResourceType.TEST: - for node_id in node.depends_on: - if node_id in self.filtered_nodes: - self.filtered_nodes[node_id].has_test = True diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index bd3777209..7b539bb5b 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -36,7 +36,6 @@ file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", tags=["has_child"], config={"materialized": "view"}, - has_test=True, ) test_parent_node = DbtNode( name="test_parent", unique_id="test_parent", resource_type=DbtResourceType.TEST, depends_on=["parent"], file_path="" @@ -50,8 +49,15 @@ tags=["nightly"], config={"materialized": "table"}, ) +test_child_node = DbtNode( + name="test_child", + unique_id="test_child", + resource_type=DbtResourceType.TEST, + depends_on=["child"], + file_path="", +) -sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node] +sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, test_child_node] sample_nodes = {node.unique_id: node for node in sample_nodes_list} @@ -87,18 +93,21 @@ def test_build_airflow_graph_with_after_each(): "seed_parent_seed", "parent.run", "parent.test", - "child_run", + "child.run", + "child.test", ] - assert topological_sort == expected_sort task_groups = dag.task_group_dict - assert len(task_groups) == 1 + assert len(task_groups) == 2 assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"} assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"] + assert task_groups["child"].upstream_task_ids == {"parent.test"} + assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"] + assert len(dag.leaves) == 1 - assert dag.leaves[0].task_id == "child_run" + assert dag.leaves[0].task_id == "child.test" @pytest.mark.skipif( @@ -222,7 +231,7 @@ def test_create_task_metadata_model(caplog): assert metadata.arguments == {"models": "my_model"} -def test_create_task_metadata_model_use_task_group(caplog): +def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog): child_node = DbtNode( name="my_model", unique_id="my_folder.my_model", @@ -232,12 +241,14 @@ def test_create_task_metadata_model_use_task_group(caplog): tags=[], config={}, ) - metadata = create_task_metadata(child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_task_group=True) + metadata = create_task_metadata( + child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_name_as_task_id_prefix=False + ) assert metadata.id == "run" -@pytest.mark.parametrize("use_task_group", (None, True, False)) -def test_create_task_metadata_seed(caplog, use_task_group): +@pytest.mark.parametrize("use_name_as_task_id_prefix", (None, True, False)) +def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix): sample_node = DbtNode( name="my_seed", unique_id="my_folder.my_seed", @@ -247,14 +258,14 @@ def test_create_task_metadata_seed(caplog, use_task_group): tags=[], config={}, ) - if use_task_group is None: + if use_name_as_task_id_prefix is None: metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={}) else: metadata = create_task_metadata( sample_node, execution_mode=ExecutionMode.DOCKER, args={}, - use_task_group=use_task_group, + use_name_as_task_id_prefix=use_name_as_task_id_prefix, ) assert metadata.id == "my_seed_seed" assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator" diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 317dce0bb..2547a4885 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -361,32 +361,3 @@ def test_load_via_load_via_custom_parser(pipeline_name): assert dbt_graph.nodes == dbt_graph.filtered_nodes # the custom parser does not add dbt test nodes assert len(dbt_graph.nodes) == 8 - - -@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency", return_value=None) -def test_update_node_dependency_called(mock_update_node_dependency): - dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) - dbt_graph = DbtGraph(project=dbt_project) - dbt_graph.load() - - assert mock_update_node_dependency.called - - -def test_update_node_dependency_target_exist(): - dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) - dbt_graph = DbtGraph(project=dbt_project) - dbt_graph.load() - - for _, nodes in dbt_graph.nodes.items(): - if nodes.resource_type == DbtResourceType.TEST: - for node_id in nodes.depends_on: - assert dbt_graph.nodes[node_id].has_test is True - - -def test_update_node_dependency_test_not_exist(): - dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) - dbt_graph = DbtGraph(project=dbt_project, exclude=["config.materialized:test"]) - dbt_graph.load_from_dbt_manifest() - - for _, nodes in dbt_graph.filtered_nodes.items(): - assert nodes.has_test is False