Skip to content

Commit

Permalink
Introduce ground work for ExecutionMode.AIRFLOW_ASYNC (#1224)
Browse files Browse the repository at this point in the history
This PR is the groundwork for the implementation of
`ExecutionMode.AIRFLOW_ASYNC`
(#1120), which -
once all other epic tasks are completed - will enable asynchronous
execution of dbt resources using Apache Airflow’s deferrable operators.
As part of this work, this PR introduces a new option to the enum
`ExecutionMode` : `AIRFLOW_ASYNC`. When this execution mode is used,
Cosmos now creates a setup task that will pre-compile the dbt project
SQL and make it available to the remaining dbt tasks. This PR, however,
does not yet leverage Airflow's deferrable operators. If users use
`ExecutionMode.AIRFLOW_ASYNC` they will actually be running
`ExecutionMode.LOCAL` operators with this change. The PR (#1230) has a
first experimental version of using deferrable operators for task
execution.

## Setup task as the ground work for a new Execution Mode:
`ExecutionMode.AIRFLOW_ASYNC`:
- Adds a new operator, `DbtCompileAirflowAsyncOperator`, as a root
task(analogous to a setup task) in the DAG, running the dbt compile
command and uploading the compiled SQL files to a remote storage
location for subsequent tasks that fetch these compiled SQL files from
the remote storage and run them asynchronously using Airflow's
deferrable operators.

## Airflow Configurations:
- `remote_target_path`: Introduces a configurable path to store
dbt-generated files remotely, supporting any storage scheme that works
with Airflow’s Object Store (e.g., S3, GCS, Azure Blob).
- `remote_target_path_conn_id`: Allows specifying a custom connection ID
for the remote target path, defaulting to the scheme’s associated
Airflow connection if not set.

## Example DAG for CI Testing:
Introduces an example DAG (`simple_dag_async.py`) demonstrating how to
use the new execution mode(The execution like mentioned earlier would
still run like Execution Mode LOCAL operators at the moment with this PR
alone)
This DAG is integrated into the CI pipeline to run integration tests and
aims at verifying the functionality of the `ExecutionMode.AIRFLOW_ASYNC`
as and when implementation gets added starting with the experimental
implementation in #1230 .

## Unit & Integration Tests:
- Adds comprehensive unit and integration tests to ensure correct
behavior.
- Tests include validation for successful uploads, error handling for
misconfigured remote paths, and scenarios where `remote_target_path` are
not set.

## Documentation:
- Adds detailed documentation explaining how to configure and set the
`ExecutionMode.AIRFLOW_ASYNC`.

## Scope & Limitations of the feature being introduced:
1. This feature is meant to be released as Experimental and is also
marked so in the documentation.
2. Currently, it has been scoped for only dbt models to be executed
asynchronously (being worked upon in PR #1230), while other resource
types would be run synchronously.
3. `BigQuery` will be the only supported target database for this
execution mode ((being worked upon in PR #1230).

Thus, this PR enhances Cosmos by providing the ground work for more
efficient execution of long-running dbt resources

## Additional Notes:
- This feature is planned to be introduced in Cosmos v1.7.0.

related: #1134
  • Loading branch information
pankajkoti authored Sep 30, 2024
1 parent 741c2eb commit 879b1a3
Show file tree
Hide file tree
Showing 15 changed files with 575 additions and 6 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: test

on:
push: # Run on pushes to the default branch
branches: [main]
branches: [main,poc-dbt-compile-task]
pull_request_target: # Also run on pull requests originated from forks
branches: [main]

Expand Down Expand Up @@ -176,6 +176,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -248,6 +250,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -316,6 +320,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -393,6 +399,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -537,6 +545,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down
31 changes: 30 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from cosmos.config import RenderConfig
from cosmos.constants import (
DBT_COMPILE_TASK_ID,
DEFAULT_DBT_RESOURCES,
TESTABLE_DBT_RESOURCES,
DbtResourceType,
Expand Down Expand Up @@ -252,6 +253,31 @@ def generate_task_or_group(
return task_or_group


def _add_dbt_compile_task(
nodes: dict[str, DbtNode],
dag: DAG,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return

compile_task_metadata = TaskMetadata(
id=DBT_COMPILE_TASK_ID,
operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group)
tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task

for node_id, node in nodes.items():
if not node.depends_on and node_id in tasks_map:
tasks_map[DBT_COMPILE_TASK_ID] >> tasks_map[node_id]


def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
Expand Down Expand Up @@ -332,11 +358,14 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)

create_airflow_task_dependencies(nodes, tasks_map)


def create_airflow_task_dependencies(
nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]]
nodes: dict[str, DbtNode],
tasks_map: dict[str, Union[TaskGroup, BaseOperator]],
) -> None:
"""
Create the Airflow task dependencies between non-test nodes.
Expand Down
3 changes: 3 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ExecutionMode(Enum):
"""

LOCAL = "local"
AIRFLOW_ASYNC = "airflow_async"
DOCKER = "docker"
KUBERNETES = "kubernetes"
AWS_EKS = "aws_eks"
Expand Down Expand Up @@ -147,3 +148,5 @@ def _missing_value_(cls, value): # type: ignore
# It expects that you have already created those resources through the appropriate commands.
# https://docs.getdbt.com/reference/commands/test
TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED}

DBT_COMPILE_TASK_ID = "dbt_compile"
67 changes: 67 additions & 0 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)


class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
pass


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
pass


class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
pass
9 changes: 9 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,12 @@ def add_cmd_flags(self) -> list[str]:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags


class DbtCompileMixin:
"""
Mixin for dbt compile command.
"""

base_cmd = ["compile"]
ui_color = "#877c7c"
95 changes: 93 additions & 2 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
from urllib.parse import urlparse

import airflow
import jinja2
Expand All @@ -17,6 +18,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.version import version as airflow_version
from attr import define
from packaging.version import Version

Expand All @@ -26,10 +28,11 @@
_get_latest_cached_package_lockfile,
is_cache_package_lockfile_enabled,
)
from cosmos.constants import InvocationMode
from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError
from cosmos.exceptions import AirflowCompatibilityError, CosmosValueError
from cosmos.settings import AIRFLOW_IO_AVAILABLE, remote_target_path, remote_target_path_conn_id

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -67,6 +70,7 @@
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtBuildMixin,
DbtCompileMixin,
DbtLSMixin,
DbtRunMixin,
DbtRunOperationMixin,
Expand Down Expand Up @@ -137,6 +141,7 @@ def __init__(
install_deps: bool = False,
callback: Callable[[str], None] | None = None,
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
**kwargs: Any,
) -> None:
Expand All @@ -146,6 +151,7 @@ def __init__(
self.compiled_sql = ""
self.freshness = ""
self.should_store_compiled_sql = should_store_compiled_sql
self.should_upload_compiled_sql = should_upload_compiled_sql
self.openlineage_events_completes: list[RunEvent] = []
self.invocation_mode = invocation_mode
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
Expand Down Expand Up @@ -271,6 +277,84 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se
else:
self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.")

@staticmethod
def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
"""Configure the remote target path if it is provided."""
if not remote_target_path:
return None, None

_configured_target_path = None

target_path_str = str(remote_target_path)

remote_conn_id = remote_target_path_conn_id
if not remote_conn_id:
target_path_schema = urlparse(target_path_str).scheme
remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment]
if remote_conn_id is None:
return None, None

if not AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
f"You're trying to specify remote target path {target_path_str}, but the required "
f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to "
"Airflow 2.8 or later."
)

from airflow.io.path import ObjectStoragePath

_configured_target_path = ObjectStoragePath(target_path_str, conn_id=remote_conn_id)

if not _configured_target_path.exists(): # type: ignore[no-untyped-call]
_configured_target_path.mkdir(parents=True, exist_ok=True)

return _configured_target_path, remote_conn_id

def _construct_dest_file_path(
self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, context: Context
) -> str:
"""
Construct the destination path for the compiled SQL files to be uploaded to the remote store.
"""
dest_target_dir_str = str(dest_target_dir).rstrip("/")

task = context["task"]
dag_id = task.dag_id
task_group_id = task.task_group.group_id if task.task_group else None
identifiers_list = []
if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)
dag_task_group_identifier = "__".join(identifiers_list)

rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/")

return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}"

def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Uploads the compiled SQL files from the dbt compile output to the remote store.
"""
if not self.should_upload_compiled_sql:
return

dest_target_dir, dest_conn_id = self._configure_remote_target_path()
if not dest_target_dir:
raise CosmosValueError(
"You're trying to upload compiled SQL files, but the remote target path is not configured. "
)

from airflow.io.path import ObjectStoragePath

source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled"
files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()]
for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, context)
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
ObjectStoragePath(file_path).copy(dest_object_storage_path)
self.log.debug("Copied %s to %s", file_path, dest_object_storage_path)

@provide_session
def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
"""
Expand Down Expand Up @@ -416,6 +500,7 @@ def run_command(

self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self.upload_compiled_sql(tmp_project_dir, context)
self.handle_exception(result)
if self.callback:
self.callback(tmp_project_dir)
Expand Down Expand Up @@ -920,3 +1005,9 @@ def __init__(self, **kwargs: str) -> None:
raise DeprecationWarning(
"The DbtDepsOperator has been deprecated. " "Please use the `install_deps` flag in dbt_args instead."
)


class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator):
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["should_upload_compiled_sql"] = True
super().__init__(*args, **kwargs)
3 changes: 3 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
remote_cache_dir = conf.get("cosmos", "remote_cache_dir", fallback=None)
remote_cache_dir_conn_id = conf.get("cosmos", "remote_cache_dir_conn_id", fallback=None)

remote_target_path = conf.get("cosmos", "remote_target_path", fallback=None)
remote_target_path_conn_id = conf.get("cosmos", "remote_target_path_conn_id", fallback=None)

try:
LINEAGE_NAMESPACE = conf.get("openlineage", "namespace")
except airflow.exceptions.AirflowConfigException:
Expand Down
Loading

0 comments on commit 879b1a3

Please sign in to comment.