Skip to content

Commit

Permalink
feat: add adjustment of python file in case of local for spark_python
Browse files Browse the repository at this point in the history
  • Loading branch information
mikita-sakalouski committed Sep 14, 2024
1 parent 53ab6cc commit 45c115f
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 78 deletions.
165 changes: 91 additions & 74 deletions brickflow/codegen/databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

import yaml
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -400,9 +400,9 @@ def __init__(
id_: str,
env: str,
mutators: Optional[List[DatabricksBundleResourceMutator]] = None,
**kwargs: Any,
**_kwargs: Any,
) -> None:
super().__init__(project, id_, env, **kwargs)
super().__init__(project, id_, env, **_kwargs)
self.imports: List[ImportBlock] = []
self.mutators = mutators or [
DatabricksBundleTagsAndNameMutator(),
Expand All @@ -422,15 +422,36 @@ def workflow_obj_to_schedule(workflow: Workflow) -> Optional[JobsSchedule]:
)
return None

def get_entrypoint_notebook_source(self) -> str:
def adjust_source(self) -> str:
return "WORKSPACE" if self.env == BrickflowDefaultEnvs.LOCAL.value else "GIT"

def task_to_task_obj(self, task: Task) -> Union[JobsTasksNotebookTask]:
def adjust_file_path(self, file_path: str) -> str:
if (
self.env == BrickflowDefaultEnvs.LOCAL.value
and self.project.bundle_base_path is not None
and self.project.bundle_obj_name is not None
):
bundle_files_local_path = "/".join(
[
self.project.bundle_base_path,
self.project.bundle_obj_name,
self.project.name,
str(BrickflowDefaultEnvs.LOCAL.value),
]
)
file_path = (
bundle_files_local_path + file_path
if file_path.startswith("/")
else f"{bundle_files_local_path}/{file_path}"
)
return file_path

def task_to_task_obj(self, task: Task) -> JobsTasksNotebookTask:
if task.task_type in [TaskType.BRICKFLOW_TASK, TaskType.CUSTOM_PYTHON_TASK]:
generated_path = handle_mono_repo_path(self.project, self.env)
return JobsTasksNotebookTask(
**task.get_obj_dict(generated_path),
source=self.get_entrypoint_notebook_source(),
source=self.adjust_source(),
)

def workflow_obj_to_pipelines(self, workflow: Workflow) -> Dict[str, Pipelines]:
Expand Down Expand Up @@ -459,6 +480,7 @@ def _build_native_notebook_task(
task_libraries: List[JobsTasksLibraries],
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**_kwargs: Any,
) -> JobsTasks:
try:
notebook_task: JobsTasksNotebookTask = task.task_func()
Expand Down Expand Up @@ -486,6 +508,7 @@ def _build_native_spark_jar_task(
task_libraries: List[JobsTasksLibraries],
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**_kwargs: Any,
) -> JobsTasks:
try:
spark_jar_task: JobsTasksSparkJarTask = task.task_func()
Expand Down Expand Up @@ -513,6 +536,7 @@ def _build_native_spark_python_task(
task_libraries: List[JobsTasksLibraries],
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**_kwargs: Any,
) -> JobsTasks:
try:
spark_python_task: JobsTasksSparkPythonTask = task.task_func()
Expand All @@ -522,6 +546,11 @@ def _build_native_spark_python_task(
f"Make sure {task_name} returns a SparkPythonTask object."
) from e

spark_python_task.source = self.adjust_source()
spark_python_task.python_file = self.adjust_file_path(
file_path=spark_python_task.python_file
)

return JobsTasks(
**task_settings.to_tf_dict(),
spark_python_task=spark_python_task,
Expand All @@ -539,6 +568,7 @@ def _build_native_run_job_task(
task: Task,
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**_kwargs: Any,
) -> JobsTasks:
try:
run_job_task: JobsTasksRunJobTask = task.task_func()
Expand All @@ -560,6 +590,7 @@ def _build_native_sql_file_task(
task: Task,
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**_kwargs: Any,
) -> JobsTasks:
try:
sql_task: JobsTasksSqlTask = task.task_func()
Expand All @@ -582,6 +613,7 @@ def _build_native_condition_task(
task: Task,
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**_kwargs: Any,
) -> JobsTasks:
try:
condition_task: JobsTasksConditionTask = task.task_func()
Expand All @@ -605,6 +637,7 @@ def _build_dlt_task(
workflow: Workflow,
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**_kwargs: Any,
) -> JobsTasks:
dlt_task: DLTPipeline = task.task_func()
# tasks.append(Pipelines(**dlt_task.to_dict())) # TODO: fix this so pipeline also gets created
Expand All @@ -619,11 +652,49 @@ def _build_dlt_task(
task_key=task_name,
)

def _build_brickflow_entrypoint_task(
self,
task_name: str,
task: Task,
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
task_libraries: List[JobsTasksLibraries],
**_kwargs: Any,
) -> JobsTasks:
task_obj = JobsTasks(
**{
task.databricks_task_type_str: self.task_to_task_obj(task),
**task_settings.to_tf_dict(),
}, # type: ignore
libraries=task_libraries,
depends_on=depends_on,
task_key=task_name,
# unpack dictionary provided by cluster object, will either be key or
# existing cluster id
**task.cluster.job_task_field_dict,
)
return task_obj

def workflow_obj_to_tasks(
self, workflow: Workflow
) -> List[Union[JobsTasks, Pipelines]]:
tasks = []

map_task_type_to_builder: Dict[TaskType, Callable[..., Any]] = {
TaskType.DLT: self._build_dlt_task,
TaskType.NOTEBOOK_TASK: self._build_native_notebook_task,
TaskType.SPARK_JAR_TASK: self._build_native_spark_jar_task,
TaskType.SPARK_PYTHON_TASK: self._build_native_spark_python_task,
TaskType.RUN_JOB_TASK: self._build_native_run_job_task,
TaskType.SQL: self._build_native_sql_file_task,
TaskType.IF_ELSE_CONDITION_TASK: self._build_native_condition_task,
}

for task_name, task in workflow.tasks.items():
builder_func = map_task_type_to_builder.get(
task.task_type, self._build_brickflow_entrypoint_task
)

# TODO: DLT
# pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task)
if task.depends_on_names:
Expand All @@ -642,75 +713,21 @@ def workflow_obj_to_tasks(
libraries += get_brickflow_libraries(workflow.enable_plugins)

task_libraries = [
JobsTasksLibraries(**library.dict)
for library in libraries # type: ignore
]
task_settings = workflow.default_task_settings.merge(task.task_settings)
if task.task_type == TaskType.DLT:
# native dlt task
tasks.append(
self._build_dlt_task(
task_name, task, workflow, task_settings, depends_on
)
)
elif task.task_type == TaskType.NOTEBOOK_TASK:
# native notebook task
tasks.append(
self._build_native_notebook_task(
task_name, task, task_libraries, task_settings, depends_on
)
)
elif task.task_type == TaskType.SPARK_JAR_TASK:
# native jar task
tasks.append(
self._build_native_spark_jar_task(
task_name, task, task_libraries, task_settings, depends_on
)
)
elif task.task_type == TaskType.SPARK_PYTHON_TASK:
# native python task
tasks.append(
self._build_native_spark_python_task(
task_name, task, task_libraries, task_settings, depends_on
)
)
elif task.task_type == TaskType.RUN_JOB_TASK:
# native run job task
tasks.append(
self._build_native_run_job_task(
task_name, task, task_settings, depends_on
)
)
elif task.task_type == TaskType.SQL:
# native SQL task
tasks.append(
self._build_native_sql_file_task(
task_name, task, task_settings, depends_on
)
)
elif task.task_type == TaskType.IF_ELSE_CONDITION_TASK:
# native If/Else task
tasks.append(
self._build_native_condition_task(
task_name, task, task_settings, depends_on
)
)
else:
# brickflow entrypoint task
task_obj = JobsTasks(
**{
task.databricks_task_type_str: self.task_to_task_obj(task),
**task_settings.to_tf_dict(),
}, # type: ignore
libraries=task_libraries,
depends_on=depends_on,
task_key=task_name,
# unpack dictionary provided by cluster object, will either be key or
# existing cluster id
**task.cluster.job_task_field_dict,
)
tasks.append(task_obj)
JobsTasksLibraries(**library.dict) for library in libraries
] # type: ignore
task_settings = workflow.default_task_settings.merge(task.task_settings) # type: ignore
task = builder_func(
task_name=task_name,
task=task,
workflow=workflow,
task_libraries=task_libraries,
task_settings=task_settings,
depends_on=depends_on,
)
tasks.append(task)

tasks.sort(key=lambda t: (t.task_key is None, t.task_key))

return tasks

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions tests/codegen/expected_bundles/local_bundle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,13 @@
"libraries":
- "pypi":
"package": "koheesio"
"repo": null
"max_retries": null
"min_retry_interval_millis": null
"retry_on_timeout": null
"spark_python_task":
"python_file": "path/to/python/file.py"
"source": "GIT"
"python_file": "/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/path/to/python/file.py"
"source": "WORKSPACE"
"parameters":
- "--param1"
- "World!"
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ targets:
max_retries: null
min_retry_interval_millis: null
spark_python_task:
python_file: path/to/python/file.py
python_file: /Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/path/to/python/file.py
parameters: ["--param1", "World!"]
source: GIT
source: WORKSPACE
retry_on_timeout: null
task_key: spark_python_task_a
timeout_seconds: null
Expand Down

0 comments on commit 45c115f

Please sign in to comment.