Skip to content

Commit

Permalink
Extend DbtDag caching to DbtTaskGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed May 23, 2024
1 parent 0db7121 commit daad0b3
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 67 deletions.
87 changes: 20 additions & 67 deletions cosmos/airflow/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@

from __future__ import annotations

import functools

# import inspect
import pickle
import time
from pathlib import Path
from typing import Any

from airflow.models.dag import DAG

from cosmos import cache, settings
from cosmos import cache
from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs
from cosmos.log import get_logger

Expand All @@ -26,78 +23,32 @@ class DbtDag(DAG, DbtToAirflowConverter):
Render a dbt project as an Airflow DAG.
"""

@staticmethod
@functools.lru_cache
def get_cache_filepath(cache_identifier: str) -> Path:
cache_dir_path = cache._obtain_cache_dir_path(cache_identifier)
return cache_dir_path / f"{cache_identifier}.pkl"

@staticmethod
@functools.lru_cache
def get_cache_version_filepath(cache_identifier: str) -> Path:
return Path(str(DbtDag.get_cache_filepath(cache_identifier)) + ".version")

@staticmethod
@functools.lru_cache
def should_use_cache() -> bool:
return settings.enable_cache and settings.experimental_cache

@staticmethod
@functools.lru_cache
def is_project_unmodified(dag_id: str, current_version: str) -> Path | None:
cache_filepath = DbtDag.get_cache_filepath(dag_id)
cache_version_filepath = DbtDag.get_cache_version_filepath(dag_id)
if cache_version_filepath.exists() and cache_filepath.exists():
previous_cache_version = cache_version_filepath.read_text()
if previous_cache_version == current_version:
return cache_filepath
return None

@staticmethod
@functools.lru_cache
def calculate_current_version(dag_id: str, project_dir: Path) -> str:
start_time = time.process_time()

# When DAG file was last changed - this is very slow (e.g. 0.6s)
# caller_dag_frame = inspect.stack()[1]
# caller_dag_filepath = Path(caller_dag_frame.filename)
# logger.info("The %s DAG is located in: %s" % (dag_id, caller_dag_filepath))
# dag_last_modified = caller_dag_filepath.stat().st_mtime
# mid_time = time.process_time() - start_time
# logger.info(f"It took {mid_time:.3}s to calculate the first part of the version")
dag_last_modified = None

# Combined value for when the dbt project directory files were last modified
# This is fast (e.g. 0.01s for jaffle shop, 0.135s for a 5k models dbt folder)
dbt_combined_last_modified = sum([path.stat().st_mtime for path in project_dir.glob("**/*")])

elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time:.3}s to calculate the cache version for the DbtDag {dag_id}")
return f"{dag_last_modified} {dbt_combined_last_modified}"

def __new__(cls, *args, **kwargs): # type: ignore
dag_id = kwargs.get("dag_id")
project_config = kwargs.get("project_config")

# When we load a Pickle dump of a DbtDag, __new__ is invoked without kwargs
# In those cases, we should not call DbtDag.__new__ again, otherwise we'll have an infinite recursion
# When we load a Pickle dump of an instance, __new__ is invoked without kwargs
# In those cases, we should not call __new__ again, otherwise we'll have an infinite recursion
if dag_id is not None and project_config and project_config.dbt_project_path:
current_version = DbtDag.calculate_current_version(dag_id, project_config.dbt_project_path)
cache_filepath = DbtDag.should_use_cache() and DbtDag.is_project_unmodified(dag_id, current_version)
cache_id = cache.create_cache_identifier_v2(dag_id, None)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_filepath = cache.should_use_cache() and cache.is_project_unmodified(cache_id, current_version)
if cache_filepath:
logger.info(f"Restoring DbtDag {dag_id} from cache {cache_filepath}")
logger.info(f"Restoring {cls.__name__} {dag_id} from cache {cache_filepath}")
with open(cache_filepath, "rb") as fp:
start_time = time.process_time()
dbt_dag = pickle.load(fp)
elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time:.3}s to restore the cached version of the DbtDag {dag_id}")
logger.info(
f"It took {elapsed_time:.3}s to restore the cached version of the {cls.__name__} {dag_id}"
)
return dbt_dag

instance = DAG.__new__(DAG)
DbtDag.__init__(instance, *args, **kwargs) # type: ignore
cls.__init__(instance, *args, **kwargs) # type: ignore
return instance

# The __init__ is not called when restoring the cached DbtDag in __new__
# The __init__ is not called when restoring the cached in __new__
def __init__(
self,
*args: Any,
Expand All @@ -110,14 +61,16 @@ def __init__(
DAG.__init__(self, *args, **airflow_kwargs(**kwargs))
kwargs["dag"] = self
DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs))

elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time} to create the DbtDag {dag_id} from scratch")
logger.info(f"It took {elapsed_time} to create the {self.__class__.__name__} {dag_id} from scratch")

if DbtDag.should_use_cache() and project_config:
cache_filepath = DbtDag.get_cache_filepath(dag_id)
if cache.should_use_cache() and project_config:
cache_id = cache.create_cache_identifier_v2(dag_id, None)
cache_filepath = cache.get_cache_filepath(cache_id)
with open(cache_filepath, "wb") as fp:
pickle.dump(self, fp)
cache_version_filepath = DbtDag.get_cache_version_filepath(dag_id)
current_version = DbtDag.calculate_current_version(dag_id, project_config.dbt_project_path)
cache_version_filepath = cache.get_cache_version_filepath(cache_id)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_version_filepath.write_text(current_version)
logger.info(f"Stored DbtDag {dag_id} cache {cache_filepath}")
logger.info(f"Stored {self.__class__.__name__} {dag_id} cache {cache_filepath}")
49 changes: 49 additions & 0 deletions cosmos/airflow/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,74 @@

from __future__ import annotations

import pickle
import time
from typing import Any

from airflow.utils.task_group import TaskGroup

from cosmos import cache
from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs
from cosmos.log import get_logger

logger = get_logger()


class DbtTaskGroup(TaskGroup, DbtToAirflowConverter):
"""
Render a dbt project as an Airflow Task Group.
"""

def __new__(cls, *args, **kwargs): # type: ignore
dag_id = kwargs.get("dag_id")
task_id = kwargs.get("task_id")
project_config = kwargs.get("project_config")

# When we load a Pickle dump of an instance, __new__ is invoked without kwargs
# In those cases, we should not call __new__ again, otherwise we'll have an infinite recursion
if task_id is not None and project_config and project_config.dbt_project_path:
cache_id = cache.create_cache_identifier_v2(dag_id, task_id)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_filepath = cache.should_use_cache() and cache.is_project_unmodified(cache_id, current_version)
if cache_filepath:
logger.info(f"Restoring {cls.__name__} {dag_id} from cache {cache_filepath}")
with open(cache_filepath, "rb") as fp:
start_time = time.process_time()
dbt_dag = pickle.load(fp)
elapsed_time = time.process_time() - start_time
logger.info(
f"It took {elapsed_time:.3}s to restore the cached version of the {cls.__name__} {dag_id}"
)
return dbt_dag

instance = TaskGroup.__new__(TaskGroup)
cls.__init__(instance, *args, **kwargs) # type: ignore
return instance

def __init__(
self,
group_id: str = "dbt_task_group",
*args: Any,
**kwargs: Any,
) -> None:
start_time = time.process_time()
kwargs["group_id"] = group_id
dag_id = kwargs.get("dag_id")
project_config = kwargs.get("project_config")

TaskGroup.__init__(self, *args, **airflow_kwargs(**kwargs))
kwargs["task_group"] = self
DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs))

elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time} to create the {self.__class__.__name__} {dag_id} from scratch")

if cache.should_use_cache() and project_config:
cache_id = cache.create_cache_identifier_v2(dag_id, group_id)
cache_filepath = cache.get_cache_filepath(cache_id)
with open(cache_filepath, "wb") as fp:
pickle.dump(self, fp)
cache_version_filepath = cache.get_cache_version_filepath(cache_id)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_version_filepath.write_text(current_version)
logger.info(f"Stored {self.__class__.__name__} {dag_id} cache {cache_filepath}")
83 changes: 83 additions & 0 deletions cosmos/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import functools
import shutil
import time
from pathlib import Path

import msgpack
Expand Down Expand Up @@ -171,3 +173,84 @@ def _copy_partial_parse_to_project(partial_parse_filepath: Path, project_path: P

if source_manifest_filepath.exists():
shutil.copy(str(source_manifest_filepath), str(target_manifest_filepath))


# The following methods are being used to cache DbtDag / DbtTaskGroup


# It was considered to create a cache identifier based on the dbt project path, as opposed
# to where it is used in Airflow. However, we could have concurrency issues if the same
# dbt cached directory was being used by different dbt task groups or DAGs within the same
# node. For this reason, as a starting point, the cache is identified by where it is used.
# This can be reviewed in the future.
def create_cache_identifier_v2(dag_id: str | None, task_group_id: str | None) -> str:
# FIXME: To be refactored and merged with _create_cache_identifier
# Missing support to: task_group.group_id
"""
Given a DAG name and a (optional) task_group_name, create the identifier for caching.
:param dag_name: Name of the Cosmos DbtDag being cached
:param task_group_name: (optional) Name of the Cosmos DbtTaskGroup being cached
:return: Unique identifier representing the cache
"""
cache_identifiers_list = []
if task_group_id:
if dag_id is not None:
cache_identifiers_list.append(dag_id)
if task_group_id is not None:
cache_identifiers_list.append(task_group_id)
cache_identifier = "__".join(cache_identifiers_list)
else:
cache_identifier = str(dag_id)

return cache_identifier


@functools.lru_cache
def get_cache_filepath(cache_identifier: str) -> Path:
cache_dir_path = _obtain_cache_dir_path(cache_identifier)
return cache_dir_path / f"{cache_identifier}.pkl"


@functools.lru_cache
def get_cache_version_filepath(cache_identifier: str) -> Path:
return Path(str(get_cache_filepath(cache_identifier)) + ".version")


@functools.lru_cache
def should_use_cache() -> bool:
return settings.enable_cache and settings.experimental_cache


@functools.lru_cache
def calculate_current_version(dag_id: str, project_dir: Path) -> str:
start_time = time.process_time()

# When DAG file was last changed - this is very slow (e.g. 0.6s)
# caller_dag_frame = inspect.stack()[1]
# caller_dag_filepath = Path(caller_dag_frame.filename)
# logger.info("The %s DAG is located in: %s" % (dag_id, caller_dag_filepath))
# dag_last_modified = caller_dag_filepath.stat().st_mtime
# mid_time = time.process_time() - start_time
# logger.info(f"It took {mid_time:.3}s to calculate the first part of the version")
# dag_last_modified = None

# Combined value for when the dbt project directory files were last modified
# This is fast (e.g. 0.01s for jaffle shop, 0.135s for a 5k models dbt folder)
dbt_combined_last_modified = sum([path.stat().st_mtime for path in project_dir.glob("**/*")])

elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time:.3}s to calculate the cache version for the {dag_id}")
# return f"{dag_last_modified} {dbt_combined_last_modified}"
return f"{dbt_combined_last_modified}"


@functools.lru_cache
def is_project_unmodified(dag_id: str, current_version: str) -> Path | None:
cache_filepath = get_cache_filepath(dag_id)
cache_version_filepath = get_cache_version_filepath(dag_id)
if cache_version_filepath.exists() and cache_filepath.exists():
previous_cache_version = cache_version_filepath.read_text()
if previous_cache_version == current_version:
return cache_filepath
return None

0 comments on commit daad0b3

Please sign in to comment.