Skip to content

Commit

Permalink
Refactor DAtasetAlias implementation to be at the task level and repr…
Browse files Browse the repository at this point in the history
…oduce the airflow standalone issue we're experiencing
  • Loading branch information
tatiana committed Sep 26, 2024
1 parent d9bbc02 commit 8ab7075
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 23 deletions.
20 changes: 2 additions & 18 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from packaging.version import Version

from cosmos.core.graph.entities import Task
from cosmos.dataset import get_dataset_alias_name

# from cosmos.dataset import get_dataset_alias_name
from cosmos.log import get_logger

logger = get_logger(__name__)
Expand All @@ -35,23 +36,6 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None)
if task.owner != "":
task_kwargs["owner"] = task.owner

if module_name.split(".")[-1] == "local" and AIRFLOW_VERSION >= Version("2.10"):
from airflow.datasets import DatasetAlias

# ignoring the type because older versions of Airflow raise the follow error in mypy
# error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") [assignment] Found 1 error in 1 file (checked 3 source files)
task_kwargs["outlets"] = [DatasetAlias(name=get_dataset_alias_name(dag, task_group, task.id))] # type: ignore

logger.info("HELP ME!!!")
logger.info(module_name)
logger.info(Operator)
logger.info(task.id)
logger.info(dag)
logger.info(task_group)
logger.info(task_kwargs)
logger.info({} if class_name == "EmptyOperator" else {"extra_context": task.extra_context})
logger.info(task.arguments)

airflow_task = Operator(
task_id=task.id,
dag=dag,
Expand Down
7 changes: 4 additions & 3 deletions cosmos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_i
dag_id = task_group.dag_id
if task_group.group_id is not None:
task_group_id = task_group.group_id
task_group_id = task_group_id.replace(".", "__")
elif dag:
dag_id = dag.dag_id

identifiers_list = []
dag_id = dag_id
task_group_id = task_group_id

if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id.replace(".", "__"))
identifiers_list.append(task_group_id)

identifiers_list.append(task_id)

return "__".join(identifiers_list)
18 changes: 17 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator):

def __init__(
self,
task_id: str,
profile_config: ProfileConfig,
invocation_mode: InvocationMode | None = None,
install_deps: bool = False,
Expand All @@ -139,6 +140,7 @@ def __init__(
append_env: bool = True,
**kwargs: Any,
) -> None:
self.task_id = task_id
self.profile_config = profile_config
self.callback = callback
self.compiled_sql = ""
Expand All @@ -151,7 +153,19 @@ def __init__(
self._dbt_runner: dbtRunner | None = None
if self.invocation_mode:
self._set_invocation_methods()
super().__init__(**kwargs)

if AIRFLOW_VERSION >= Version("2.10"):
from airflow.datasets import DatasetAlias

# ignoring the type because older versions of Airflow raise the follow error in mypy
# error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str")
dag_id = kwargs.get("dag")
task_group_id = kwargs.get("task_group")
kwargs["outlets"] = [
DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, task_id))
] # type: ignore

super().__init__(task_id=task_id, **kwargs)

# For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment
# variables to the subprocess by default. Although this behavior is designed for ExecuteMode.LOCAL and
Expand Down Expand Up @@ -478,6 +492,8 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]
"""
Register a list of datasets as outlets of the current task.
Until Airflow 2.7, there was not a better interface to associate outlets to a task during execution.
This works before Airflow 2.10 with a few limitations, as described in the ticket:
TODO: add the link to the GH issue related to orphaned nodes
"""
if AIRFLOW_VERSION < Version("2.10"):
logger.info("Assigning inlets/outlets without DatasetAlias")
Expand Down
57 changes: 56 additions & 1 deletion tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,10 @@ def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues

@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.4")
or version.parse(airflow_version) >= version.parse("2.10")
or version.parse(airflow_version) in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS,
reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1",
reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1. \n"
"From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.",
)
@pytest.mark.integration
def test_run_operator_dataset_inlets_and_outlets(caplog):
Expand Down Expand Up @@ -453,6 +455,59 @@ def test_run_operator_dataset_inlets_and_outlets(caplog):
assert test_operator.outlets == []


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.10"),
reason="From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.",
)
@pytest.mark.integration
def test_run_operator_dataset_inlets_and_outlets_airflow_210_onwards(caplog):
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from sqlalchemy import select

with DAG("test-id-1", start_date=datetime(2022, 1, 1)) as dag:
seed_operator = DbtSeedLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="seed",
dag=dag,
dbt_cmd_flags=["--select", "raw_customers"],
install_deps=True,
append_env=True,
)
run_operator = DbtRunLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="run",
dag=dag,
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
append_env=True,
)
test_operator = DbtTestLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="test",
dag=dag,
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
append_env=True,
)
seed_operator >> run_operator >> test_operator

dag_run, session = run_test_dag(dag)

assert session.scalars(select(DatasetModel)).all()
assert session.scalars(select(DatasetAliasModel)).all()
assert False
# assert session == session
# dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "<something>"))
# assert dataset_model == 1
# dataset_alias_models = dataset_model.aliases # Aliases associated to the URI.


# session.query(Dataset).filter_by


@pytest.mark.skipif(
version.parse(airflow_version) not in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS,
reason="Airflow 2.9.0 and 2.9.1 have a breaking change in Dataset URIs",
Expand Down

0 comments on commit 8ab7075

Please sign in to comment.