Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ground work for ExecutionMode.AIRFLOW_ASYNC #1224

Merged
merged 11 commits into from
Sep 30, 2024
12 changes: 11 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: test

on:
push: # Run on pushes to the default branch
branches: [main]
branches: [main,poc-dbt-compile-task]
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
pull_request_target: # Also run on pull requests originated from forks
branches: [main]

Expand Down Expand Up @@ -176,6 +176,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -248,6 +250,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -316,6 +320,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -393,6 +399,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -537,6 +545,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down
31 changes: 30 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from cosmos.config import RenderConfig
from cosmos.constants import (
DBT_COMPILE_TASK_ID,
DEFAULT_DBT_RESOURCES,
TESTABLE_DBT_RESOURCES,
DbtResourceType,
Expand Down Expand Up @@ -252,6 +253,31 @@ def generate_task_or_group(
return task_or_group


def _add_dbt_compile_task(
nodes: dict[str, DbtNode],
dag: DAG,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return

compile_task_metadata = TaskMetadata(
id=DBT_COMPILE_TASK_ID,
operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group)
tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task

for node_id, node in nodes.items():
if not node.depends_on and node_id in tasks_map:
tasks_map[DBT_COMPILE_TASK_ID] >> tasks_map[node_id]


def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
Expand Down Expand Up @@ -332,11 +358,14 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)

create_airflow_task_dependencies(nodes, tasks_map)


def create_airflow_task_dependencies(
nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]]
nodes: dict[str, DbtNode],
tasks_map: dict[str, Union[TaskGroup, BaseOperator]],
) -> None:
"""
Create the Airflow task dependencies between non-test nodes.
Expand Down
3 changes: 3 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ExecutionMode(Enum):
"""

LOCAL = "local"
AIRFLOW_ASYNC = "airflow_async"
DOCKER = "docker"
KUBERNETES = "kubernetes"
AWS_EKS = "aws_eks"
Expand Down Expand Up @@ -147,3 +148,5 @@ def _missing_value_(cls, value): # type: ignore
# It expects that you have already created those resources through the appropriate commands.
# https://docs.getdbt.com/reference/commands/test
TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED}

DBT_COMPILE_TASK_ID = "dbt_compile"
67 changes: 67 additions & 0 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)


class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
pass


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
pass


class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
pass
9 changes: 9 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,12 @@ def add_cmd_flags(self) -> list[str]:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags


class DbtCompileMixin:
"""
Mixin for dbt compile command.
"""

base_cmd = ["compile"]
ui_color = "#877c7c"
95 changes: 93 additions & 2 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
from urllib.parse import urlparse

import airflow
import jinja2
Expand All @@ -17,6 +18,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.version import version as airflow_version
from attr import define
from packaging.version import Version

Expand All @@ -26,10 +28,11 @@
_get_latest_cached_package_lockfile,
is_cache_package_lockfile_enabled,
)
from cosmos.constants import InvocationMode
from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError
from cosmos.exceptions import AirflowCompatibilityError, CosmosValueError
from cosmos.settings import AIRFLOW_IO_AVAILABLE, remote_target_path, remote_target_path_conn_id

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -67,6 +70,7 @@
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtBuildMixin,
DbtCompileMixin,
DbtLSMixin,
DbtRunMixin,
DbtRunOperationMixin,
Expand Down Expand Up @@ -137,6 +141,7 @@ def __init__(
install_deps: bool = False,
callback: Callable[[str], None] | None = None,
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
**kwargs: Any,
) -> None:
Expand All @@ -146,6 +151,7 @@ def __init__(
self.compiled_sql = ""
self.freshness = ""
self.should_store_compiled_sql = should_store_compiled_sql
self.should_upload_compiled_sql = should_upload_compiled_sql
self.openlineage_events_completes: list[RunEvent] = []
self.invocation_mode = invocation_mode
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
Expand Down Expand Up @@ -271,6 +277,84 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se
else:
self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.")

@staticmethod
def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
"""Configure the remote target path if it is provided."""
if not remote_target_path:
return None, None

_configured_target_path = None

target_path_str = str(remote_target_path)

remote_conn_id = remote_target_path_conn_id
if not remote_conn_id:
target_path_schema = urlparse(target_path_str).scheme
remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment]
if remote_conn_id is None:
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
return None, None

if not AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
f"You're trying to specify remote target path {target_path_str}, but the required "
f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to "
"Airflow 2.8 or later."
)

from airflow.io.path import ObjectStoragePath

_configured_target_path = ObjectStoragePath(target_path_str, conn_id=remote_conn_id)

if not _configured_target_path.exists(): # type: ignore[no-untyped-call]
_configured_target_path.mkdir(parents=True, exist_ok=True)

return _configured_target_path, remote_conn_id

def _construct_dest_file_path(
self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, context: Context
) -> str:
"""
Construct the destination path for the compiled SQL files to be uploaded to the remote store.
"""
dest_target_dir_str = str(dest_target_dir).rstrip("/")

task = context["task"]
dag_id = task.dag_id
task_group_id = task.task_group.group_id if task.task_group else None
identifiers_list = []
if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)
dag_task_group_identifier = "__".join(identifiers_list)

rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/")

return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}"

def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Uploads the compiled SQL files from the dbt compile output to the remote store.
"""
if not self.should_upload_compiled_sql:
return

dest_target_dir, dest_conn_id = self._configure_remote_target_path()
if not dest_target_dir:
raise CosmosValueError(
"You're trying to upload compiled SQL files, but the remote target path is not configured. "
)

from airflow.io.path import ObjectStoragePath

source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled"
files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()]
for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, context)
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
ObjectStoragePath(file_path).copy(dest_object_storage_path)
self.log.debug("Copied %s to %s", file_path, dest_object_storage_path)

@provide_session
def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
"""
Expand Down Expand Up @@ -416,6 +500,7 @@ def run_command(

self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self.upload_compiled_sql(tmp_project_dir, context)
self.handle_exception(result)
if self.callback:
self.callback(tmp_project_dir)
Expand Down Expand Up @@ -920,3 +1005,9 @@ def __init__(self, **kwargs: str) -> None:
raise DeprecationWarning(
"The DbtDepsOperator has been deprecated. " "Please use the `install_deps` flag in dbt_args instead."
)


class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator):
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["should_upload_compiled_sql"] = True
super().__init__(*args, **kwargs)
3 changes: 3 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
remote_cache_dir = conf.get("cosmos", "remote_cache_dir", fallback=None)
remote_cache_dir_conn_id = conf.get("cosmos", "remote_cache_dir_conn_id", fallback=None)

remote_target_path = conf.get("cosmos", "remote_target_path", fallback=None)
remote_target_path_conn_id = conf.get("cosmos", "remote_target_path_conn_id", fallback=None)

try:
LINEAGE_NAMESPACE = conf.get("openlineage", "namespace")
except airflow.exceptions.AirflowConfigException:
Expand Down
Loading
Loading