diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 861a33778..53f80df2f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -76,6 +76,12 @@ repos: name: mypy-python additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil, apache-airflow] files: ^cosmos + - repo: https://github.com/pycqa/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + entry: pflake8 + additional_dependencies: [pyproject-flake8] ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index c41ee0713..89658811f 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -53,6 +53,64 @@ class DbtNode: has_test: bool = False +def create_symlinks(dbt_project_path: Path, tmp_dir: Path) -> None: + """Helper function to create symlinks to the dbt project files.""" + ignore_paths = (DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, "dbt_packages", "profiles.yml") + for child_name in os.listdir(dbt_project_path): + if child_name not in ignore_paths: + os.symlink(dbt_project_path / child_name, tmp_dir / child_name) + + +def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str: + """Run a command in a subprocess, returning the stdout.""" + logger.info("Running command: `%s`", " ".join(command)) + logger.info("Environment variable keys: %s", env_vars.keys()) + process = Popen( + command, + stdout=PIPE, + stderr=PIPE, + cwd=tmp_dir, + universal_newlines=True, + env=env_vars, + ) + stdout, stderr = process.communicate() + returncode = process.returncode + + if 'Run "dbt deps" to install package dependencies' in stdout and command[1] == "ls": + raise CosmosLoadDbtException( + "Unable to run dbt ls command due to missing dbt_packages. Set render_config.dbt_deps=True." + ) + + if returncode or "Error" in stdout: + details = stderr or stdout + raise CosmosLoadDbtException(f"Unable to run {command} due to the error:\n{details}") + + return stdout + + +def parse_dbt_ls_output(dbt_project_path: Path, ls_stdout: str) -> dict[str, DbtNode]: + """Parses the output of `dbt ls` into a dictionary of `DbtNode` instances.""" + nodes = {} + for line in ls_stdout.split("\n"): + try: + node_dict = json.loads(line.strip()) + except json.decoder.JSONDecodeError: + logger.debug("Skipped dbt ls line: %s", line) + else: + node = DbtNode( + name=node_dict.get("alias", node_dict["name"]), + unique_id=node_dict["unique_id"], + resource_type=DbtResourceType(node_dict["resource_type"]), + depends_on=node_dict.get("depends_on", {}).get("nodes", []), + file_path=dbt_project_path / node_dict["original_file_path"], + tags=node_dict["tags"], + config=node_dict["config"], + ) + nodes[node.unique_id] = node + logger.debug("Parsed dbt resource `%s` of type `%s`", node.unique_id, node.resource_type) + return nodes + + class DbtGraph: """ A dbt project graph (represented by `nodes` and `filtered_nodes`). @@ -131,6 +189,31 @@ def load( else: load_method[method]() + def run_dbt_ls(self, dbt_project_path: Path, tmp_dir: Path, env_vars: dict[str, str]) -> dict[str, DbtNode]: + """Runs dbt ls command and returns the parsed nodes.""" + ls_command = [self.dbt_cmd, "ls", "--output", "json"] + + if self.exclude: + ls_command.extend(["--exclude", *self.exclude]) + + if self.select: + ls_command.extend(["--select", *self.select]) + + ls_command.extend(self.local_flags) + + stdout = run_command(ls_command, tmp_dir, env_vars) + + logger.debug("dbt ls output: %s", stdout) + log_filepath = self.log_dir / DBT_LOG_FILENAME + logger.debug("dbt logs available in: %s", log_filepath) + if log_filepath.exists(): + with open(log_filepath) as logfile: + for line in logfile: + logger.debug(line.strip()) + + nodes = parse_dbt_ls_output(dbt_project_path, stdout) + return nodes + def load_via_dbt_ls(self) -> None: """ This is the most accurate way of loading `dbt` projects and filtering them out, since it uses the `dbt` command @@ -148,6 +231,8 @@ def load_via_dbt_ls(self) -> None: self.project.project_name, self.project.dbt_project_path, ) + if self.project.dbt_project_path is None: + raise CosmosLoadDbtException("Unable to dbt ls load a project without a project path.") if not self.project.dbt_project_path or not self.profile_config: raise CosmosLoadDbtException("Unable to load dbt project without project files and a profile config") @@ -155,28 +240,21 @@ def load_via_dbt_ls(self) -> None: if not shutil.which(self.dbt_cmd): raise CosmosLoadDbtException(f"Unable to find the dbt executable: {self.dbt_cmd}") - with self.profile_config.ensure_profile(use_mock_values=True) as profile_values: - (profile_path, env_vars) = profile_values - env = os.environ.copy() - env.update(env_vars) + with tempfile.TemporaryDirectory() as tmpdir: + logger.info( + "Content of the dbt project dir <%s>: `%s`", + self.project.dbt_project_path, + os.listdir(self.project.dbt_project_path), + ) + tmpdir_path = Path(tmpdir) + create_symlinks(self.project.dbt_project_path, tmpdir_path) - with tempfile.TemporaryDirectory() as tmpdir: - logger.info( - "Content of the dbt project dir <%s>: `%s`", - self.project.dbt_project_path, - os.listdir(self.project.dbt_project_path), - ) - logger.info("Creating symlinks from %s to `%s`", self.project.dbt_project_path, tmpdir) - # We create symbolic links to the original directory files and directories. - # This allows us to run the dbt command from within the temporary directory, outputting any necessary - # artifact and also allow us to run `dbt deps` - tmpdir_path = Path(tmpdir) - ignore_paths = (DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, "dbt_packages", "profiles.yml") - for child_name in os.listdir(self.project.dbt_project_path): - if child_name not in ignore_paths: - os.symlink(self.project.dbt_project_path / child_name, tmpdir_path / child_name) - - local_flags = [ + with self.profile_config.ensure_profile(use_mock_values=True) as profile_values: + (profile_path, env_vars) = profile_values + env = os.environ.copy() + env.update(env_vars) + + self.local_flags = [ "--project-dir", str(tmpdir), "--profiles-dir", @@ -186,92 +264,18 @@ def load_via_dbt_ls(self) -> None: "--target", self.profile_config.target_name, ] - log_dir = Path(env.get(DBT_LOG_PATH_ENVVAR) or tmpdir_path / DBT_LOG_DIR_NAME) - target_dir = Path(env.get(DBT_TARGET_PATH_ENVVAR) or tmpdir_path / DBT_TARGET_DIR_NAME) - env[DBT_LOG_PATH_ENVVAR] = str(log_dir) - env[DBT_TARGET_PATH_ENVVAR] = str(target_dir) + self.log_dir = Path(env.get(DBT_LOG_PATH_ENVVAR) or tmpdir_path / DBT_LOG_DIR_NAME) + self.target_dir = Path(env.get(DBT_TARGET_PATH_ENVVAR) or tmpdir_path / DBT_TARGET_DIR_NAME) + env[DBT_LOG_PATH_ENVVAR] = str(self.log_dir) + env[DBT_TARGET_PATH_ENVVAR] = str(self.target_dir) if self.dbt_deps: deps_command = [self.dbt_cmd, "deps"] - deps_command.extend(local_flags) - logger.info("Running command: `%s`", " ".join(deps_command)) - logger.info("Environment variable keys: %s", env.keys()) - process = Popen( - deps_command, - stdout=PIPE, - stderr=PIPE, - cwd=tmpdir, - universal_newlines=True, - env=env, - ) - stdout, stderr = process.communicate() - returncode = process.returncode + deps_command.extend(self.local_flags) + stdout = run_command(deps_command, tmpdir_path, env) logger.debug("dbt deps output: %s", stdout) - if returncode or "Error" in stdout: - details = stderr or stdout - raise CosmosLoadDbtException(f"Unable to run dbt deps command due to the error:\n{details}") - - ls_command = [self.dbt_cmd, "ls", "--output", "json"] - - if self.exclude: - ls_command.extend(["--exclude", *self.exclude]) - - if self.select: - ls_command.extend(["--select", *self.select]) - - ls_command.extend(local_flags) - - logger.info("Running command: `%s`", " ".join(ls_command)) - logger.info("Environment variable keys: %s", env.keys()) - - process = Popen( - ls_command, - stdout=PIPE, - stderr=PIPE, - cwd=tmpdir, - universal_newlines=True, - env=env, - ) - - stdout, stderr = process.communicate() - returncode = process.returncode - - logger.debug("dbt output: %s", stdout) - log_filepath = log_dir / DBT_LOG_FILENAME - logger.debug("dbt logs available in: %s", log_filepath) - if log_filepath.exists(): - with open(log_filepath) as logfile: - for line in logfile: - logger.debug(line.strip()) - - if 'Run "dbt deps" to install package dependencies' in stdout: - raise CosmosLoadDbtException( - "Unable to run dbt ls command due to missing dbt_packages. Set render_config.dbt_deps=True." - ) - - if returncode or "Error" in stdout: - details = stderr or stdout - raise CosmosLoadDbtException(f"Unable to run dbt ls command due to the error:\n{details}") - - nodes = {} - for line in stdout.split("\n"): - try: - node_dict = json.loads(line.strip()) - except json.decoder.JSONDecodeError: - logger.debug("Skipped dbt ls line: %s", line) - else: - node = DbtNode( - name=node_dict.get("alias", node_dict["name"]), - unique_id=node_dict["unique_id"], - resource_type=DbtResourceType(node_dict["resource_type"]), - depends_on=node_dict.get("depends_on", {}).get("nodes", []), - file_path=self.project.dbt_project_path / node_dict["original_file_path"], - tags=node_dict["tags"], - config=node_dict["config"], - ) - nodes[node.unique_id] = node - logger.debug("Parsed dbt resource `%s` of type `%s`", node.unique_id, node.resource_type) + nodes = self.run_dbt_ls(self.project.dbt_project_path, tmpdir_path, env) self.nodes = nodes self.filtered_nodes = nodes diff --git a/cosmos/dbt/parser/project.py b/cosmos/dbt/parser/project.py index e154bb0ed..278b1a0f7 100644 --- a/cosmos/dbt/parser/project.py +++ b/cosmos/dbt/parser/project.py @@ -137,17 +137,14 @@ def __post_init__(self) -> None: """ Parses the file and extracts metadata (dependencies, tags, etc) """ - # first, get an empty config - config = DbtModelConfig() - var_args: Dict[str, Any] = self.operator_args.get("vars", {}) + if self.type == DbtModelType.DBT_SEED or self.type == DbtModelType.DBT_TEST: + return - if self.type == DbtModelType.DBT_MODEL: - # get the code from the file - code = self.path.read_text() + config = DbtModelConfig() + self.var_args: Dict[str, Any] = self.operator_args.get("vars", {}) + code = self.path.read_text() - # we remove first and last line if the code is a snapshot - elif self.type == DbtModelType.DBT_SNAPSHOT: - code = self.path.read_text() + if self.type == DbtModelType.DBT_SNAPSHOT: snapshot_name = code.split("{%")[1] snapshot_name = snapshot_name.split("%}")[0] snapshot_name = snapshot_name.split(" ")[2] @@ -156,56 +153,74 @@ def __post_init__(self) -> None: code = code.split("%}")[1] code = code.split("{%")[0] - elif self.type == DbtModelType.DBT_SEED or self.type == DbtModelType.DBT_TEST: - return - if self.path.suffix == PYTHON_FILE_SUFFIX: config.upstream_models = config.upstream_models.union(set(extract_python_file_upstream_requirements(code))) else: - # get the dependencies - env = jinja2.Environment() - jinja2_ast = env.parse(code) - # iterate over the jinja nodes to extract info - for base_node in jinja2_ast.find_all(jinja2.nodes.Call): - if hasattr(base_node.node, "name"): - try: - # check we have a ref - this indicates a dependency - if base_node.node.name == "ref": - # if it is, get the first argument - first_arg = base_node.args[0] - # if it contains vars, render the value of the var - if isinstance(first_arg, jinja2.nodes.Concat): - value = "" - for node in first_arg.nodes: - if isinstance(node, jinja2.nodes.Const): - value += node.value - elif ( - isinstance(node, jinja2.nodes.Call) - and isinstance(node.node, jinja2.nodes.Name) - and isinstance(node.args[0], jinja2.nodes.Const) - and node.node.name == "var" - ): - value += var_args[node.args[0].value] - config.upstream_models.add(value) - elif isinstance(first_arg, jinja2.nodes.Const): - # and add it to the config - config.upstream_models.add(first_arg.value) - - # check if we have a config - this could contain tags - if base_node.node.name == "config": - # if it is, check if any kwargs are tags - for kwarg in base_node.kwargs: - for selector in self.config.config_types: - extracted_config = self._extract_config(kwarg=kwarg, config_name=selector) - config.config_selectors |= ( - set(extracted_config) if isinstance(extracted_config, (str, List)) else set() - ) - except KeyError as e: - logger.warning(f"Could not add upstream model for config in {self.path}: {e}") - - # set the config and set the parsed file flag to true + upstream_models, extracted_config = self.extract_sql_file_requirements(code) + config.upstream_models = config.upstream_models.union(set(upstream_models)) + config.config_selectors |= extracted_config + self.config = config + def extract_sql_file_requirements(self, code: str) -> tuple[list[str], set[str]]: + """Extracts upstream models and config selectors from a dbt sql file.""" + # get the dependencies + env = jinja2.Environment() + jinja2_ast = env.parse(code) + upstream_models = [] + config_selectors = set() + # iterate over the jinja nodes to extract info + for base_node in jinja2_ast.find_all(jinja2.nodes.Call): + if hasattr(base_node.node, "name"): + try: + # check we have a ref - this indicates a dependency + if base_node.node.name == "ref": + upstream_model = self._parse_jinja_ref_node(base_node) + if upstream_model: + upstream_models.append(upstream_model) + # check if we have a config - this could contain tags + if base_node.node.name == "config": + config_selectors |= self._parse_jinja_config_node(base_node) + except KeyError as e: + logger.warning(f"Could not add upstream model for config in {self.path}: {e}") + + return upstream_models, config_selectors + + def _parse_jinja_ref_node(self, base_node: jinja2.nodes.Call) -> str | None: + """Parses a jinja ref node.""" + # get the first argument + first_arg = base_node.args[0] + value = None + # if it contains vars, render the value of the var + if isinstance(first_arg, jinja2.nodes.Concat): + value = "" + for node in first_arg.nodes: + if isinstance(node, jinja2.nodes.Const): + value += node.value + elif ( + isinstance(node, jinja2.nodes.Call) + and isinstance(node.node, jinja2.nodes.Name) + and isinstance(node.args[0], jinja2.nodes.Const) + and node.node.name == "var" + ): + value += self.var_args[node.args[0].value] + elif isinstance(first_arg, jinja2.nodes.Const): + # and add it to the config + value = first_arg.value + + return value + + def _parse_jinja_config_node(self, base_node: jinja2.nodes.Call) -> set[str]: + """Parses a jinja config node.""" + # check if any kwargs are tags + selector_config = set() + for kwarg in base_node.kwargs: + for config_name in self.config.config_types: + if hasattr(kwarg, "key") and kwarg.key == config_name: + extracted_config = self._extract_config(kwarg, config_name) + selector_config |= set(extracted_config) if isinstance(extracted_config, (str, List)) else set() + return selector_config + # TODO following needs coverage: def _extract_config(self, kwarg: Any, config_name: str) -> Any: if hasattr(kwarg, "key") and kwarg.key == config_name: @@ -354,47 +369,21 @@ def _handle_config_file(self, path: Path) -> None: if not config_dict: return - for model in config_dict.get("models", []): - model_name = model.get("name") + for model_config in config_dict.get("models", []): + model_name = model_config.get("name") # if the model doesn't exist, we can't do anything if not model_name: continue - # tests - for column in model.get("columns", []): - for test in column.get("tests", []): - if not column.get("name"): - continue - - # Get the test name - if not isinstance(test, str): - test = list(test.keys())[0] - - test_model = DbtModel( - name=f"{test}_{column['name']}_{model_name}", - type=DbtModelType.DBT_TEST, - path=path, - operator_args=self.operator_args, - config=DbtModelConfig(upstream_models=set({model_name})), - ) - - self.tests[test_model.name] = test_model + model_tests = self._extract_model_tests(model_name, model_config, path) + self.tests.update(model_tests) # config_selectors if model_name not in self.models: continue - config_selectors = [] - for selector in DbtModelConfig.config_types: - config_value = model.get("config", {}).get(selector) - if config_value: - if isinstance(config_value, str): - config_selectors.append(f"{selector}:{config_value}") - else: - for item in config_value: - if item: - config_selectors.append(f"{selector}:{item}") + config_selectors = self._extract_config_selectors(model_config) # dbt default ensures "materialized:view" is set for all models if nothing is specified so that it will # work in a select/exclude list @@ -407,3 +396,40 @@ def _handle_config_file(self, path: Path) -> None: # then, get the model and merge the configs model = self.models[model_name] model.config = model.config + DbtModelConfig(config_selectors=set(config_selectors)) + + def _extract_model_tests( + self, model_name: str, model_config: dict[str, list[dict[str, dict[str, list[str]]]]], path: Path + ) -> dict[str, DbtModel]: + """Extracts tests from a dbt config file model.""" + tests = {} + for column in model_config.get("columns", []): + for test in column.get("tests", []): + if not column.get("name"): + continue + # Get the test name + if not isinstance(test, str): + test = list(test.keys())[0] + + test_model = DbtModel( + name=f"{test}_{column['name']}_{model_name}", + type=DbtModelType.DBT_TEST, + path=path, + operator_args=self.operator_args, + config=DbtModelConfig(upstream_models=set({model_name})), + ) + tests[test_model.name] = test_model + return tests + + def _extract_config_selectors(self, model_config: dict[str, dict[str, str | list[str]]]) -> list[str]: + """Extracts config selectors from a dbt config file model.""" + config_selectors = [] + for selector in DbtModelConfig.config_types: + config_value = model_config.get("config", {}).get(selector) + if config_value: + if isinstance(config_value, str): + config_selectors.append(f"{selector}:{config_value}") + else: + for item in config_value: + if item: + config_selectors.append(f"{selector}:{item}") + return config_selectors diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index f195f225b..c7316dc75 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -2,7 +2,7 @@ from pathlib import Path import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from cosmos.constants import DbtResourceType from cosmos.exceptions import CosmosValueError @@ -84,72 +84,96 @@ def __repr__(self) -> str: return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other})" -def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: SelectorConfig) -> set[str]: +class NodeSelector: """ - Return a list of node ids which matches the configuration defined in config. + Class to select nodes based on a selector config. :param nodes: Dictionary mapping dbt nodes (node.unique_id to node) :param config: User-defined select statements - - References: - https://docs.getdbt.com/reference/node-selection/syntax - https://docs.getdbt.com/reference/node-selection/yaml-selectors """ - if config.is_empty: - return set(nodes.keys()) - selected_nodes = set() - visited_nodes = set() + def __init__(self, nodes: dict[str, DbtNode], config: SelectorConfig) -> None: + self.nodes = nodes + self.config = config + + def select_nodes_ids_by_intersection(self) -> set[str]: + """ + Return a list of node ids which matches the configuration defined in config. + + References: + https://docs.getdbt.com/reference/node-selection/syntax + https://docs.getdbt.com/reference/node-selection/yaml-selectors + """ + if self.config.is_empty: + return set(self.nodes.keys()) + + self.selected_nodes: set[str] = set() + self.visited_nodes: set[str] = set() + + for node_id, node in self.nodes.items(): + if self._should_include_node(node_id, node): + self.selected_nodes.add(node_id) + + return self.selected_nodes - def should_include_node(node_id: str, node: DbtNode) -> bool: + def _should_include_node(self, node_id: str, node: DbtNode) -> bool: "Checks if a single node should be included. Only runs once per node with caching." - if node_id in visited_nodes: - return node_id in selected_nodes + if node_id in self.visited_nodes: + return node_id in self.selected_nodes - visited_nodes.add(node_id) + self.visited_nodes.add(node_id) if node.resource_type == DbtResourceType.TEST: - node.tags = getattr(nodes.get(node.depends_on[0]), "tags", []) + node.tags = getattr(self.nodes.get(node.depends_on[0]), "tags", []) - if config.tags: - if not (set(config.tags) <= set(node.tags)): - return False + if not self._is_tags_subset(node): + return False node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG} - config_tags = config.config.get("tags") - if config_tags and config_tags not in node_config.get("tags", []): + + if not self._is_config_subset(node_config): return False # Remove 'tags' as they've already been filtered for - config_copy = copy.deepcopy(config.config) + config_copy = copy.deepcopy(self.config.config) config_copy.pop("tags", None) node_config.pop("tags", None) if not (config_copy.items() <= node_config.items()): return False - if config.paths: - for filter_path in config.paths: - if filter_path in node.file_path.parents or filter_path == node.file_path: - return True + if self.config.paths and not self._is_path_matching(node): + return False - # if it's a test coming from a schema.yml file, check the model's file_path - if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml": - # try to get the corresponding model from node.depends_on - if len(node.depends_on) == 1: - model_node = nodes.get(node.depends_on[0]) - if model_node: - return should_include_node(node.depends_on[0], model_node) + return True + def _is_tags_subset(self, node: DbtNode) -> bool: + """Checks if the node's tags are a subset of the config's tags.""" + if not (set(self.config.tags) <= set(node.tags)): return False + return True + def _is_config_subset(self, node_config: dict[str, Any]) -> bool: + """Checks if the node's config is a subset of the config's config.""" + config_tags = self.config.config.get("tags") + if config_tags and config_tags not in node_config.get("tags", []): + return False return True - for node_id, node in nodes.items(): - if should_include_node(node_id, node): - selected_nodes.add(node_id) + def _is_path_matching(self, node: DbtNode) -> bool: + """Checks if the node's path is a subset of the config's paths.""" + for filter_path in self.config.paths: + if filter_path in node.file_path.parents or filter_path == node.file_path: + return True - return selected_nodes + # if it's a test coming from a schema.yml file, check the model's file_path + if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml": + # try to get the corresponding model from node.depends_on + if len(node.depends_on) == 1: + model_node = self.nodes.get(node.depends_on[0]) + if model_node: + return self._should_include_node(node.depends_on[0], model_node) + return False def retrieve_by_label(statement_list: list[str], label: str) -> set[str]: @@ -204,7 +228,8 @@ def select_nodes( for statement in select: config = SelectorConfig(project_dir, statement) - select_ids = select_nodes_ids_by_intersection(nodes, config) + node_selector = NodeSelector(nodes, config) + select_ids = node_selector.select_nodes_ids_by_intersection() subset_ids = subset_ids.union(set(select_ids)) if select: @@ -215,7 +240,8 @@ def select_nodes( exclude_ids: set[str] = set() for statement in exclude: config = SelectorConfig(project_dir, statement) - exclude_ids = exclude_ids.union(set(select_nodes_ids_by_intersection(nodes, config))) + node_selector = NodeSelector(nodes, config) + exclude_ids = exclude_ids.union(set(node_selector.select_nodes_ids_by_intersection())) subset_ids = set(nodes_ids) - set(exclude_ids) return {id_: nodes[id_] for id_ in subset_ids} diff --git a/pyproject.toml b/pyproject.toml index 204ef771b..47e770de2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,3 +250,7 @@ line-length = 120 [tool.distutils.bdist_wheel] universal = true + +[tool.flake8] +max-complexity = 10 +select = "C" diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index b36bf92fb..b9215aafa 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -7,7 +7,15 @@ from cosmos.config import ProfileConfig, ProjectConfig from cosmos.constants import DbtResourceType, ExecutionMode -from cosmos.dbt.graph import CosmosLoadDbtException, DbtGraph, LoadMode +from cosmos.dbt.graph import ( + CosmosLoadDbtException, + DbtGraph, + DbtNode, + LoadMode, + create_symlinks, + run_command, + parse_dbt_ls_output, +) from cosmos.profiles import PostgresUserPasswordProfileMapping DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt" @@ -407,12 +415,10 @@ def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen): ), ), ) - with pytest.raises(CosmosLoadDbtException) as err_info: + expected = r"Unable to run \['.+dbt', 'deps', .*\] due to the error:\nSome stderr message" + with pytest.raises(CosmosLoadDbtException, match=expected): dbt_graph.load_via_dbt_ls() - expected = "Unable to run dbt deps command due to the error:\nSome stderr message" - assert err_info.value.args[0] == expected - @pytest.mark.integration @patch("cosmos.dbt.graph.Popen.communicate", return_value=("Some Runtime Error", "")) @@ -430,11 +436,10 @@ def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate): ), ), ) - with pytest.raises(CosmosLoadDbtException) as err_info: + expected = r"Unable to run \['.+dbt', 'deps', .*\] due to the error:\nSome Runtime Error" + with pytest.raises(CosmosLoadDbtException, match=expected): dbt_graph.load_via_dbt_ls() - expected = "Unable to run dbt deps command due to the error:\nSome Runtime Error" - assert err_info.value.args[0] == expected mock_popen_communicate.assert_called_once() @@ -574,3 +579,59 @@ def test_load_dbt_ls_and_manifest_with_model_version(load_method): "model.jaffle_shop.stg_orders.v1", "model.jaffle_shop.stg_payments", } == set(dbt_graph.nodes["model.jaffle_shop.orders"].depends_on) + + +def test_create_symlinks(tmp_path): + """Tests that symlinks are created for expected files in the dbt project directory.""" + tmp_dir = tmp_path / "dbt-project" + tmp_dir.mkdir() + + create_symlinks(DBT_PROJECTS_ROOT_DIR / "jaffle_shop", tmp_dir) + for child in tmp_dir.iterdir(): + assert child.is_symlink() + assert child.name not in ("logs", "target", "profiles.yml", "dbt_packages") + + +@pytest.mark.parametrize( + "stdout,returncode", + [ + ("all good", None), + pytest.param("fail", 599, marks=pytest.mark.xfail(raises=CosmosLoadDbtException)), + pytest.param("Error", None, marks=pytest.mark.xfail(raises=CosmosLoadDbtException)), + ], +) +@patch("cosmos.dbt.graph.Popen") +def test_run_command(mock_popen, stdout, returncode): + fake_command = ["fake", "command"] + fake_dir = Path("fake_dir") + env_vars = {"fake": "env_var"} + + mock_popen.return_value.communicate.return_value = (stdout, "") + mock_popen.return_value.returncode = returncode + + return_value = run_command(fake_command, fake_dir, env_vars) + args, kwargs = mock_popen.call_args + assert args[0] == fake_command + assert kwargs["cwd"] == fake_dir + assert kwargs["env"] == env_vars + + assert return_value == stdout + + +def test_parse_dbt_ls_output(): + fake_ls_stdout = '{"resource_type": "model", "name": "fake-name", "original_file_path": "fake-file-path.sql", "unique_id": "fake-unique-id", "tags": [], "config": {}}' + + expected_nodes = { + "fake-unique-id": DbtNode( + name="fake-name", + unique_id="fake-unique-id", + resource_type=DbtResourceType.MODEL, + file_path=Path("fake-project/fake-file-path.sql"), + tags=[], + config={}, + depends_on=[], + ), + } + nodes = parse_dbt_ls_output(Path("fake-project"), fake_ls_stdout) + + assert expected_nodes == nodes