diff --git a/cosmos/airflow/dag.py b/cosmos/airflow/dag.py index dd5092080..bdcce3702 100644 --- a/cosmos/airflow/dag.py +++ b/cosmos/airflow/dag.py @@ -4,17 +4,14 @@ from __future__ import annotations -import functools - # import inspect import pickle import time -from pathlib import Path from typing import Any from airflow.models.dag import DAG -from cosmos import cache, settings +from cosmos import cache from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs from cosmos.log import get_logger @@ -26,78 +23,32 @@ class DbtDag(DAG, DbtToAirflowConverter): Render a dbt project as an Airflow DAG. """ - @staticmethod - @functools.lru_cache - def get_cache_filepath(cache_identifier: str) -> Path: - cache_dir_path = cache._obtain_cache_dir_path(cache_identifier) - return cache_dir_path / f"{cache_identifier}.pkl" - - @staticmethod - @functools.lru_cache - def get_cache_version_filepath(cache_identifier: str) -> Path: - return Path(str(DbtDag.get_cache_filepath(cache_identifier)) + ".version") - - @staticmethod - @functools.lru_cache - def should_use_cache() -> bool: - return settings.enable_cache and settings.experimental_cache - - @staticmethod - @functools.lru_cache - def is_project_unmodified(dag_id: str, current_version: str) -> Path | None: - cache_filepath = DbtDag.get_cache_filepath(dag_id) - cache_version_filepath = DbtDag.get_cache_version_filepath(dag_id) - if cache_version_filepath.exists() and cache_filepath.exists(): - previous_cache_version = cache_version_filepath.read_text() - if previous_cache_version == current_version: - return cache_filepath - return None - - @staticmethod - @functools.lru_cache - def calculate_current_version(dag_id: str, project_dir: Path) -> str: - start_time = time.process_time() - - # When DAG file was last changed - this is very slow (e.g. 0.6s) - # caller_dag_frame = inspect.stack()[1] - # caller_dag_filepath = Path(caller_dag_frame.filename) - # logger.info("The %s DAG is located in: %s" % (dag_id, caller_dag_filepath)) - # dag_last_modified = caller_dag_filepath.stat().st_mtime - # mid_time = time.process_time() - start_time - # logger.info(f"It took {mid_time:.3}s to calculate the first part of the version") - dag_last_modified = None - - # Combined value for when the dbt project directory files were last modified - # This is fast (e.g. 0.01s for jaffle shop, 0.135s for a 5k models dbt folder) - dbt_combined_last_modified = sum([path.stat().st_mtime for path in project_dir.glob("**/*")]) - - elapsed_time = time.process_time() - start_time - logger.info(f"It took {elapsed_time:.3}s to calculate the cache version for the DbtDag {dag_id}") - return f"{dag_last_modified} {dbt_combined_last_modified}" - def __new__(cls, *args, **kwargs): # type: ignore dag_id = kwargs.get("dag_id") project_config = kwargs.get("project_config") - # When we load a Pickle dump of a DbtDag, __new__ is invoked without kwargs - # In those cases, we should not call DbtDag.__new__ again, otherwise we'll have an infinite recursion + # When we load a Pickle dump of an instance, __new__ is invoked without kwargs + # In those cases, we should not call __new__ again, otherwise we'll have an infinite recursion if dag_id is not None and project_config and project_config.dbt_project_path: - current_version = DbtDag.calculate_current_version(dag_id, project_config.dbt_project_path) - cache_filepath = DbtDag.should_use_cache() and DbtDag.is_project_unmodified(dag_id, current_version) + cache_id = cache.create_cache_identifier_v2(dag_id, None) + current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path) + cache_filepath = cache.should_use_cache() and cache.is_project_unmodified(cache_id, current_version) if cache_filepath: - logger.info(f"Restoring DbtDag {dag_id} from cache {cache_filepath}") + logger.info(f"Restoring {cls.__name__} {dag_id} from cache {cache_filepath}") with open(cache_filepath, "rb") as fp: start_time = time.process_time() dbt_dag = pickle.load(fp) elapsed_time = time.process_time() - start_time - logger.info(f"It took {elapsed_time:.3}s to restore the cached version of the DbtDag {dag_id}") + logger.info( + f"It took {elapsed_time:.3}s to restore the cached version of the {cls.__name__} {dag_id}" + ) return dbt_dag instance = DAG.__new__(DAG) - DbtDag.__init__(instance, *args, **kwargs) # type: ignore + cls.__init__(instance, *args, **kwargs) # type: ignore return instance - # The __init__ is not called when restoring the cached DbtDag in __new__ + # The __init__ is not called when restoring the cached in __new__ def __init__( self, *args: Any, @@ -110,14 +61,16 @@ def __init__( DAG.__init__(self, *args, **airflow_kwargs(**kwargs)) kwargs["dag"] = self DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs)) + elapsed_time = time.process_time() - start_time - logger.info(f"It took {elapsed_time} to create the DbtDag {dag_id} from scratch") + logger.info(f"It took {elapsed_time} to create the {self.__class__.__name__} {dag_id} from scratch") - if DbtDag.should_use_cache() and project_config: - cache_filepath = DbtDag.get_cache_filepath(dag_id) + if cache.should_use_cache() and project_config: + cache_id = cache.create_cache_identifier_v2(dag_id, None) + cache_filepath = cache.get_cache_filepath(cache_id) with open(cache_filepath, "wb") as fp: pickle.dump(self, fp) - cache_version_filepath = DbtDag.get_cache_version_filepath(dag_id) - current_version = DbtDag.calculate_current_version(dag_id, project_config.dbt_project_path) + cache_version_filepath = cache.get_cache_version_filepath(cache_id) + current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path) cache_version_filepath.write_text(current_version) - logger.info(f"Stored DbtDag {dag_id} cache {cache_filepath}") + logger.info(f"Stored {self.__class__.__name__} {dag_id} cache {cache_filepath}") diff --git a/cosmos/airflow/task_group.py b/cosmos/airflow/task_group.py index 64fcb298a..90881726e 100644 --- a/cosmos/airflow/task_group.py +++ b/cosmos/airflow/task_group.py @@ -4,11 +4,17 @@ from __future__ import annotations +import pickle +import time from typing import Any from airflow.utils.task_group import TaskGroup +from cosmos import cache from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs +from cosmos.log import get_logger + +logger = get_logger() class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): @@ -16,13 +22,56 @@ class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): Render a dbt project as an Airflow Task Group. """ + def __new__(cls, *args, **kwargs): # type: ignore + dag_id = kwargs.get("dag_id") + task_id = kwargs.get("task_id") + project_config = kwargs.get("project_config") + + # When we load a Pickle dump of an instance, __new__ is invoked without kwargs + # In those cases, we should not call __new__ again, otherwise we'll have an infinite recursion + if task_id is not None and project_config and project_config.dbt_project_path: + cache_id = cache.create_cache_identifier_v2(dag_id, task_id) + current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path) + cache_filepath = cache.should_use_cache() and cache.is_project_unmodified(cache_id, current_version) + if cache_filepath: + logger.info(f"Restoring {cls.__name__} {dag_id} from cache {cache_filepath}") + with open(cache_filepath, "rb") as fp: + start_time = time.process_time() + dbt_dag = pickle.load(fp) + elapsed_time = time.process_time() - start_time + logger.info( + f"It took {elapsed_time:.3}s to restore the cached version of the {cls.__name__} {dag_id}" + ) + return dbt_dag + + instance = TaskGroup.__new__(TaskGroup) + cls.__init__(instance, *args, **kwargs) # type: ignore + return instance + def __init__( self, group_id: str = "dbt_task_group", *args: Any, **kwargs: Any, ) -> None: + start_time = time.process_time() kwargs["group_id"] = group_id + dag_id = kwargs.get("dag_id") + project_config = kwargs.get("project_config") + TaskGroup.__init__(self, *args, **airflow_kwargs(**kwargs)) kwargs["task_group"] = self DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs)) + + elapsed_time = time.process_time() - start_time + logger.info(f"It took {elapsed_time} to create the {self.__class__.__name__} {dag_id} from scratch") + + if cache.should_use_cache() and project_config: + cache_id = cache.create_cache_identifier_v2(dag_id, group_id) + cache_filepath = cache.get_cache_filepath(cache_id) + with open(cache_filepath, "wb") as fp: + pickle.dump(self, fp) + cache_version_filepath = cache.get_cache_version_filepath(cache_id) + current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path) + cache_version_filepath.write_text(current_version) + logger.info(f"Stored {self.__class__.__name__} {dag_id} cache {cache_filepath}") diff --git a/cosmos/cache.py b/cosmos/cache.py index 1e0b341f0..37a88d179 100644 --- a/cosmos/cache.py +++ b/cosmos/cache.py @@ -1,6 +1,8 @@ from __future__ import annotations +import functools import shutil +import time from pathlib import Path import msgpack @@ -171,3 +173,84 @@ def _copy_partial_parse_to_project(partial_parse_filepath: Path, project_path: P if source_manifest_filepath.exists(): shutil.copy(str(source_manifest_filepath), str(target_manifest_filepath)) + + +# The following methods are being used to cache DbtDag / DbtTaskGroup + + +# It was considered to create a cache identifier based on the dbt project path, as opposed +# to where it is used in Airflow. However, we could have concurrency issues if the same +# dbt cached directory was being used by different dbt task groups or DAGs within the same +# node. For this reason, as a starting point, the cache is identified by where it is used. +# This can be reviewed in the future. +def create_cache_identifier_v2(dag_id: str | None, task_group_id: str | None) -> str: + # FIXME: To be refactored and merged with _create_cache_identifier + # Missing support to: task_group.group_id + """ + Given a DAG name and a (optional) task_group_name, create the identifier for caching. + + :param dag_name: Name of the Cosmos DbtDag being cached + :param task_group_name: (optional) Name of the Cosmos DbtTaskGroup being cached + :return: Unique identifier representing the cache + """ + cache_identifiers_list = [] + if task_group_id: + if dag_id is not None: + cache_identifiers_list.append(dag_id) + if task_group_id is not None: + cache_identifiers_list.append(task_group_id) + cache_identifier = "__".join(cache_identifiers_list) + else: + cache_identifier = str(dag_id) + + return cache_identifier + + +@functools.lru_cache +def get_cache_filepath(cache_identifier: str) -> Path: + cache_dir_path = _obtain_cache_dir_path(cache_identifier) + return cache_dir_path / f"{cache_identifier}.pkl" + + +@functools.lru_cache +def get_cache_version_filepath(cache_identifier: str) -> Path: + return Path(str(get_cache_filepath(cache_identifier)) + ".version") + + +@functools.lru_cache +def should_use_cache() -> bool: + return settings.enable_cache and settings.experimental_cache + + +@functools.lru_cache +def calculate_current_version(dag_id: str, project_dir: Path) -> str: + start_time = time.process_time() + + # When DAG file was last changed - this is very slow (e.g. 0.6s) + # caller_dag_frame = inspect.stack()[1] + # caller_dag_filepath = Path(caller_dag_frame.filename) + # logger.info("The %s DAG is located in: %s" % (dag_id, caller_dag_filepath)) + # dag_last_modified = caller_dag_filepath.stat().st_mtime + # mid_time = time.process_time() - start_time + # logger.info(f"It took {mid_time:.3}s to calculate the first part of the version") + # dag_last_modified = None + + # Combined value for when the dbt project directory files were last modified + # This is fast (e.g. 0.01s for jaffle shop, 0.135s for a 5k models dbt folder) + dbt_combined_last_modified = sum([path.stat().st_mtime for path in project_dir.glob("**/*")]) + + elapsed_time = time.process_time() - start_time + logger.info(f"It took {elapsed_time:.3}s to calculate the cache version for the {dag_id}") + # return f"{dag_last_modified} {dbt_combined_last_modified}" + return f"{dbt_combined_last_modified}" + + +@functools.lru_cache +def is_project_unmodified(dag_id: str, current_version: str) -> Path | None: + cache_filepath = get_cache_filepath(dag_id) + cache_version_filepath = get_cache_version_filepath(dag_id) + if cache_version_filepath.exists() and cache_filepath.exists(): + previous_cache_version = cache_version_filepath.read_text() + if previous_cache_version == current_version: + return cache_filepath + return None