diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index d03af5f1f..763812bbf 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -171,10 +171,14 @@ def build_airflow_graph( node=node, execution_mode=execution_mode, args=task_args, - use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH, + use_name_as_task_id_prefix=not (test_behavior == TestBehavior.AFTER_EACH and node.has_test), ) if task_meta and node.resource_type != DbtResourceType.TEST: - if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH: + if ( + node.resource_type == DbtResourceType.MODEL + and node.has_test is True + and test_behavior == TestBehavior.AFTER_EACH + ): with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group: task = create_airflow_task(task_meta, dag, task_group=model_task_group) test_meta = create_test_task_metadata( diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 07a55ad79..f83f490e2 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -50,6 +50,7 @@ class DbtNode: file_path: Path tags: list[str] = field(default_factory=lambda: []) config: dict[str, Any] = field(default_factory=lambda: {}) + has_test: bool = False class DbtGraph: @@ -262,6 +263,8 @@ def load_via_dbt_ls(self) -> None: self.nodes = nodes self.filtered_nodes = nodes + self.update_node_dependency() + logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) @@ -306,6 +309,8 @@ def load_via_custom_parser(self) -> None: project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude ) + self.update_node_dependency() + logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) @@ -335,11 +340,28 @@ def load_from_dbt_manifest(self) -> None: tags=node_dict["tags"], config=node_dict["config"], ) + nodes[node.unique_id] = node self.nodes = nodes self.filtered_nodes = select_nodes( project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude ) + + self.update_node_dependency() + logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) + + def update_node_dependency(self) -> None: + """ + This will update the property `has_text` if node has `dbt` test + + Updates in-place: + * self.filtered_nodes + """ + for _, node in self.filtered_nodes.items(): + if node.resource_type == DbtResourceType.TEST: + for node_id in node.depends_on: + if node_id in self.filtered_nodes: + self.filtered_nodes[node_id].has_test = True diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 7b539bb5b..d05ee1c05 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -36,6 +36,7 @@ file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", tags=["has_child"], config={"materialized": "view"}, + has_test=True, ) test_parent_node = DbtNode( name="test_parent", unique_id="test_parent", resource_type=DbtResourceType.TEST, depends_on=["parent"], file_path="" @@ -49,15 +50,8 @@ tags=["nightly"], config={"materialized": "table"}, ) -test_child_node = DbtNode( - name="test_child", - unique_id="test_child", - resource_type=DbtResourceType.TEST, - depends_on=["child"], - file_path="", -) -sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, test_child_node] +sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node] sample_nodes = {node.unique_id: node for node in sample_nodes_list} @@ -93,21 +87,18 @@ def test_build_airflow_graph_with_after_each(): "seed_parent_seed", "parent.run", "parent.test", - "child.run", - "child.test", + "child_run", ] + assert topological_sort == expected_sort task_groups = dag.task_group_dict - assert len(task_groups) == 2 + assert len(task_groups) == 1 assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"} assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"] - assert task_groups["child"].upstream_task_ids == {"parent.test"} - assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"] - assert len(dag.leaves) == 1 - assert dag.leaves[0].task_id == "child.test" + assert dag.leaves[0].task_id == "child_run" @pytest.mark.skipif( diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 3a2a3eaa9..4dbdeb411 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -312,3 +312,32 @@ def test_load_via_load_via_custom_parser(pipeline_name): assert dbt_graph.nodes == dbt_graph.filtered_nodes # the custom parser does not add dbt test nodes assert len(dbt_graph.nodes) == 8 + + +@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency", return_value=None) +def test_update_node_dependency_called(mock_update_node_dependency): + dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) + dbt_graph = DbtGraph(project=dbt_project) + dbt_graph.load() + + assert mock_update_node_dependency.called + + +def test_update_node_dependency_target_exist(): + dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) + dbt_graph = DbtGraph(project=dbt_project) + dbt_graph.load() + + for _, nodes in dbt_graph.nodes.items(): + if nodes.resource_type == DbtResourceType.TEST: + for node_id in nodes.depends_on: + assert dbt_graph.nodes[node_id].has_test is True + + +def test_update_node_dependency_test_not_exist(): + dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) + dbt_graph = DbtGraph(project=dbt_project, exclude=["config.materialized:test"]) + dbt_graph.load_from_dbt_manifest() + + for _, nodes in dbt_graph.filtered_nodes.items(): + assert nodes.has_test is False