diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 458b8fa56..2782412ac 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -9,7 +9,8 @@ from packaging.version import Version from cosmos.core.graph.entities import Task -from cosmos.dataset import get_dataset_alias_name + +# from cosmos.dataset import get_dataset_alias_name from cosmos.log import get_logger logger = get_logger(__name__) @@ -35,23 +36,6 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) if task.owner != "": task_kwargs["owner"] = task.owner - if module_name.split(".")[-1] == "local" and AIRFLOW_VERSION >= Version("2.10"): - from airflow.datasets import DatasetAlias - - # ignoring the type because older versions of Airflow raise the follow error in mypy - # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") [assignment] Found 1 error in 1 file (checked 3 source files) - task_kwargs["outlets"] = [DatasetAlias(name=get_dataset_alias_name(dag, task_group, task.id))] # type: ignore - - logger.info("HELP ME!!!") - logger.info(module_name) - logger.info(Operator) - logger.info(task.id) - logger.info(dag) - logger.info(task_group) - logger.info(task_kwargs) - logger.info({} if class_name == "EmptyOperator" else {"extra_context": task.extra_context}) - logger.info(task.arguments) - airflow_task = Operator( task_id=task.id, dag=dag, diff --git a/cosmos/dataset.py b/cosmos/dataset.py index 4ab727fa9..2a308c54e 100644 --- a/cosmos/dataset.py +++ b/cosmos/dataset.py @@ -17,16 +17,17 @@ def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_i dag_id = task_group.dag_id if task_group.group_id is not None: task_group_id = task_group.group_id + task_group_id = task_group_id.replace(".", "__") elif dag: dag_id = dag.dag_id identifiers_list = [] - dag_id = dag_id - task_group_id = task_group_id if dag_id: identifiers_list.append(dag_id) if task_group_id: - identifiers_list.append(task_group_id.replace(".", "__")) + identifiers_list.append(task_group_id) + + identifiers_list.append(task_id) return "__".join(identifiers_list) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 58fa5f63b..a593a2fea 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -131,6 +131,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): def __init__( self, + task_id: str, profile_config: ProfileConfig, invocation_mode: InvocationMode | None = None, install_deps: bool = False, @@ -139,6 +140,7 @@ def __init__( append_env: bool = True, **kwargs: Any, ) -> None: + self.task_id = task_id self.profile_config = profile_config self.callback = callback self.compiled_sql = "" @@ -151,7 +153,19 @@ def __init__( self._dbt_runner: dbtRunner | None = None if self.invocation_mode: self._set_invocation_methods() - super().__init__(**kwargs) + + if AIRFLOW_VERSION >= Version("2.10"): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in mypy + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") + dag_id = kwargs.get("dag") + task_group_id = kwargs.get("task_group") + kwargs["outlets"] = [ + DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, task_id)) + ] # type: ignore + + super().__init__(task_id=task_id, **kwargs) # For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment # variables to the subprocess by default. Although this behavior is designed for ExecuteMode.LOCAL and @@ -478,6 +492,8 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset] """ Register a list of datasets as outlets of the current task. Until Airflow 2.7, there was not a better interface to associate outlets to a task during execution. + This works before Airflow 2.10 with a few limitations, as described in the ticket: + TODO: add the link to the GH issue related to orphaned nodes """ if AIRFLOW_VERSION < Version("2.10"): logger.info("Assigning inlets/outlets without DatasetAlias") diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index effd604fa..e89615bd7 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -411,8 +411,10 @@ def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.4") + or version.parse(airflow_version) >= version.parse("2.10") or version.parse(airflow_version) in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS, - reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1", + reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1. \n" + "From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.", ) @pytest.mark.integration def test_run_operator_dataset_inlets_and_outlets(caplog): @@ -453,6 +455,59 @@ def test_run_operator_dataset_inlets_and_outlets(caplog): assert test_operator.outlets == [] +@pytest.mark.skipif( + version.parse(airflow_version) < version.parse("2.10"), + reason="From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.", +) +@pytest.mark.integration +def test_run_operator_dataset_inlets_and_outlets_airflow_210_onwards(caplog): + from airflow.models.dataset import DatasetAliasModel, DatasetModel + from sqlalchemy import select + + with DAG("test-id-1", start_date=datetime(2022, 1, 1)) as dag: + seed_operator = DbtSeedLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="seed", + dag=dag, + dbt_cmd_flags=["--select", "raw_customers"], + install_deps=True, + append_env=True, + ) + run_operator = DbtRunLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="run", + dag=dag, + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + test_operator = DbtTestLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="test", + dag=dag, + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + seed_operator >> run_operator >> test_operator + + dag_run, session = run_test_dag(dag) + + assert session.scalars(select(DatasetModel)).all() + assert session.scalars(select(DatasetAliasModel)).all() + assert False + # assert session == session + # dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "")) + # assert dataset_model == 1 + # dataset_alias_models = dataset_model.aliases # Aliases associated to the URI. + + +# session.query(Dataset).filter_by + + @pytest.mark.skipif( version.parse(airflow_version) not in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS, reason="Airflow 2.9.0 and 2.9.1 have a breaking change in Dataset URIs",