diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index eaf732bfa..9de21292e 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -259,6 +259,7 @@ def _add_dbt_compile_task( execution_mode: ExecutionMode, task_args: dict[str, Any], tasks_map: dict[str, Any], + task_group: TaskGroup | None, ) -> None: if execution_mode != ExecutionMode.AIRFLOW_ASYNC: return @@ -269,7 +270,7 @@ def _add_dbt_compile_task( arguments=task_args, extra_context={}, ) - compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=None) + compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group) tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task for node_id, node in nodes.items(): @@ -357,7 +358,7 @@ def build_airflow_graph( for leaf_node_id in leaves_ids: tasks_map[leaf_node_id] >> test_task - _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map) + _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group) create_airflow_task_dependencies(nodes, tasks_map) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 8cda2f883..8bbee1a44 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -310,6 +310,28 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: return _configured_target_path, remote_conn_id + def _construct_dest_file_path( + self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, context: Context + ) -> str: + """ + Construct the destination path for the compiled SQL files to be uploaded to the remote store. + """ + dest_target_dir_str = str(dest_target_dir).rstrip("/") + + task = context["task"] + dag_id = task.dag_id + task_group_id = task.task_group.group_id if task.task_group else None + identifiers_list = [] + if dag_id: + identifiers_list.append(dag_id) + if task_group_id: + identifiers_list.append(task_group_id) + dag_task_group_identifier = "__".join(identifiers_list) + + rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/") + + return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}" + def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: """ Uploads the compiled SQL files from the dbt compile output to the remote store. @@ -327,16 +349,11 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled" files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()] - for file_path in files: - rel_path = os.path.relpath(file_path, source_compiled_dir) - - dest_path = ObjectStoragePath( - f"{str(dest_target_dir).rstrip('/')}/{context['dag'].dag_id}/{rel_path.lstrip('/')}", - conn_id=dest_conn_id, - ) - ObjectStoragePath(file_path).copy(dest_path) - self.log.debug("Copied %s to %s", file_path, dest_path) + dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, context) + dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id) + ObjectStoragePath(file_path).copy(dest_object_storage_path) + self.log.debug("Copied %s to %s", file_path, dest_object_storage_path) @provide_session def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index a6e7f2caa..c7615225f 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -1241,6 +1241,7 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st task_id="fake-task", profile_config=profile_config, project_dir="fake-dir", + dag=DAG("test_dag", start_date=datetime(2024, 4, 16)), ) mock_configure_remote.return_value = ("mock_remote_path", "mock_conn_id") @@ -1259,12 +1260,10 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st files = [file1, file2] with patch.object(Path, "rglob", return_value=files): - context = {"dag": MagicMock(dag_id="test_dag")} - - operator.upload_compiled_sql(tmp_project_dir, context) + operator.upload_compiled_sql(tmp_project_dir, context={"task": operator}) for file_path in files: rel_path = os.path.relpath(str(file_path), str(source_compiled_dir)) - expected_dest_path = f"mock_remote_path/test_dag/{rel_path.lstrip('/')}" + expected_dest_path = f"mock_remote_path/test_dag/compiled/{rel_path.lstrip('/')}" mock_object_storage_path.assert_any_call(expected_dest_path, conn_id="mock_conn_id") mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value)