Skip to content

Commit

Permalink
Address feedback to include task_group in the remote store path as id…
Browse files Browse the repository at this point in the history
…entifier & put files under compiled dir
  • Loading branch information
pankajkoti committed Sep 30, 2024
1 parent 061ea1b commit 77c7c6c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
5 changes: 3 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def _add_dbt_compile_task(
execution_mode: ExecutionMode,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return
Expand All @@ -269,7 +270,7 @@ def _add_dbt_compile_task(
arguments=task_args,
extra_context={},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=None)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group)
tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task

for node_id, node in nodes.items():
Expand Down Expand Up @@ -357,7 +358,7 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map)
_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)

create_airflow_task_dependencies(nodes, tasks_map)

Expand Down
35 changes: 26 additions & 9 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,28 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:

return _configured_target_path, remote_conn_id

def _construct_dest_file_path(
self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, context: Context
) -> str:
"""
Construct the destination path for the compiled SQL files to be uploaded to the remote store.
"""
dest_target_dir_str = str(dest_target_dir).rstrip("/")

task = context["task"]
dag_id = task.dag_id
task_group_id = task.task_group.group_id if task.task_group else None
identifiers_list = []
if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)
dag_task_group_identifier = "__".join(identifiers_list)

rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/")

return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}"

def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Uploads the compiled SQL files from the dbt compile output to the remote store.
Expand All @@ -327,16 +349,11 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:

source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled"
files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()]

for file_path in files:
rel_path = os.path.relpath(file_path, source_compiled_dir)

dest_path = ObjectStoragePath(
f"{str(dest_target_dir).rstrip('/')}/{context['dag'].dag_id}/{rel_path.lstrip('/')}",
conn_id=dest_conn_id,
)
ObjectStoragePath(file_path).copy(dest_path)
self.log.debug("Copied %s to %s", file_path, dest_path)
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, context)
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
ObjectStoragePath(file_path).copy(dest_object_storage_path)
self.log.debug("Copied %s to %s", file_path, dest_object_storage_path)

@provide_session
def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
Expand Down
7 changes: 3 additions & 4 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,7 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st
task_id="fake-task",
profile_config=profile_config,
project_dir="fake-dir",
dag=DAG("test_dag", start_date=datetime(2024, 4, 16)),
)

mock_configure_remote.return_value = ("mock_remote_path", "mock_conn_id")
Expand All @@ -1259,12 +1260,10 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st
files = [file1, file2]

with patch.object(Path, "rglob", return_value=files):
context = {"dag": MagicMock(dag_id="test_dag")}

operator.upload_compiled_sql(tmp_project_dir, context)
operator.upload_compiled_sql(tmp_project_dir, context={"task": operator})

for file_path in files:
rel_path = os.path.relpath(str(file_path), str(source_compiled_dir))
expected_dest_path = f"mock_remote_path/test_dag/{rel_path.lstrip('/')}"
expected_dest_path = f"mock_remote_path/test_dag/compiled/{rel_path.lstrip('/')}"
mock_object_storage_path.assert_any_call(expected_dest_path, conn_id="mock_conn_id")
mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value)

0 comments on commit 77c7c6c

Please sign in to comment.