From 9cd46d2353dff874cc94c66baa0b146255ec42f5 Mon Sep 17 00:00:00 2001 From: Joppe Vos <44348300+joppevos@users.noreply.github.com> Date: Wed, 8 Nov 2023 16:29:49 +0100 Subject: [PATCH] Add `operator_args` `full_refresh` as a templated field (#623) This allows you to fully refresh a model from the console. Full-refresh/backfill is a common task. Using Airflow parameters makes this easy. Without this, you'd have to trigger an entire deployment. In our setup, company analysts manage their models without modifying the DAG code. This empowers such users. Example of usage: ```python with DAG( dag_id="jaffle", params={"full_refresh": Param(default=False, type="boolean")}, render_template_as_native_obj=True ): task = DbtTaskGroup( operator_args={"full_refresh": "{{ params.get('full_refresh') }}", "install_deps": True}, ) ``` Closes: #151 --- cosmos/operators/local.py | 3 +++ tests/operators/test_local.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 1c00f476c..145741096 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -397,6 +397,8 @@ class DbtSeedLocalOperator(DbtLocalBaseOperator): ui_color = "#F58D7E" + template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] + def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) @@ -434,6 +436,7 @@ class DbtRunLocalOperator(DbtLocalBaseOperator): ui_color = "#7352BA" ui_fgcolor = "#F4F2FC" + template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: self.full_refresh = full_refresh diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 580d49e6c..14213b335 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -420,6 +420,19 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo assert err_msg in caplog.text +@pytest.mark.parametrize( + "operator_class,expected_template", + [ + (DbtSeedLocalOperator, ("env", "vars", "full_refresh")), + (DbtRunLocalOperator, ("env", "vars", "full_refresh")), + ], +) +def test_dbt_base_operator_template_fields(operator_class, expected_template): + # Check if value of template fields is what we expect for the operators we're validating + dbt_base_operator = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir") + assert dbt_base_operator.template_fields == expected_template + + @patch.object(DbtDocsGCSLocalOperator, "required_files", ["file1", "file2"]) def test_dbt_docs_gcs_local_operator(): mock_gcs = MagicMock()