Skip to content

Commit

Permalink
Refactor, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Sep 30, 2024
1 parent 8ab7075 commit 3f72cc1
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 27 deletions.
5 changes: 0 additions & 5 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@

import importlib

import airflow
from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from packaging.version import Version

from cosmos.core.graph.entities import Task

# from cosmos.dataset import get_dataset_alias_name
from cosmos.log import get_logger

logger = get_logger(__name__)
AIRFLOW_VERSION = Version(airflow.__version__)


def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) -> BaseOperator:
Expand Down
18 changes: 10 additions & 8 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from attr import define
from packaging.version import Version

from cosmos import cache
from cosmos import cache, settings
from cosmos.cache import (
_copy_cached_package_lockfile_to_project,
_get_latest_cached_package_lockfile,
Expand All @@ -29,7 +29,6 @@
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError
from cosmos.settings import LINEAGE_NAMESPACE

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -154,7 +153,7 @@ def __init__(
if self.invocation_mode:
self._set_invocation_methods()

if AIRFLOW_VERSION >= Version("2.10"):
if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"):
from airflow.datasets import DatasetAlias

# ignoring the type because older versions of Airflow raise the follow error in mypy
Expand Down Expand Up @@ -442,7 +441,7 @@ def calculate_openlineage_events_completes(

openlineage_processor = DbtLocalArtifactProcessor(
producer=OPENLINEAGE_PRODUCER,
job_namespace=LINEAGE_NAMESPACE,
job_namespace=settings.LINEAGE_NAMESPACE,
project_dir=project_dir,
profile_name=self.profile_config.profile_name,
target=self.profile_config.target_name,
Expand Down Expand Up @@ -490,12 +489,16 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]:

def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset], context: Context) -> None:
"""
Register a list of datasets as outlets of the current task.
Register a list of datasets as outlets of the current task, when possible.
Until Airflow 2.7, there was not a better interface to associate outlets to a task during execution.
This works before Airflow 2.10 with a few limitations, as described in the ticket:
TODO: add the link to the GH issue related to orphaned nodes
https://github.com/astronomer/astronomer-cosmos/issues/522
In Airflow 2.10.0 and 2.10.1, we are not able to test Airflow DAGs powered with DatasetAlias.
https://github.com/apache/airflow/issues/42495
"""
if AIRFLOW_VERSION < Version("2.10"):
if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias:
logger.info("Assigning inlets/outlets without DatasetAlias")
with create_session() as session:
self.outlets.extend(new_outlets)
Expand All @@ -511,7 +514,6 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]
dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id)
for outlet in new_outlets:
context["outlet_events"][dataset_alias_name].add(outlet)
# TODO: check equivalent to inlets

def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage:
"""
Expand Down
1 change: 1 addition & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), DEFAULT_COSMOS_CACHE_DIR_NAME)
cache_dir = Path(conf.get("cosmos", "cache_dir", fallback=DEFAULT_CACHE_DIR) or DEFAULT_CACHE_DIR)
enable_cache = conf.getboolean("cosmos", "enable_cache", fallback=True)
enable_dataset_alias = conf.getboolean("cosmos", "enable_dataset_alias", fallback=True)
enable_cache_partial_parse = conf.getboolean("cosmos", "enable_cache_partial_parse", fallback=True)
enable_cache_package_lockfile = conf.getboolean("cosmos", "enable_cache_package_lockfile", fallback=True)
enable_cache_dbt_ls = conf.getboolean("cosmos", "enable_cache_dbt_ls", fallback=True)
Expand Down
28 changes: 15 additions & 13 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,10 @@ def test_run_operator_dataset_inlets_and_outlets(caplog):
)
@pytest.mark.integration
def test_run_operator_dataset_inlets_and_outlets_airflow_210_onwards(caplog):
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from sqlalchemy import select
from airflow.models.dataset import DatasetAliasModel
from sqlalchemy.orm.exc import FlushError

with DAG("test-id-1", start_date=datetime(2022, 1, 1)) as dag:
with DAG("test_id_1", start_date=datetime(2022, 1, 1)) as dag:
seed_operator = DbtSeedLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
Expand Down Expand Up @@ -494,18 +494,20 @@ def test_run_operator_dataset_inlets_and_outlets_airflow_210_onwards(caplog):
)
seed_operator >> run_operator >> test_operator

dag_run, session = run_test_dag(dag)

assert session.scalars(select(DatasetModel)).all()
assert session.scalars(select(DatasetAliasModel)).all()
assert False
# assert session == session
# dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "<something>"))
# assert dataset_model == 1
# dataset_alias_models = dataset_model.aliases # Aliases associated to the URI.
assert seed_operator.outlets == [DatasetAliasModel(name="test_id_1__seed")]
assert run_operator.outlets == [DatasetAliasModel(name="test_id_1__run")]
assert test_operator.outlets == [DatasetAliasModel(name="test_id_1__test")]

with pytest.raises(FlushError):
# This is a known limitation of Airflow 2.10.0 and 2.10.1
# https://github.com/apache/airflow/issues/42495
dag_run, session = run_test_dag(dag)

# session.query(Dataset).filter_by
# Once this issue is solved, we should do some type of check on the actual datasets being emitted,
# so we guarantee Cosmos is backwards compatible via tests using something along the lines or an alternative,
# based on the resolution of the issue logged in Airflow:
# dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "<something>"))
# assert dataset_model == 1


@pytest.mark.skipif(
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_dag(

print("conn_file_path", conn_file_path)

return dr
return dr, session


def add_logger_if_needed(dag: DAG, ti: TaskInstance):
Expand Down

0 comments on commit 3f72cc1

Please sign in to comment.