From 709f92fb2f3b6e898f3e78bc02eb6a9f58461f57 Mon Sep 17 00:00:00 2001 From: Julian LaNeve Date: Thu, 5 Oct 2023 13:42:24 -0400 Subject: [PATCH] Ensure filtering with manifest loading works with single model --- .gitignore | 3 ++ cosmos/dbt/selector.py | 63 +++++++++++++++++++++-------- dev/dags/cosmos_manifest_example.py | 40 ++++++++++++++++++ tests/dbt/test_selector.py | 18 +++++++++ 4 files changed, 107 insertions(+), 17 deletions(-) create mode 100644 dev/dags/cosmos_manifest_example.py diff --git a/.gitignore b/.gitignore index e65d1f716..5991c231c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ # these files get autogenerated docs/profiles/* +# dbt_packages is a directory that gets created when you run dbt deps +dbt_packages/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index fc1048220..e10f5b9b2 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING +from cosmos.constants import DbtResourceType from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger @@ -76,6 +77,9 @@ def load_from_statement(self, statement: str) -> None: self.other.append(item) logger.warning("Unsupported select statement: %s", item) + def __repr__(self) -> str: + return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other})" + def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: SelectorConfig) -> set[str]: """ @@ -88,30 +92,55 @@ def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: Selector https://docs.getdbt.com/reference/node-selection/syntax https://docs.getdbt.com/reference/node-selection/yaml-selectors """ + if config.is_empty: + return set(nodes.keys()) + selected_nodes = set() + visited_nodes = set() - if not config.is_empty: - for node_id, node in nodes.items(): - if config.tags and not (sorted(node.tags) == sorted(config.tags)): - continue + def should_include_node(node_id: str, node: DbtNode) -> bool: + "Checks if a single node should be included. Only runs once per node with caching." + if node_id in visited_nodes: + return node_id in selected_nodes - supported_node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG} - config_tag = config.config.get("tags") - if config.config: - if config_tag and config_tag not in supported_node_config.get("tags", []): - continue + visited_nodes.add(node_id) - # Remove 'tags' as they've already been filtered for - config_copy = copy.deepcopy(config.config) - config_copy.pop("tags", None) - supported_node_config.pop("tags", None) + if config.tags: + if not (set(config.tags) == set(node.tags)): + return False - if not (config_copy.items() <= supported_node_config.items()): - continue + node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG} + config_tags = config.config.get("tags") + if config_tags and config_tags not in node_config.get("tags", []): + return False - if config.paths and not (set(config.paths).issubset(set(node.file_path.parents))): - continue + # Remove 'tags' as they've already been filtered for + config_copy = copy.deepcopy(config.config) + config_copy.pop("tags", None) + node_config.pop("tags", None) + + if not (config_copy.items() <= node_config.items()): + return False + + if config.paths: + for filter_path in config.paths: + if filter_path in node.file_path.parents or filter_path == node.file_path: + return True + + # if it's a test coming from a schema.yml file, check the model's file_path + if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml": + # try to get the corresponding model from node.depends_on + if len(node.depends_on) == 1: + model_node = nodes.get(node.depends_on[0]) + if model_node: + return should_include_node(node.depends_on[0], model_node) + + return False + + return True + for node_id, node in nodes.items(): + if should_include_node(node_id, node): selected_nodes.add(node_id) return selected_nodes diff --git a/dev/dags/cosmos_manifest_example.py b/dev/dags/cosmos_manifest_example.py new file mode 100644 index 000000000..ddf5f1b10 --- /dev/null +++ b/dev/dags/cosmos_manifest_example.py @@ -0,0 +1,40 @@ +""" +An example DAG that uses Cosmos to render a dbt project. +""" + +import os +from datetime import datetime +from pathlib import Path + +from cosmos import DbtDag, ProjectConfig, ProfileConfig, RenderConfig, LoadMode +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), +) + +# [START local_example] +cosmos_manifest_example = DbtDag( + # dbt/cosmos-specific parameters + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + manifest_path=DBT_ROOT_PATH / "jaffle_shop" / "target" / "manifest.json", + ), + profile_config=profile_config, + render_config=RenderConfig(load_method=LoadMode.DBT_MANIFEST, select=["path:models/customers.sql"]), + operator_args={"install_deps": True}, + # normal dag parameters + schedule_interval="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="cosmos_manifest_example", +) +# [END local_example] diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index 7c6ff3292..8e3fc8c61 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -212,3 +212,21 @@ def test_select_nodes_by_exclude_union_config_test_tags(): child_node.unique_id: child_node, } assert selected == expected + + +def test_select_nodes_by_path_dir(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["path:gen3/models"]) + expected = { + child_node.unique_id: child_node, + grandchild_1_test_node.unique_id: grandchild_1_test_node, + grandchild_2_test_node.unique_id: grandchild_2_test_node, + } + assert selected == expected + + +def test_select_nodes_by_path_file(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["path:gen2/models/parent.sql"]) + expected = { + parent_node.unique_id: parent_node, + } + assert selected == expected