diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 9e1d08ac1..1bdce9361 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib from airflow.models import BaseOperator @@ -10,7 +12,7 @@ logger = get_logger(__name__) -def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator: +def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) -> BaseOperator: """ Get the Airflow Operator class for a Task. diff --git a/cosmos/dataset.py b/cosmos/dataset.py new file mode 100644 index 000000000..2a308c54e --- /dev/null +++ b/cosmos/dataset.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from airflow import DAG +from airflow.utils.task_group import TaskGroup + + +def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_id: str) -> str: + """ + Given the Airflow DAG, Airflow TaskGroup and the Airflow Task ID, return the name of the + Airflow DatasetAlias associated to that task. + """ + dag_id = None + task_group_id = None + + if task_group: + if task_group.dag_id is not None: + dag_id = task_group.dag_id + if task_group.group_id is not None: + task_group_id = task_group.group_id + task_group_id = task_group_id.replace(".", "__") + elif dag: + dag_id = dag.dag_id + + identifiers_list = [] + + if dag_id: + identifiers_list.append(dag_id) + if task_group_id: + identifiers_list.append(task_group_id) + + identifiers_list.append(task_id) + + return "__".join(identifiers_list) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 701552f56..557bfe500 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence +import airflow import jinja2 from airflow import DAG from airflow.exceptions import AirflowException, AirflowSkipException @@ -17,17 +18,18 @@ from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session from attr import define +from packaging.version import Version -from cosmos import cache +from cosmos import cache, settings from cosmos.cache import ( _copy_cached_package_lockfile_to_project, _get_latest_cached_package_lockfile, is_cache_package_lockfile_enabled, ) from cosmos.constants import InvocationMode +from cosmos.dataset import get_dataset_alias_name from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file from cosmos.exceptions import AirflowCompatibilityError -from cosmos.settings import LINEAGE_NAMESPACE try: from airflow.datasets import Dataset @@ -43,6 +45,7 @@ from dbt.cli.main import dbtRunner, dbtRunnerResult from openlineage.client.run import RunEvent + from sqlalchemy.orm import Session from cosmos.config import ProfileConfig @@ -73,6 +76,8 @@ DbtTestMixin, ) +AIRFLOW_VERSION = Version(airflow.__version__) + logger = get_logger(__name__) try: @@ -126,6 +131,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): def __init__( self, + task_id: str, profile_config: ProfileConfig, invocation_mode: InvocationMode | None = None, install_deps: bool = False, @@ -134,6 +140,7 @@ def __init__( append_env: bool = True, **kwargs: Any, ) -> None: + self.task_id = task_id self.profile_config = profile_config self.callback = callback self.compiled_sql = "" @@ -146,7 +153,19 @@ def __init__( self._dbt_runner: dbtRunner | None = None if self.invocation_mode: self._set_invocation_methods() - super().__init__(**kwargs) + + if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in mypy + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") + dag_id = kwargs.get("dag") + task_group_id = kwargs.get("task_group") + kwargs["outlets"] = [ + DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, task_id)) + ] # type: ignore + + super().__init__(task_id=task_id, **kwargs) # For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment # variables to the subprocess by default. Although this behavior is designed for ExecuteMode.LOCAL and @@ -388,7 +407,7 @@ def run_command( outlets = self.get_datasets("outputs") self.log.info("Inlets: %s", inlets) self.log.info("Outlets: %s", outlets) - self.register_dataset(inlets, outlets) + self.register_dataset(inlets, outlets, context) if self.partial_parse and self.cache_dir: partial_parse_file = get_partial_parse_path(tmp_dir_path) @@ -423,7 +442,7 @@ def calculate_openlineage_events_completes( openlineage_processor = DbtLocalArtifactProcessor( producer=OPENLINEAGE_PRODUCER, - job_namespace=LINEAGE_NAMESPACE, + job_namespace=settings.LINEAGE_NAMESPACE, project_dir=project_dir, profile_name=self.profile_config.profile_name, target=self.profile_config.target_name, @@ -469,20 +488,37 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]: ) return datasets - def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]) -> None: + def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset], context: Context) -> None: """ - Register a list of datasets as outlets of the current task. + Register a list of datasets as outlets of the current task, when possible. + Until Airflow 2.7, there was not a better interface to associate outlets to a task during execution. + This works in Cosmos with versions before Airflow 2.10 with a few limitations, as described in the ticket: + https://github.com/astronomer/astronomer-cosmos/issues/522 + + Since Airflow 2.10, Cosmos uses DatasetAlias by default, to generate datasets. This resolved the limitations + described before. + + The only limitation is that with Airflow 2.10.0 and 2.10.1, the `airflow dags test` command will not work + with DatasetAlias: + https://github.com/apache/airflow/issues/42495 """ - with create_session() as session: - self.outlets.extend(new_outlets) - self.inlets.extend(new_inlets) - for task in self.dag.tasks: - if task.task_id == self.task_id: - task.outlets.extend(new_outlets) - task.inlets.extend(new_inlets) - DAG.bulk_write_to_db([self.dag], session=session) - session.commit() + if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias: + logger.info("Assigning inlets/outlets without DatasetAlias") + with create_session() as session: + self.outlets.extend(new_outlets) + self.inlets.extend(new_inlets) + for task in self.dag.tasks: + if task.task_id == self.task_id: + task.outlets.extend(new_outlets) + task.inlets.extend(new_inlets) + DAG.bulk_write_to_db([self.dag], session=session) + session.commit() + else: + logger.info("Assigning inlets/outlets with DatasetAlias") + dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) + for outlet in new_outlets: + context["outlet_events"][dataset_alias_name].add(outlet) def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage: """ diff --git a/cosmos/settings.py b/cosmos/settings.py index 43abc8897..6449630ae 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -18,6 +18,7 @@ DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), DEFAULT_COSMOS_CACHE_DIR_NAME) cache_dir = Path(conf.get("cosmos", "cache_dir", fallback=DEFAULT_CACHE_DIR) or DEFAULT_CACHE_DIR) enable_cache = conf.getboolean("cosmos", "enable_cache", fallback=True) +enable_dataset_alias = conf.getboolean("cosmos", "enable_dataset_alias", fallback=True) enable_cache_partial_parse = conf.getboolean("cosmos", "enable_cache_partial_parse", fallback=True) enable_cache_package_lockfile = conf.getboolean("cosmos", "enable_cache_package_lockfile", fallback=True) enable_cache_dbt_ls = conf.getboolean("cosmos", "enable_cache_dbt_ls", fallback=True) diff --git a/docs/configuration/scheduling.rst b/docs/configuration/scheduling.rst index 16fccca14..60e466d34 100644 --- a/docs/configuration/scheduling.rst +++ b/docs/configuration/scheduling.rst @@ -26,7 +26,7 @@ Data-Aware Scheduling `Apache Airflow® `_ 2.4 introduced the concept of `scheduling based on Datasets `_. -By default, if Airflow 2.4 or higher is used, Cosmos emits `Airflow Datasets `_ when running dbt projects. This allows you to use Airflow's data-aware scheduling capabilities to schedule your dbt projects. Cosmos emits datasets using the OpenLineage URI format, as detailed in the `OpenLineage Naming Convention `_. +By default, if using a version between Airflow 2.4 or higher is used, Cosmos emits `Airflow Datasets `_ when running dbt projects. This allows you to use Airflow's data-aware scheduling capabilities to schedule your dbt projects. Cosmos emits datasets using the OpenLineage URI format, as detailed in the `OpenLineage Naming Convention `_. Cosmos calculates these URIs during the task execution, by using the library `OpenLineage Integration Common `_. @@ -62,3 +62,48 @@ Then, you can use Airflow's data-aware scheduling capabilities to schedule ``my_ ) In this scenario, ``project_one`` runs once a day and ``project_two`` runs immediately after ``project_one``. You can view these dependencies in Airflow's UI. + +Known Limitations +................. + +Airflow 2.9 and below +_____________________ + +If using cosmos with an Airflow 2.9 or below, users will experience the following issues: + +- The task inlets and outlets generated by Cosmos will not be seen in the Airflow UI +- The scheduler logs will contain many messages saying "Orphaning unreferenced dataset" + +Example of scheduler logs: + +.. code-block:: + scheduler | [2023-09-08T10:18:34.252+0100] {scheduler_job_runner.py:1742} INFO - Orphaning unreferenced dataset 'postgres://0.0.0.0:5432/postgres.public.stg_customers' + scheduler | [2023-09-08T10:18:34.252+0100] {scheduler_job_runner.py:1742} INFO - Orphaning unreferenced dataset 'postgres://0.0.0.0:5432/postgres.public.stg_payments' + scheduler | [2023-09-08T10:18:34.252+0100] {scheduler_job_runner.py:1742} INFO - Orphaning unreferenced dataset 'postgres://0.0.0.0:5432/postgres.public.stg_orders' + scheduler | [2023-09-08T10:18:34.252+0100] {scheduler_job_runner.py:1742} INFO - Orphaning unreferenced dataset 'postgres://0.0.0.0:5432/postgres.public.customers' + + +References about the root cause of these issues: + +- https://github.com/astronomer/astronomer-cosmos/issues/522 +- https://github.com/apache/airflow/issues/34206 + + +Airflow 2.10.0 and 2.10.1 +_________________________ + +If using Cosmos with Airflow 2.10.0 or 2.10.1, the two issues previously described are resolved, since Cosmos uses ``DatasetAlias`` +to support the dynamic creation of datasets during task execution. However, users may face ``sqlalchemy.orm.exc.FlushError`` +errors if they attempt to run Cosmos-powered DAGs using ``airflow dags test`` with these versions. + +We've reported this issue and it will be resolved in future versions of Airflow: + +- https://github.com/apache/airflow/issues/42495 + +For users to overcome this limitation in local tests, until the Airflow community solves this, we introduced the configuration +``AIRFLOW__COSMOS__ENABLE_DATASET_ALIAS``, that is ``True`` by default. If users want to run ``dags test` and not see ``sqlalchemy.orm.exc.FlushError``, +they can set this configuration to ``False``. It can also be set in the ``airflow.cfg`` file: + +.. code-block:: + [cosmos] + enable_dataset_alias = False diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index d54bbb5e1..04001ca75 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -411,8 +411,10 @@ def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.4") + or version.parse(airflow_version) >= version.parse("2.10") or version.parse(airflow_version) in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS, - reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1", + reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1. \n" + "From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.", ) @pytest.mark.integration def test_run_operator_dataset_inlets_and_outlets(caplog): @@ -453,6 +455,82 @@ def test_run_operator_dataset_inlets_and_outlets(caplog): assert test_operator.outlets == [] +@pytest.mark.skipif( + version.parse(airflow_version) < version.parse("2.10"), + reason="From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.", +) +@pytest.mark.integration +def test_run_operator_dataset_inlets_and_outlets_airflow_210_onwards(caplog): + from airflow.models.dataset import DatasetAliasModel + from sqlalchemy.orm.exc import FlushError + + with DAG("test_id_1", start_date=datetime(2022, 1, 1)) as dag: + seed_operator = DbtSeedLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="seed", + dag=dag, + emit_datasets=False, + dbt_cmd_flags=["--select", "raw_customers"], + install_deps=True, + append_env=True, + ) + run_operator = DbtRunLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="run", + dag=dag, + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + test_operator = DbtTestLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="test", + dag=dag, + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + seed_operator >> run_operator >> test_operator + + assert seed_operator.outlets == [] # because emit_datasets=False, + assert run_operator.outlets == [DatasetAliasModel(name="test_id_1__run")] + assert test_operator.outlets == [DatasetAliasModel(name="test_id_1__test")] + + with pytest.raises(FlushError): + # This is a known limitation of Airflow 2.10.0 and 2.10.1 + # https://github.com/apache/airflow/issues/42495 + dag_run, session = run_test_dag(dag) + + # Once this issue is solved, we should do some type of check on the actual datasets being emitted, + # so we guarantee Cosmos is backwards compatible via tests using something along the lines or an alternative, + # based on the resolution of the issue logged in Airflow: + # dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "")) + # assert dataset_model == 1 + + +@patch("cosmos.settings.enable_dataset_alias", 0) +@pytest.mark.skipif( + version.parse(airflow_version) < version.parse("2.10"), + reason="From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.", +) +@pytest.mark.integration +def test_run_operator_dataset_inlets_and_outlets_airflow_210_onwards_disabled_via_envvar(caplog): + with DAG("test_id_2", start_date=datetime(2022, 1, 1)) as dag: + run_operator = DbtRunLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="run", + dag=dag, + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + assert run_operator.outlets == [] + + @pytest.mark.skipif( version.parse(airflow_version) not in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS, reason="Airflow 2.9.0 and 2.9.1 have a breaking change in Dataset URIs", @@ -495,6 +573,7 @@ def test_run_operator_dataset_emission_is_skipped(caplog): reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1", ) @pytest.mark.integration +@patch("cosmos.settings.enable_dataset_alias", 0) def test_run_operator_dataset_url_encoded_names(caplog): from airflow.datasets import Dataset diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 000000000..12e423323 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,39 @@ +from datetime import datetime + +import pytest +from airflow import DAG +from airflow.utils.task_group import TaskGroup + +from cosmos.dataset import get_dataset_alias_name + +START_DATE = datetime(2024, 4, 16) +example_dag = DAG("dag", start_date=START_DATE) + + +@pytest.mark.parametrize( + "dag, task_group, result_identifier", + [ + (example_dag, None, "dag__task_id"), + (None, TaskGroup(dag=example_dag, group_id="inner_tg"), "dag__inner_tg__task_id"), + ( + None, + TaskGroup( + dag=example_dag, group_id="child_tg", parent_group=TaskGroup(dag=example_dag, group_id="parent_tg") + ), + "dag__parent_tg__child_tg__task_id", + ), + ( + None, + TaskGroup( + dag=example_dag, + group_id="child_tg", + parent_group=TaskGroup( + dag=example_dag, group_id="mum_tg", parent_group=TaskGroup(dag=example_dag, group_id="nana_tg") + ), + ), + "dag__nana_tg__mum_tg__child_tg__task_id", + ), + ], +) +def test_get_dataset_alias_name(dag, task_group, result_identifier): + assert get_dataset_alias_name(dag, task_group, "task_id") == result_identifier diff --git a/tests/utils.py b/tests/utils.py index 37f7a3223..1f73b693a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -87,7 +87,7 @@ def test_dag( print("conn_file_path", conn_file_path) - return dr + return dr, session def add_logger_if_needed(dag: DAG, ti: TaskInstance):