diff --git a/cosmos/dbt/parser/project.py b/cosmos/dbt/parser/project.py index ad6e78e44..15a515287 100644 --- a/cosmos/dbt/parser/project.py +++ b/cosmos/dbt/parser/project.py @@ -174,11 +174,16 @@ def __post_init__(self) -> None: first_arg = base_node.args[0] # if it contains vars, render the value of the var if isinstance(first_arg, jinja2.nodes.Concat): +<<<<<<< HEAD value = '' for node in first_arg.nodes: +======= + value = "" + for node in base_node.args[0].nodes: +>>>>>>> d71cdf07ea2f6f076ecca68e9b3040829f6e5dba if isinstance(node, jinja2.nodes.Const): value += node.value - elif isinstance(node, jinja2.nodes.Call) and node.node.name == 'var': + elif isinstance(node, jinja2.nodes.Call) and node.node.name == "var": value += var_args[node.args[0].value] config.upstream_models.add(value) elif isinstance(first_arg, jinja2.nodes.Const): @@ -249,7 +254,7 @@ class DbtProject: models_dir: Path = field(init=False) snapshots_dir: Path = field(init=False) seeds_dir: Path = field(init=False) - + operator_args: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: diff --git a/tests/dbt/parser/test_project.py b/tests/dbt/parser/test_project.py index b8ef51f1a..5fb00c40c 100644 --- a/tests/dbt/parser/test_project.py +++ b/tests/dbt/parser/test_project.py @@ -198,5 +198,10 @@ def test_dbtmodelconfig_with_vars(tmp_path): path_with_sources = tmp_path / "customers_with_sources.sql" path_with_sources.write_text(model_with_sources_sql) - dbt_model = DbtModel(name="some_name", type=DbtModelType.DBT_MODEL, path=path_with_sources, operator_args={"vars": {"country_code": "us"}}) + dbt_model = DbtModel( + name="some_name", + type=DbtModelType.DBT_MODEL, + path=path_with_sources, + operator_args={"vars": {"country_code": "us"}}, + ) assert "stg_customers_us" in dbt_model.config.upstream_models