Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parsing test nodes when using the custom load method (LoadMethod.CUSTOM) #563

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def load_via_custom_parser(self) -> None:
operator_args=self.operator_args,
)
nodes = {}
models = itertools.chain(project.models.items(), project.snapshots.items(), project.seeds.items())
models = itertools.chain(
project.models.items(), project.snapshots.items(), project.seeds.items(), project.tests.items()
)
for model_name, model in models:
config = {item.split(":")[0]: item.split(":")[-1] for item in model.config.config_selectors}
node = DbtNode(
Expand Down
39 changes: 32 additions & 7 deletions cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@

class DbtModelType(Enum):
"""
Represents type of dbt unit (model, snapshot, seed)
Represents type of dbt unit (model, snapshot, seed, test)
"""

DBT_MODEL = "model"
DBT_SNAPSHOT = "snapshot"
DBT_SEED = "seed"
DBT_TEST = "test"


@dataclass
Expand Down Expand Up @@ -155,8 +156,8 @@
code = code.split("%}")[1]
code = code.split("{%")[0]

elif self.type == DbtModelType.DBT_SEED:
code = ""
elif self.type == DbtModelType.DBT_SEED or self.type == DbtModelType.DBT_TEST:
return

if self.path.suffix == PYTHON_FILE_SUFFIX:
config.upstream_models = config.upstream_models.union(set(extract_python_file_upstream_requirements(code)))
Expand Down Expand Up @@ -250,6 +251,7 @@
models: Dict[str, DbtModel] = field(default_factory=dict)
snapshots: Dict[str, DbtModel] = field(default_factory=dict)
seeds: Dict[str, DbtModel] = field(default_factory=dict)
tests: Dict[str, DbtModel] = field(default_factory=dict)
project_dir: Path = field(init=False)
models_dir: Path = field(init=False)
snapshots_dir: Path = field(init=False)
Expand Down Expand Up @@ -349,19 +351,42 @@
config_dict = yaml.safe_load(path.read_text())

# iterate over the models in the config
if not (config_dict and config_dict.get("models")):
if not config_dict:
return

for model in config_dict["models"]:
for model in config_dict.get("models", []):
model_name = model.get("name")

# if the model doesn't exist, we can't do anything
if model_name not in self.models:
if not model_name:
continue

# tests
for column in model.get("columns", []):
for test in column.get("tests", []):
if not column.get("name"):
continue

Check warning on line 368 in cosmos/dbt/parser/project.py

View check run for this annotation

Codecov / codecov/patch

cosmos/dbt/parser/project.py#L368

Added line #L368 was not covered by tests

# Get the test name
if not isinstance(test, str):
test = list(test.keys())[0]

test_model = DbtModel(
name=f"{test}_{column['name']}_{model_name}",
type=DbtModelType.DBT_TEST,
path=path,
operator_args=self.operator_args,
config=DbtModelConfig(upstream_models=set({model_name})),
)

self.tests[test_model.name] = test_model

# config_selectors
if model_name not in self.models:
continue

config_selectors = []
for selector in self.models[model_name].config.config_types:
for selector in DbtModelConfig.config_types:
config_value = model.get("config", {}).get(selector)
if config_value:
if isinstance(config_value, str):
Expand Down
17 changes: 17 additions & 0 deletions tests/dbt/parser/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SAMPLE_CSV_PATH = DBT_PROJECT_PATH / "jaffle_shop/seeds/raw_customers.csv"
SAMPLE_MODEL_SQL_PATH = DBT_PROJECT_PATH / "jaffle_shop/models/customers.sql"
SAMPLE_SNAPSHOT_SQL_PATH = DBT_PROJECT_PATH / "jaffle_shop/models/orders.sql"
SAMPLE_YML_PATH = DBT_PROJECT_PATH / "jaffle_shop/models/schema.yml"


def test_dbtproject__handle_csv_file():
Expand Down Expand Up @@ -64,6 +65,22 @@ def test_dbtproject__handle_sql_file_snapshot():
assert raw_customers.path == SAMPLE_SNAPSHOT_SQL_PATH


def test_dbtproject__handle_config_file():
dbt_project = DbtProject(
project_name="jaffle_shop",
dbt_root_path=DBT_PROJECT_PATH,
)
dbt_project.tests = {}

dbt_project._handle_config_file(SAMPLE_YML_PATH)

assert len(dbt_project.tests) == 12
assert "not_null_customer_id_customers" in dbt_project.tests
sample_test = dbt_project.tests["not_null_customer_id_customers"]
assert sample_test.type == DbtModelType.DBT_TEST
assert sample_test.path == SAMPLE_YML_PATH


def test_dbtproject__handle_config_file_empty_file():
with NamedTemporaryFile("w") as tmp_fp:
tmp_fp.flush()
Expand Down
3 changes: 1 addition & 2 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ def test_load_via_load_via_custom_parser(pipeline_name):
dbt_graph.load_via_custom_parser()

assert dbt_graph.nodes == dbt_graph.filtered_nodes
# the custom parser does not add dbt test nodes
assert len(dbt_graph.nodes) == 8
assert len(dbt_graph.nodes) == 28


@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency", return_value=None)
Expand Down
Loading