From 45c115fc74c6f771dc132154700d535a2304bd18 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Sat, 14 Sep 2024 02:30:42 +0200 Subject: [PATCH] feat: add adjustment of python file in case of local for spark_python --- brickflow/codegen/databricks_bundle.py | 165 ++++++++++-------- .../codegen/expected_bundles/local_bundle.yml | 5 +- .../local_bundle_prefix_suffix.yml | 4 +- 3 files changed, 96 insertions(+), 78 deletions(-) diff --git a/brickflow/codegen/databricks_bundle.py b/brickflow/codegen/databricks_bundle.py index 2bc9dff..20053a5 100644 --- a/brickflow/codegen/databricks_bundle.py +++ b/brickflow/codegen/databricks_bundle.py @@ -6,7 +6,7 @@ import typing from enum import Enum from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import yaml from databricks.sdk import WorkspaceClient @@ -400,9 +400,9 @@ def __init__( id_: str, env: str, mutators: Optional[List[DatabricksBundleResourceMutator]] = None, - **kwargs: Any, + **_kwargs: Any, ) -> None: - super().__init__(project, id_, env, **kwargs) + super().__init__(project, id_, env, **_kwargs) self.imports: List[ImportBlock] = [] self.mutators = mutators or [ DatabricksBundleTagsAndNameMutator(), @@ -422,15 +422,36 @@ def workflow_obj_to_schedule(workflow: Workflow) -> Optional[JobsSchedule]: ) return None - def get_entrypoint_notebook_source(self) -> str: + def adjust_source(self) -> str: return "WORKSPACE" if self.env == BrickflowDefaultEnvs.LOCAL.value else "GIT" - def task_to_task_obj(self, task: Task) -> Union[JobsTasksNotebookTask]: + def adjust_file_path(self, file_path: str) -> str: + if ( + self.env == BrickflowDefaultEnvs.LOCAL.value + and self.project.bundle_base_path is not None + and self.project.bundle_obj_name is not None + ): + bundle_files_local_path = "/".join( + [ + self.project.bundle_base_path, + self.project.bundle_obj_name, + self.project.name, + str(BrickflowDefaultEnvs.LOCAL.value), + ] + ) + file_path = ( + bundle_files_local_path + file_path + if file_path.startswith("/") + else f"{bundle_files_local_path}/{file_path}" + ) + return file_path + + def task_to_task_obj(self, task: Task) -> JobsTasksNotebookTask: if task.task_type in [TaskType.BRICKFLOW_TASK, TaskType.CUSTOM_PYTHON_TASK]: generated_path = handle_mono_repo_path(self.project, self.env) return JobsTasksNotebookTask( **task.get_obj_dict(generated_path), - source=self.get_entrypoint_notebook_source(), + source=self.adjust_source(), ) def workflow_obj_to_pipelines(self, workflow: Workflow) -> Dict[str, Pipelines]: @@ -459,6 +480,7 @@ def _build_native_notebook_task( task_libraries: List[JobsTasksLibraries], task_settings: TaskSettings, depends_on: List[JobsTasksDependsOn], + **_kwargs: Any, ) -> JobsTasks: try: notebook_task: JobsTasksNotebookTask = task.task_func() @@ -486,6 +508,7 @@ def _build_native_spark_jar_task( task_libraries: List[JobsTasksLibraries], task_settings: TaskSettings, depends_on: List[JobsTasksDependsOn], + **_kwargs: Any, ) -> JobsTasks: try: spark_jar_task: JobsTasksSparkJarTask = task.task_func() @@ -513,6 +536,7 @@ def _build_native_spark_python_task( task_libraries: List[JobsTasksLibraries], task_settings: TaskSettings, depends_on: List[JobsTasksDependsOn], + **_kwargs: Any, ) -> JobsTasks: try: spark_python_task: JobsTasksSparkPythonTask = task.task_func() @@ -522,6 +546,11 @@ def _build_native_spark_python_task( f"Make sure {task_name} returns a SparkPythonTask object." ) from e + spark_python_task.source = self.adjust_source() + spark_python_task.python_file = self.adjust_file_path( + file_path=spark_python_task.python_file + ) + return JobsTasks( **task_settings.to_tf_dict(), spark_python_task=spark_python_task, @@ -539,6 +568,7 @@ def _build_native_run_job_task( task: Task, task_settings: TaskSettings, depends_on: List[JobsTasksDependsOn], + **_kwargs: Any, ) -> JobsTasks: try: run_job_task: JobsTasksRunJobTask = task.task_func() @@ -560,6 +590,7 @@ def _build_native_sql_file_task( task: Task, task_settings: TaskSettings, depends_on: List[JobsTasksDependsOn], + **_kwargs: Any, ) -> JobsTasks: try: sql_task: JobsTasksSqlTask = task.task_func() @@ -582,6 +613,7 @@ def _build_native_condition_task( task: Task, task_settings: TaskSettings, depends_on: List[JobsTasksDependsOn], + **_kwargs: Any, ) -> JobsTasks: try: condition_task: JobsTasksConditionTask = task.task_func() @@ -605,6 +637,7 @@ def _build_dlt_task( workflow: Workflow, task_settings: TaskSettings, depends_on: List[JobsTasksDependsOn], + **_kwargs: Any, ) -> JobsTasks: dlt_task: DLTPipeline = task.task_func() # tasks.append(Pipelines(**dlt_task.to_dict())) # TODO: fix this so pipeline also gets created @@ -619,11 +652,49 @@ def _build_dlt_task( task_key=task_name, ) + def _build_brickflow_entrypoint_task( + self, + task_name: str, + task: Task, + task_settings: TaskSettings, + depends_on: List[JobsTasksDependsOn], + task_libraries: List[JobsTasksLibraries], + **_kwargs: Any, + ) -> JobsTasks: + task_obj = JobsTasks( + **{ + task.databricks_task_type_str: self.task_to_task_obj(task), + **task_settings.to_tf_dict(), + }, # type: ignore + libraries=task_libraries, + depends_on=depends_on, + task_key=task_name, + # unpack dictionary provided by cluster object, will either be key or + # existing cluster id + **task.cluster.job_task_field_dict, + ) + return task_obj + def workflow_obj_to_tasks( self, workflow: Workflow ) -> List[Union[JobsTasks, Pipelines]]: tasks = [] + + map_task_type_to_builder: Dict[TaskType, Callable[..., Any]] = { + TaskType.DLT: self._build_dlt_task, + TaskType.NOTEBOOK_TASK: self._build_native_notebook_task, + TaskType.SPARK_JAR_TASK: self._build_native_spark_jar_task, + TaskType.SPARK_PYTHON_TASK: self._build_native_spark_python_task, + TaskType.RUN_JOB_TASK: self._build_native_run_job_task, + TaskType.SQL: self._build_native_sql_file_task, + TaskType.IF_ELSE_CONDITION_TASK: self._build_native_condition_task, + } + for task_name, task in workflow.tasks.items(): + builder_func = map_task_type_to_builder.get( + task.task_type, self._build_brickflow_entrypoint_task + ) + # TODO: DLT # pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task) if task.depends_on_names: @@ -642,75 +713,21 @@ def workflow_obj_to_tasks( libraries += get_brickflow_libraries(workflow.enable_plugins) task_libraries = [ - JobsTasksLibraries(**library.dict) - for library in libraries # type: ignore - ] - task_settings = workflow.default_task_settings.merge(task.task_settings) - if task.task_type == TaskType.DLT: - # native dlt task - tasks.append( - self._build_dlt_task( - task_name, task, workflow, task_settings, depends_on - ) - ) - elif task.task_type == TaskType.NOTEBOOK_TASK: - # native notebook task - tasks.append( - self._build_native_notebook_task( - task_name, task, task_libraries, task_settings, depends_on - ) - ) - elif task.task_type == TaskType.SPARK_JAR_TASK: - # native jar task - tasks.append( - self._build_native_spark_jar_task( - task_name, task, task_libraries, task_settings, depends_on - ) - ) - elif task.task_type == TaskType.SPARK_PYTHON_TASK: - # native python task - tasks.append( - self._build_native_spark_python_task( - task_name, task, task_libraries, task_settings, depends_on - ) - ) - elif task.task_type == TaskType.RUN_JOB_TASK: - # native run job task - tasks.append( - self._build_native_run_job_task( - task_name, task, task_settings, depends_on - ) - ) - elif task.task_type == TaskType.SQL: - # native SQL task - tasks.append( - self._build_native_sql_file_task( - task_name, task, task_settings, depends_on - ) - ) - elif task.task_type == TaskType.IF_ELSE_CONDITION_TASK: - # native If/Else task - tasks.append( - self._build_native_condition_task( - task_name, task, task_settings, depends_on - ) - ) - else: - # brickflow entrypoint task - task_obj = JobsTasks( - **{ - task.databricks_task_type_str: self.task_to_task_obj(task), - **task_settings.to_tf_dict(), - }, # type: ignore - libraries=task_libraries, - depends_on=depends_on, - task_key=task_name, - # unpack dictionary provided by cluster object, will either be key or - # existing cluster id - **task.cluster.job_task_field_dict, - ) - tasks.append(task_obj) + JobsTasksLibraries(**library.dict) for library in libraries + ] # type: ignore + task_settings = workflow.default_task_settings.merge(task.task_settings) # type: ignore + task = builder_func( + task_name=task_name, + task=task, + workflow=workflow, + task_libraries=task_libraries, + task_settings=task_settings, + depends_on=depends_on, + ) + tasks.append(task) + tasks.sort(key=lambda t: (t.task_key is None, t.task_key)) + return tasks @staticmethod diff --git a/tests/codegen/expected_bundles/local_bundle.yml b/tests/codegen/expected_bundles/local_bundle.yml index 8ca4ecb..6f58acd 100644 --- a/tests/codegen/expected_bundles/local_bundle.yml +++ b/tests/codegen/expected_bundles/local_bundle.yml @@ -237,12 +237,13 @@ "libraries": - "pypi": "package": "koheesio" + "repo": null "max_retries": null "min_retry_interval_millis": null "retry_on_timeout": null "spark_python_task": - "python_file": "path/to/python/file.py" - "source": "GIT" + "python_file": "/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/path/to/python/file.py" + "source": "WORKSPACE" "parameters": - "--param1" - "World!" diff --git a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml index 303a00f..e3d2e08 100644 --- a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml +++ b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml @@ -120,9 +120,9 @@ targets: max_retries: null min_retry_interval_millis: null spark_python_task: - python_file: path/to/python/file.py + python_file: /Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/path/to/python/file.py parameters: ["--param1", "World!"] - source: GIT + source: WORKSPACE retry_on_timeout: null task_key: spark_python_task_a timeout_seconds: null