Skip to content

Commit

Permalink
Add var value to upstream model
Browse files Browse the repository at this point in the history
  • Loading branch information
dojinkimm committed Aug 31, 2023
1 parent 9387c45 commit b0c2d9b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
1 change: 1 addition & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
select=select,
dbt_cmd=dbt_executable_path,
profile_config=profile_config,
operator_args=operator_args,
)
dbt_graph.load(method=load_mode, execution_mode=execution_mode)

Expand Down
3 changes: 3 additions & 0 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ def __init__(
select: list[str] | None = None,
dbt_cmd: str = get_system_dbt(),
profile_config: ProfileConfig | None = None,
operator_args: dict[str, Any] | None = None,
):
self.project = project
self.exclude = exclude or []
self.select = select or []
self.profile_config = profile_config
self.operator_args = operator_args or {}

# specific to loading using ls
self.dbt_cmd = dbt_cmd
Expand Down Expand Up @@ -253,6 +255,7 @@ def load_via_custom_parser(self) -> None:
dbt_models_dir=self.project.models_dir.stem if self.project.models_dir else None,
dbt_seeds_dir=self.project.seeds_dir.stem if self.project.seeds_dir else None,
project_name=self.project.name,
operator_args=operator_args,
)
nodes = {}
models = itertools.chain(project.models.items(), project.snapshots.items(), project.seeds.items())
Expand Down
53 changes: 36 additions & 17 deletions cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class DbtModel:
name: str
type: DbtModelType
path: Path
operator_args: Dict[str, str] = field(default_factory=dict)
config: DbtModelConfig = field(default_factory=DbtModelConfig)

def __post_init__(self) -> None:
Expand All @@ -137,6 +138,7 @@ def __post_init__(self) -> None:
"""
# first, get an empty config
config = DbtModelConfig()
var_args = self.operator_args.get("vars", {})

if self.type == DbtModelType.DBT_MODEL:
# get the code from the file
Expand Down Expand Up @@ -165,23 +167,35 @@ def __post_init__(self) -> None:
# iterate over the jinja nodes to extract info
for base_node in jinja2_ast.find_all(jinja2.nodes.Call):
if hasattr(base_node.node, "name"):
# check we have a ref - this indicates a dependency
if base_node.node.name == "ref":
# if it is, get the first argument
first_arg = base_node.args[0]
if isinstance(first_arg, jinja2.nodes.Const):
# and add it to the config
config.upstream_models.add(first_arg.value)

# check if we have a config - this could contain tags
if base_node.node.name == "config":
# if it is, check if any kwargs are tags
for kwarg in base_node.kwargs:
for selector in self.config.config_types:
extracted_config = self._extract_config(kwarg=kwarg, config_name=selector)
config.config_selectors |= (
set(extracted_config) if isinstance(extracted_config, (str, List)) else set()
)
try:
# check we have a ref - this indicates a dependency
if base_node.node.name == "ref":
# if it is, get the first argument
first_arg = base_node.args[0]
# if it contains vars, render the value of the var
if isinstance(first_arg, jinja2.nodes.Concat):
value = ''
for node in base_node.args[0].nodes:
if isinstance(node, jinja2.nodes.Const):
value += node.value
elif isinstance(node, jinja2.nodes.Call) and node.node.name == 'var':
value += var_args[node.args[0].value]
config.upstream_models.add(value)
elif isinstance(first_arg, jinja2.nodes.Const):
# and add it to the config
config.upstream_models.add(first_arg.value)

# check if we have a config - this could contain tags
if base_node.node.name == "config":
# if it is, check if any kwargs are tags
for kwarg in base_node.kwargs:
for selector in self.config.config_types:
extracted_config = self._extract_config(kwarg=kwarg, config_name=selector)
config.config_selectors |= (
set(extracted_config) if isinstance(extracted_config, (str, List)) else set()
)
except KeyError as e:
logger.warning(f"Could not add upstream model for config in {self.path}: {e}")

# set the config and set the parsed file flag to true
self.config = config
Expand Down Expand Up @@ -235,6 +249,8 @@ class DbtProject:
models_dir: Path = field(init=False)
snapshots_dir: Path = field(init=False)
seeds_dir: Path = field(init=False)

operator_args: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
"""
Expand Down Expand Up @@ -287,6 +303,7 @@ def _handle_csv_file(self, path: Path) -> None:
name=model_name,
type=DbtModelType.DBT_SEED,
path=path,
operator_args=self.operator_args,
)
# add the model to the project
self.seeds[model_name] = model
Expand All @@ -304,6 +321,7 @@ def _handle_sql_file(self, path: Path) -> None:
name=model_name,
type=DbtModelType.DBT_MODEL,
path=path,
operator_args=self.operator_args,
)
# add the model to the project
self.models[model.name] = model
Expand All @@ -313,6 +331,7 @@ def _handle_sql_file(self, path: Path) -> None:
name=model_name,
type=DbtModelType.DBT_SNAPSHOT,
path=path,
operator_args=self.operator_args,
)
# add the snapshot to the project
self.snapshots[model.name] = model
Expand Down

0 comments on commit b0c2d9b

Please sign in to comment.