Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove prefix of run and test tasks if test_behavior = TestBehavior.AFTER_EACH #524

Merged
merged 5 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

from typing import Any, Callable

from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
from cosmos.core.airflow import get_airflow_task as create_airflow_task
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from airflow.models import BaseOperator


logger = get_logger(__name__)

Expand Down Expand Up @@ -51,14 +50,18 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
return leaves


def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None:
def create_task_metadata(
node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_name_as_task_id_prefix=True
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.

:param node: The dbt node which we desired to convert into an Airflow Task
:param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES).
Default is ExecutionMode.LOCAL.
:param args: Arguments to be used to instantiate an Airflow Task
:param use_name_as_task_id_prefix: If resource_type is DbtResourceType.MODEL, it determines whether
using name as task id prefix or not. If it is True task_id = <node.name>_run, else task_id=run.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = {
Expand All @@ -70,9 +73,16 @@ def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dic
args = {**args, **{"models": node.name}}

if hasattr(node.resource_type, "value") and node.resource_type in dbt_resource_to_class:
task_id_suffix = "run" if node.resource_type == DbtResourceType.MODEL else node.resource_type.value
if node.resource_type == DbtResourceType.MODEL:
if use_name_as_task_id_prefix:
task_id = f"{node.name}_run"
else:
task_id = "run"
else:
task_id = f"{node.name}_{node.resource_type.value}"

task_metadata = TaskMetadata(
id=f"{node.name}_{task_id_suffix}",
id=task_id,
operator_class=calculate_operator_class(
execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type]
),
Expand Down Expand Up @@ -157,13 +167,18 @@ def build_airflow_graph(
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
for node_id, node in nodes.items():
task_meta = create_task_metadata(node=node, execution_mode=execution_mode, args=task_args)
task_meta = create_task_metadata(
node=node,
execution_mode=execution_mode,
args=task_args,
use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH,
)
if task_meta and node.resource_type != DbtResourceType.TEST:
if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
f"{node.name}_test",
"test",
execution_mode,
task_args=task_args,
model_name=node.name,
Expand Down
54 changes: 39 additions & 15 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from datetime import datetime
from pathlib import Path
from unittest.mock import patch

import pytest
Expand All @@ -10,15 +10,14 @@
from cosmos.airflow.graph import (
build_airflow_graph,
calculate_leaves,
calculate_operator_class,
create_task_metadata,
create_test_task_metadata,
calculate_operator_class,
)
from cosmos.config import ProfileConfig
from cosmos.profiles import PostgresUserPasswordProfileMapping
from cosmos.constants import ExecutionMode, DbtResourceType, TestBehavior
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
from cosmos.dbt.graph import DbtNode

from cosmos.profiles import PostgresUserPasswordProfileMapping

SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/")

Expand Down Expand Up @@ -92,23 +91,23 @@ def test_build_airflow_graph_with_after_each():
topological_sort = [task.task_id for task in dag.topological_sort()]
expected_sort = [
"seed_parent_seed",
"parent.parent_run",
"parent.parent_test",
"child.child_run",
"child.child_test",
"parent.run",
"parent.test",
"child.run",
"child.test",
]
assert topological_sort == expected_sort
task_groups = dag.task_group_dict
assert len(task_groups) == 2

assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"}
assert list(task_groups["parent"].children.keys()) == ["parent.parent_run", "parent.parent_test"]
assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"]

assert task_groups["child"].upstream_task_ids == {"parent.parent_test"}
assert list(task_groups["child"].children.keys()) == ["child.child_run", "child.child_test"]
assert task_groups["child"].upstream_task_ids == {"parent.test"}
assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"]

assert len(dag.leaves) == 1
assert dag.leaves[0].task_id == "child.child_test"
assert dag.leaves[0].task_id == "child.test"


@pytest.mark.skipif(
Expand Down Expand Up @@ -232,7 +231,24 @@ def test_create_task_metadata_model(caplog):
assert metadata.arguments == {"models": "my_model"}


def test_create_task_metadata_seed(caplog):
def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog):
child_node = DbtNode(
name="my_model",
unique_id="my_folder.my_model",
resource_type=DbtResourceType.MODEL,
depends_on=[],
file_path=Path(""),
tags=[],
config={},
)
metadata = create_task_metadata(
child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_name_as_task_id_prefix=False
)
assert metadata.id == "run"


@pytest.mark.parametrize("use_name_as_task_id_prefix", (None, True, False))
def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix):
sample_node = DbtNode(
name="my_seed",
unique_id="my_folder.my_seed",
Expand All @@ -242,7 +258,15 @@ def test_create_task_metadata_seed(caplog):
tags=[],
config={},
)
metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={})
if use_name_as_task_id_prefix is None:
metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={})
else:
metadata = create_task_metadata(
sample_node,
execution_mode=ExecutionMode.DOCKER,
args={},
use_name_as_task_id_prefix=use_name_as_task_id_prefix,
)
assert metadata.id == "my_seed_seed"
assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator"
assert metadata.arguments == {"models": "my_seed"}
Expand Down
Loading