From d71cdf07ea2f6f076ecca68e9b3040829f6e5dba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:36:24 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?= =?UTF-8?q?=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/dbt/parser/project.py | 6 +++--- tests/dbt/parser/test_project.py | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/cosmos/dbt/parser/project.py b/cosmos/dbt/parser/project.py index 4a0d7fe33..99405bcbc 100644 --- a/cosmos/dbt/parser/project.py +++ b/cosmos/dbt/parser/project.py @@ -174,11 +174,11 @@ def __post_init__(self) -> None: first_arg = base_node.args[0] # if it contains vars, render the value of the var if isinstance(first_arg, jinja2.nodes.Concat): - value = '' + value = "" for node in base_node.args[0].nodes: if isinstance(node, jinja2.nodes.Const): value += node.value - elif isinstance(node, jinja2.nodes.Call) and node.node.name == 'var': + elif isinstance(node, jinja2.nodes.Call) and node.node.name == "var": value += var_args[node.args[0].value] config.upstream_models.add(value) elif isinstance(first_arg, jinja2.nodes.Const): @@ -249,7 +249,7 @@ class DbtProject: models_dir: Path = field(init=False) snapshots_dir: Path = field(init=False) seeds_dir: Path = field(init=False) - + operator_args: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: diff --git a/tests/dbt/parser/test_project.py b/tests/dbt/parser/test_project.py index d4177e82e..a3e592b9d 100644 --- a/tests/dbt/parser/test_project.py +++ b/tests/dbt/parser/test_project.py @@ -198,5 +198,10 @@ def test_dbtmodelconfig_with_vars(tmp_path): path_with_sources = tmp_path / "customers_with_sources.sql" path_with_sources.write_text(model_with_sources_sql) - dbt_model = DbtModel(name="some_name", type=DbtModelType.DBT_MODEL, path=path_with_sources, operator_args={"vars": {"country_code": "us"}}) + dbt_model = DbtModel( + name="some_name", + type=DbtModelType.DBT_MODEL, + path=path_with_sources, + operator_args={"vars": {"country_code": "us"}}, + ) assert "stg_customers_us" in dbt_model.config.upstream_models