Skip to content

Commit

Permalink
Allow users to customize how DbtResource should be rendered in Airflow
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Sep 1, 2023
1 parent 001a7af commit 7b47977
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 40 deletions.
87 changes: 58 additions & 29 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

from typing import Any, Callable

from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode
from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, TESTABLE_DBT_RESOURCES
from cosmos.core.airflow import get_airflow_task as create_airflow_task
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dataset import get_dbt_dataset
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from airflow.models import BaseOperator


logger = get_logger(__name__)
Expand Down Expand Up @@ -81,7 +80,10 @@ def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dic
)
return task_metadata
else:
logger.warning(f"Unsupported resource type {node.resource_type} (node {node.unique_id}).")
logger.warning(
f"Unavailable conversion function for <{node.resource_type}> (node <{node.unique_id}>). "
"Define a converter function using render_config.dbt_resource_converter."
)
return None


Expand Down Expand Up @@ -117,17 +119,51 @@ def create_test_task_metadata(
)


def generate_task_or_group(
dag: DAG,
task_group: TaskGroup | None,
node: DbtNode,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
test_behavior: TestBehavior,
on_warning_callback: Callable[..., Any] | None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
task_or_group: BaseOperator | TaskGroup | None = None
task_meta = create_task_metadata(node=node, execution_mode=execution_mode, args=task_args)

# In most cases, we'll map one DBT node to one Airflow task
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
if task_meta and node.resource_type != DbtResourceType.TEST:
if node.resource_type in TESTABLE_DBT_RESOURCES and test_behavior == TestBehavior.AFTER_EACH:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
f"{node.name}_test",
execution_mode,
task_args=task_args,
model_name=node.name,
on_warning_callback=on_warning_callback,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
task >> test_task
task_or_group = model_task_group
else:
task_or_group = create_airflow_task(task_meta, dag, task_group=task_group)
return task_or_group


def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
execution_mode: ExecutionMode, # Cosmos-specific - decide what which class to use
task_args: dict[str, Any], # Cosmos/DBT - used to instantiate tasks
test_behavior: TestBehavior, # Cosmos-specific: how to inject tests to Airflow DAG
dbt_project_name: str, # DBT / Cosmos - used to name test task if mode is after_all,
conn_id: str, # Cosmos, dataset URI
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
emit_datasets: bool = True, # Cosmos
dbt_resource_converter: dict[DbtResourceType, Callable[..., Any]] | None = None,
) -> None:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand Down Expand Up @@ -159,35 +195,28 @@ def build_airflow_graph(
and “test_results” of type List.
:param emit_datasets: Decides if Cosmos should add outlets to model classes or not.
"""
dbt_resource_converter = dbt_resource_converter or {}
tasks_map = {}
task_or_group: TaskGroup | BaseOperator

# In most cases, we'll map one DBT node to one Airflow task
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
for node_id, node in nodes.items():
task_meta = create_task_metadata(node=node, execution_mode=execution_mode, args=task_args)
if emit_datasets:
task_args["outlets"] = [get_dbt_dataset(conn_id, dbt_project_name, node.name)]
if task_meta and node.resource_type != DbtResourceType.TEST:
if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
f"{node.name}_test",
execution_mode,
task_args=task_args,
model_name=node.name,
on_warning_callback=on_warning_callback,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
task >> test_task
task_or_group = model_task_group
else:
task_or_group = create_airflow_task(task_meta, dag, task_group=task_group)
conversion_function = dbt_resource_converter.get(node.resource_type, generate_task_or_group)
logger.info(f"Converting <{node.unique_id}> using <{conversion_function.__name__}>")
task_or_group = conversion_function(
dag=dag,
task_group=task_group,
dbt_project_name=dbt_project_name,
execution_mode=execution_mode,
task_args=task_args,
test_behavior=test_behavior,
on_warning_callback=on_warning_callback,
node=node,
)
if task_or_group is not None:
logger.info(f"Conversion of <{node.unique_id}> was successful!")
tasks_map[node_id] = task_or_group

# If test_behaviour=="after_all", there will be one test task, run "by the end" of the DAG
# If test_behaviour=="after_all", there will be one test task, run by the end of the DAG
# The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks)
if test_behavior == TestBehavior.AFTER_ALL:
task_args.pop("outlets", None)
Expand Down
5 changes: 3 additions & 2 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterator
from typing import Any, Iterator, Callable

from cosmos.constants import TestBehavior, ExecutionMode, LoadMode
from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode
from cosmos.exceptions import CosmosValueError
from cosmos.log import get_logger
from cosmos.profiles import BaseProfileMapping
Expand Down Expand Up @@ -39,6 +39,7 @@ class RenderConfig:
select: list[str] = field(default_factory=list)
exclude: list[str] = field(default_factory=list)
dbt_deps: bool = True
dbt_resource_converter: dict[DbtResourceType, Callable[..., Any]] | None = None


@dataclass
Expand Down
5 changes: 5 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ class DbtResourceType(Enum):
SEED = "seed"
TEST = "test"
SOURCE = "source"


TESTABLE_DBT_RESOURCES = {
DbtResourceType.MODEL
} # TODO: extend with DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED)
8 changes: 3 additions & 5 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(
) -> None:
project_config.validate_project()

emit_datasets = render_config.emit_datasets
dbt_root_path = project_config.dbt_project_path.parent
dbt_project_name = project_config.dbt_project_path.name
dbt_models_dir = project_config.models_relative_path
Expand All @@ -119,10 +118,10 @@ def __init__(
load_mode = render_config.load_method
manifest_path = project_config.parsed_manifest_path
dbt_executable_path = execution_config.dbt_executable_path
dbt_resource_converter = render_config.dbt_resource_converter

conn_id = "unknown"
if profile_config and profile_config.profile_mapping:
conn_id = profile_config.profile_mapping.conn_id
pass

profile_args = {}
if profile_config.profile_mapping:
Expand Down Expand Up @@ -170,7 +169,6 @@ def __init__(
task_args=task_args,
test_behavior=test_behavior,
dbt_project_name=dbt_project.name,
conn_id=conn_id,
on_warning_callback=on_warning_callback,
emit_datasets=emit_datasets,
dbt_resource_converter=dbt_resource_converter,
)
12 changes: 11 additions & 1 deletion dev/dags/example_cosmos_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from datetime import datetime
from pathlib import Path

from cosmos import DbtDag, ProjectConfig, ProfileConfig
from airflow.operators.dummy import DummyOperator

from cosmos import DbtDag, ProjectConfig, ProfileConfig, RenderConfig
from cosmos.constants import DbtResourceType

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))
Expand All @@ -26,13 +28,21 @@
profiles_yml_filepath=(DBT_ROOT_PATH / "simple/profiles.yml"),
)


def convert_source(dag, task_group, node, **kwargs):
return DummyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_source")


render_config = RenderConfig(dbt_resource_converter={DbtResourceType.SOURCE: convert_source})

# [START local_example]
example_cosmos_sources = DbtDag(
# dbt/cosmos-specific parameters
project_config=ProjectConfig(
DBT_ROOT_PATH / "simple",
),
profile_config=profile_config,
render_config=render_config,
operator_args={"append_env": True},
# normal dag parameters
schedule_interval="@daily",
Expand Down
7 changes: 4 additions & 3 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_build_airflow_graph_with_after_each():
task_args=task_args,
test_behavior=TestBehavior.AFTER_EACH,
dbt_project_name="astro_shop",
conn_id="fake_conn",
)
topological_sort = [task.task_id for task in dag.topological_sort()]
expected_sort = [
Expand Down Expand Up @@ -138,7 +137,6 @@ def test_build_airflow_graph_with_after_all():
task_args=task_args,
test_behavior=TestBehavior.AFTER_ALL,
dbt_project_name="astro_shop",
conn_id="fake_conn",
)
topological_sort = [task.task_id for task in dag.topological_sort()]
expected_sort = ["seed_parent_seed", "parent_run", "child_run", "astro_shop_test"]
Expand Down Expand Up @@ -214,7 +212,10 @@ def test_create_task_metadata_unsupported(caplog):
)
response = create_task_metadata(child_node, execution_mode="", args={})
assert response is None
expected_msg = "Unsupported resource type unsupported (node unsupported)."
expected_msg = (
"Unavailable conversion function for <unsupported> (node <unsupported>). "
"Define a converter function using render_config.dbt_resource_converter."
)
assert caplog.messages[0] == expected_msg


Expand Down

0 comments on commit 7b47977

Please sign in to comment.