From b8e700b21f23d5df6a27681ec4ad277ccad64592 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Sun, 15 Sep 2024 23:38:41 +0200 Subject: [PATCH] fix: adjust logic for workspace --- brickflow/codegen/databricks_bundle.py | 23 +++++++++++++++++-- .../expected_bundles/dev_bundle_monorepo.yml | 2 +- .../expected_bundles/dev_bundle_polyrepo.yml | 2 +- .../dev_bundle_polyrepo_with_auto_libs.yml | 2 +- .../codegen/expected_bundles/local_bundle.yml | 2 +- .../local_bundle_prefix_suffix.yml | 2 +- tests/codegen/sample_workflows.py | 2 +- tests/engine/test_task.py | 8 +++---- 8 files changed, 31 insertions(+), 12 deletions(-) diff --git a/brickflow/codegen/databricks_bundle.py b/brickflow/codegen/databricks_bundle.py index 20053a5..e01d5f1 100644 --- a/brickflow/codegen/databricks_bundle.py +++ b/brickflow/codegen/databricks_bundle.py @@ -426,6 +426,15 @@ def adjust_source(self) -> str: return "WORKSPACE" if self.env == BrickflowDefaultEnvs.LOCAL.value else "GIT" def adjust_file_path(self, file_path: str) -> str: + """ + Adjusts the given file path based on the environment and project settings. + If the environment is local and the project has a defined bundle base path and bundle object name, + the method constructs a new file path by appending the local bundle path to the given file path. + Args: + file_path (str): The original file path to be adjusted. + Returns: + str: The adjusted file path. + """ if ( self.env == BrickflowDefaultEnvs.LOCAL.value and self.project.bundle_base_path is not None @@ -433,16 +442,26 @@ def adjust_file_path(self, file_path: str) -> str: ): bundle_files_local_path = "/".join( [ + "Workspace", self.project.bundle_base_path, self.project.bundle_obj_name, self.project.name, str(BrickflowDefaultEnvs.LOCAL.value), + "files", ] - ) + ).replace("//", "/") + + # Finds the start position of the project name in the given file path and calculates the cut position. + # - `file_path.find(self.project.name)`: Finds the start index of the project name in the file path. + # - `+ len(self.project.name) + 1`: Moves the start position to the character after the project name. + # - Adjusts the file path by appending the local bundle path to the cut file path. + cut_file_path = file_path[ + file_path.find(self.project.name) + len(self.project.name) + 1 : + ] file_path = ( bundle_files_local_path + file_path if file_path.startswith("/") - else f"{bundle_files_local_path}/{file_path}" + else f"{bundle_files_local_path}/{cut_file_path}" ) return file_path diff --git a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml index 1dcbfac..d8ac64b 100644 --- a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml @@ -123,7 +123,7 @@ targets: max_retries: null min_retry_interval_millis: null spark_python_task: - python_file: path/to/python/file.py + python_file: ./products/test-project/spark/python/src/run_task.py parameters: ["--param1", "World!"] source: GIT retry_on_timeout: null diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml index 43ad68c..4438f59 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml @@ -123,7 +123,7 @@ targets: max_retries: null min_retry_interval_millis: null spark_python_task: - python_file: path/to/python/file.py + python_file: ./products/test-project/spark/python/src/run_task.py parameters: ["--param1", "World!"] source: GIT retry_on_timeout: null diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml index fbe3125..70acc97 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml @@ -208,7 +208,7 @@ targets: max_retries: null min_retry_interval_millis: null spark_python_task: - python_file: path/to/python/file.py + python_file: ./products/test-project/spark/python/src/run_task.py parameters: ["--param1", "World!"] source: GIT retry_on_timeout: null diff --git a/tests/codegen/expected_bundles/local_bundle.yml b/tests/codegen/expected_bundles/local_bundle.yml index 6f58acd..7c66eca 100644 --- a/tests/codegen/expected_bundles/local_bundle.yml +++ b/tests/codegen/expected_bundles/local_bundle.yml @@ -242,7 +242,7 @@ "min_retry_interval_millis": null "retry_on_timeout": null "spark_python_task": - "python_file": "/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/path/to/python/file.py" + "python_file": "Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/spark/python/src/run_task.py" "source": "WORKSPACE" "parameters": - "--param1" diff --git a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml index e3d2e08..161f185 100644 --- a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml +++ b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml @@ -120,7 +120,7 @@ targets: max_retries: null min_retry_interval_millis: null spark_python_task: - python_file: /Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/path/to/python/file.py + python_file: Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/spark/python/src/run_task.py parameters: ["--param1", "World!"] source: WORKSPACE retry_on_timeout: null diff --git a/tests/codegen/sample_workflows.py b/tests/codegen/sample_workflows.py index 03739a4..17088a7 100644 --- a/tests/codegen/sample_workflows.py +++ b/tests/codegen/sample_workflows.py @@ -82,7 +82,7 @@ def spark_jar_task_a(): ) def spark_python_task_a(): return SparkPythonTask( - python_file="path/to/python/file.py", + python_file="./products/test-project/spark/python/src/run_task.py", source="GIT", parameters=["--param1", "World!"], ) # type: ignore diff --git a/tests/engine/test_task.py b/tests/engine/test_task.py index a0ee982..f502378 100644 --- a/tests/engine/test_task.py +++ b/tests/engine/test_task.py @@ -539,19 +539,19 @@ def test_without_params_spark_jar(self): def test_init_spark_python(self): task = SparkPythonTask( - python_file="path/to/python/file.py", + python_file="./products/test-project/path/to/python/file.py", source="GIT", parameters=["--param1", "World!"], ) - assert task.python_file == "path/to/python/file.py" + assert task.python_file == "./products/test-project/path/to/python/file.py" assert task.source == "GIT" assert task.parameters == ["--param1", "World!"] def test_without_params_spark_python(self): task = SparkPythonTask( - python_file="path/to/python/file.py", + python_file="./products/test-project/path/to/python/file.py", ) - assert task.python_file == "path/to/python/file.py" + assert task.python_file == "./products/test-project/path/to/python/file.py" assert task.source is None assert task.parameters is None