diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 1c00f476c..145741096 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -397,6 +397,8 @@ class DbtSeedLocalOperator(DbtLocalBaseOperator): ui_color = "#F58D7E" + template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] + def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) @@ -434,6 +436,7 @@ class DbtRunLocalOperator(DbtLocalBaseOperator): ui_color = "#7352BA" ui_fgcolor = "#F4F2FC" + template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: self.full_refresh = full_refresh diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 580d49e6c..14213b335 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -420,6 +420,19 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo assert err_msg in caplog.text +@pytest.mark.parametrize( + "operator_class,expected_template", + [ + (DbtSeedLocalOperator, ("env", "vars", "full_refresh")), + (DbtRunLocalOperator, ("env", "vars", "full_refresh")), + ], +) +def test_dbt_base_operator_template_fields(operator_class, expected_template): + # Check if value of template fields is what we expect for the operators we're validating + dbt_base_operator = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir") + assert dbt_base_operator.template_fields == expected_template + + @patch.object(DbtDocsGCSLocalOperator, "required_files", ["file1", "file2"]) def test_dbt_docs_gcs_local_operator(): mock_gcs = MagicMock()