Skip to content

Commit

Permalink
Add tests & minor refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Sep 29, 2024
1 parent b57c4ed commit 3f00cc9
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 23 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
push: # Run on pushes to the default branch
branches: [main]
pull_request_target: # Also run on pull requests originated from forks
branches: [main]
branches: [main,poc-dbt-compile-task]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand Down Expand Up @@ -176,6 +176,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -248,6 +250,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -316,6 +320,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -393,6 +399,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -537,6 +545,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down
23 changes: 2 additions & 21 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from typing import Any

from cosmos.operators.base import DbtCompileMixin
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtDepsLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsCloudLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
Expand Down Expand Up @@ -56,10 +51,6 @@ class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsCloudAirflowAsyncOperator(DbtDocsCloudLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass

Expand All @@ -72,15 +63,5 @@ class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
pass


class DbtDepsAirflowAsyncOperator(DbtDepsLocalOperator):
class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
pass


class DbtCompileAirflowAsyncOperator(DbtCompileMixin, DbtLocalBaseOperator):
"""
Executes a dbt core build command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["should_upload_compiled_sql"] = True
super().__init__(*args, **kwargs)
7 changes: 7 additions & 0 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtBuildMixin,
DbtCompileMixin,
DbtLSMixin,
DbtRunMixin,
DbtRunOperationMixin,
Expand Down Expand Up @@ -948,3 +949,9 @@ def __init__(self, **kwargs: str) -> None:
raise DeprecationWarning(
"The DbtDepsOperator has been deprecated. " "Please use the `install_deps` flag in dbt_args instead."
)


class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator):
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["should_upload_compiled_sql"] = True
super().__init__(*args, **kwargs)
36 changes: 36 additions & 0 deletions dev/dags/simple_dag_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
from datetime import datetime
from pathlib import Path

from cosmos import DbtDag, ExecutionConfig, ExecutionMode, ProfileConfig, ProjectConfig
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
disable_event_tracking=True,
),
)

simple_dag_async = DbtDag(
# dbt/cosmos-specific parameters
project_config=ProjectConfig(
DBT_ROOT_PATH / "jaffle_shop",
),
profile_config=profile_config,
execution_config=ExecutionConfig(
execution_mode=ExecutionMode.AIRFLOW_ASYNC,
),
# normal dag parameters
schedule_interval=None,
start_date=datetime(2023, 1, 1),
catchup=False,
dag_id="simple_dag_async",
tags=["simple"],
)
36 changes: 36 additions & 0 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from cosmos.converter import airflow_kwargs
from cosmos.dbt.graph import DbtNode
from cosmos.profiles import PostgresUserPasswordProfileMapping
from cosmos.settings import dbt_compile_task_id

SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/")
SOURCE_RENDERING_BEHAVIOR = SourceRenderingBehavior(os.getenv("SOURCE_RENDERING_BEHAVIOR", "none"))
Expand Down Expand Up @@ -226,6 +227,41 @@ def test_build_airflow_graph_with_after_all():
assert dag.leaves[0].select == ["tag:some"]


@pytest.mark.integration
def test_build_airflow_graph_with_dbt_compile_task():
with DAG("test-id-dbt-compile", start_date=datetime(2022, 1, 1)) as dag:
task_args = {
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
}
render_config = RenderConfig(
select=["tag:some"],
test_behavior=TestBehavior.AFTER_ALL,
source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR,
)
build_airflow_graph(
nodes=sample_nodes,
dag=dag,
execution_mode=ExecutionMode.AIRFLOW_ASYNC,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args=task_args,
dbt_project_name="astro_shop",
render_config=render_config,
)

task_ids = [task.task_id for task in dag.tasks]
assert dbt_compile_task_id in task_ids
assert dbt_compile_task_id in dag.tasks[0].upstream_task_ids


def test_calculate_operator_class():
class_module_import_path = calculate_operator_class(execution_mode=ExecutionMode.KUBERNETES, dbt_class="DbtSeed")
assert class_module_import_path == "cosmos.operators.kubernetes.DbtSeedKubernetesOperator"
Expand Down
82 changes: 82 additions & 0 deletions tests/operators/test_airflow_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from cosmos.operators.airflow_async import (
DbtBuildAirflowAsyncOperator,
DbtCompileAirflowAsyncOperator,
DbtDocsAirflowAsyncOperator,
DbtDocsAzureStorageAirflowAsyncOperator,
DbtDocsGCSAirflowAsyncOperator,
DbtDocsS3AirflowAsyncOperator,
DbtLSAirflowAsyncOperator,
DbtRunAirflowAsyncOperator,
DbtRunOperationAirflowAsyncOperator,
DbtSeedAirflowAsyncOperator,
DbtSnapshotAirflowAsyncOperator,
DbtSourceAirflowAsyncOperator,
DbtTestAirflowAsyncOperator,
)
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)


def test_dbt_build_airflow_async_operator_inheritance():
assert issubclass(DbtBuildAirflowAsyncOperator, DbtBuildLocalOperator)


def test_dbt_ls_airflow_async_operator_inheritance():
assert issubclass(DbtLSAirflowAsyncOperator, DbtLSLocalOperator)


def test_dbt_seed_airflow_async_operator_inheritance():
assert issubclass(DbtSeedAirflowAsyncOperator, DbtSeedLocalOperator)


def test_dbt_snapshot_airflow_async_operator_inheritance():
assert issubclass(DbtSnapshotAirflowAsyncOperator, DbtSnapshotLocalOperator)


def test_dbt_source_airflow_async_operator_inheritance():
assert issubclass(DbtSourceAirflowAsyncOperator, DbtSourceLocalOperator)


def test_dbt_run_airflow_async_operator_inheritance():
assert issubclass(DbtRunAirflowAsyncOperator, DbtRunLocalOperator)


def test_dbt_test_airflow_async_operator_inheritance():
assert issubclass(DbtTestAirflowAsyncOperator, DbtTestLocalOperator)


def test_dbt_run_operation_airflow_async_operator_inheritance():
assert issubclass(DbtRunOperationAirflowAsyncOperator, DbtRunOperationLocalOperator)


def test_dbt_docs_airflow_async_operator_inheritance():
assert issubclass(DbtDocsAirflowAsyncOperator, DbtDocsLocalOperator)


def test_dbt_docs_s3_airflow_async_operator_inheritance():
assert issubclass(DbtDocsS3AirflowAsyncOperator, DbtDocsS3LocalOperator)


def test_dbt_docs_azure_storage_airflow_async_operator_inheritance():
assert issubclass(DbtDocsAzureStorageAirflowAsyncOperator, DbtDocsAzureStorageLocalOperator)


def test_dbt_docs_gcs_airflow_async_operator_inheritance():
assert issubclass(DbtDocsGCSAirflowAsyncOperator, DbtDocsGCSLocalOperator)


def test_dbt_compile_airflow_async_operator_inheritance():
assert issubclass(DbtCompileAirflowAsyncOperator, DbtCompileLocalOperator)
2 changes: 2 additions & 0 deletions tests/operators/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtBuildMixin,
DbtCompileMixin,
DbtLSMixin,
DbtRunMixin,
DbtRunOperationMixin,
Expand Down Expand Up @@ -143,6 +144,7 @@ def test_dbt_base_operator_context_merge(
("seed", DbtSeedMixin),
("run", DbtRunMixin),
("build", DbtBuildMixin),
("compile", DbtCompileMixin),
],
)
def test_dbt_mixin_base_cmd(dbt_command, dbt_operator_class):
Expand Down
Loading

0 comments on commit 3f00cc9

Please sign in to comment.