Skip to content

Commit

Permalink
Fix only create task group and test task only if model has test
Browse files Browse the repository at this point in the history
  • Loading branch information
raffifu committed Sep 18, 2023
1 parent 27d1945 commit 06f2b38
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 17 deletions.
8 changes: 6 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ def build_airflow_graph(
node=node,
execution_mode=execution_mode,
args=task_args,
use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH,
use_name_as_task_id_prefix=not (test_behavior == TestBehavior.AFTER_EACH and node.has_test),
)
if task_meta and node.resource_type != DbtResourceType.TEST:
if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH:
if (
node.resource_type == DbtResourceType.MODEL
and node.has_test is True
and test_behavior == TestBehavior.AFTER_EACH
):
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
Expand Down
22 changes: 22 additions & 0 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DbtNode:
file_path: Path
tags: list[str] = field(default_factory=lambda: [])
config: dict[str, Any] = field(default_factory=lambda: {})
has_test: bool = False


class DbtGraph:
Expand Down Expand Up @@ -262,6 +263,8 @@ def load_via_dbt_ls(self) -> None:
self.nodes = nodes
self.filtered_nodes = nodes

self.update_node_dependency()

logger.info("Total nodes: %i", len(self.nodes))
logger.info("Total filtered nodes: %i", len(self.nodes))

Expand Down Expand Up @@ -306,6 +309,8 @@ def load_via_custom_parser(self) -> None:
project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude
)

self.update_node_dependency()

logger.info("Total nodes: %i", len(self.nodes))
logger.info("Total filtered nodes: %i", len(self.nodes))

Expand Down Expand Up @@ -335,11 +340,28 @@ def load_from_dbt_manifest(self) -> None:
tags=node_dict["tags"],
config=node_dict["config"],
)

nodes[node.unique_id] = node

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude
)

self.update_node_dependency()

logger.info("Total nodes: %i", len(self.nodes))
logger.info("Total filtered nodes: %i", len(self.nodes))

def update_node_dependency(self) -> None:
"""
This will update the property `has_text` if node has `dbt` test
Updates in-place:
* self.filtered_nodes
"""
for _, node in self.filtered_nodes.items():
if node.resource_type == DbtResourceType.TEST:
for node_id in node.depends_on:
if node_id in self.filtered_nodes:
self.filtered_nodes[node_id].has_test = True
21 changes: 6 additions & 15 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql",
tags=["has_child"],
config={"materialized": "view"},
has_test=True,
)
test_parent_node = DbtNode(
name="test_parent", unique_id="test_parent", resource_type=DbtResourceType.TEST, depends_on=["parent"], file_path=""
Expand All @@ -49,15 +50,8 @@
tags=["nightly"],
config={"materialized": "table"},
)
test_child_node = DbtNode(
name="test_child",
unique_id="test_child",
resource_type=DbtResourceType.TEST,
depends_on=["child"],
file_path="",
)

sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, test_child_node]
sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node]
sample_nodes = {node.unique_id: node for node in sample_nodes_list}


Expand Down Expand Up @@ -93,21 +87,18 @@ def test_build_airflow_graph_with_after_each():
"seed_parent_seed",
"parent.run",
"parent.test",
"child.run",
"child.test",
"child_run",
]

assert topological_sort == expected_sort
task_groups = dag.task_group_dict
assert len(task_groups) == 2
assert len(task_groups) == 1

assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"}
assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"]

assert task_groups["child"].upstream_task_ids == {"parent.test"}
assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"]

assert len(dag.leaves) == 1
assert dag.leaves[0].task_id == "child.test"
assert dag.leaves[0].task_id == "child_run"


@pytest.mark.skipif(
Expand Down
29 changes: 29 additions & 0 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,32 @@ def test_load_via_load_via_custom_parser(pipeline_name):
assert dbt_graph.nodes == dbt_graph.filtered_nodes
# the custom parser does not add dbt test nodes
assert len(dbt_graph.nodes) == 8


@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency", return_value=None)
def test_update_node_dependency_called(mock_update_node_dependency):
dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST)
dbt_graph = DbtGraph(project=dbt_project)
dbt_graph.load()

assert mock_update_node_dependency.called


def test_update_node_dependency_target_exist():
dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST)
dbt_graph = DbtGraph(project=dbt_project)
dbt_graph.load()

for _, nodes in dbt_graph.nodes.items():
if nodes.resource_type == DbtResourceType.TEST:
for node_id in nodes.depends_on:
assert dbt_graph.nodes[node_id].has_test is True


def test_update_node_dependency_test_not_exist():
dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST)
dbt_graph = DbtGraph(project=dbt_project, exclude=["config.materialized:test"])
dbt_graph.load_from_dbt_manifest()

for _, nodes in dbt_graph.filtered_nodes.items():
assert nodes.has_test is False

0 comments on commit 06f2b38

Please sign in to comment.