Skip to content

Commit

Permalink
Add operator_args full_refresh as a templated field (#623)
Browse files Browse the repository at this point in the history
This allows you to fully refresh a model from the console.
Full-refresh/backfill is a common task. Using Airflow parameters makes
this easy. Without this, you'd have to trigger an entire deployment. In our
setup, company analysts manage their models without modifying
the DAG code. This empowers such users.

Example of usage:
```python
with DAG(
        dag_id="jaffle",
        params={"full_refresh": Param(default=False, type="boolean")},
        render_template_as_native_obj=True
):
    task = DbtTaskGroup(
        operator_args={"full_refresh": "{{ params.get('full_refresh') }}", "install_deps": True},
    )

```

Closes: #151
  • Loading branch information
joppevos authored Nov 8, 2023
1 parent 188fe56 commit 9cd46d2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ class DbtSeedLocalOperator(DbtLocalBaseOperator):

ui_color = "#F58D7E"

template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator]

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)
Expand Down Expand Up @@ -434,6 +436,7 @@ class DbtRunLocalOperator(DbtLocalBaseOperator):

ui_color = "#7352BA"
ui_fgcolor = "#F4F2FC"
template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator]

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
Expand Down
13 changes: 13 additions & 0 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,19 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo
assert err_msg in caplog.text


@pytest.mark.parametrize(
"operator_class,expected_template",
[
(DbtSeedLocalOperator, ("env", "vars", "full_refresh")),
(DbtRunLocalOperator, ("env", "vars", "full_refresh")),
],
)
def test_dbt_base_operator_template_fields(operator_class, expected_template):
# Check if value of template fields is what we expect for the operators we're validating
dbt_base_operator = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir")
assert dbt_base_operator.template_fields == expected_template


@patch.object(DbtDocsGCSLocalOperator, "required_files", ["file1", "file2"])
def test_dbt_docs_gcs_local_operator():
mock_gcs = MagicMock()
Expand Down

0 comments on commit 9cd46d2

Please sign in to comment.