diff --git a/brickflow/engine/hooks.py b/brickflow/engine/hooks.py index 0a3d9de2..f3107f79 100644 --- a/brickflow/engine/hooks.py +++ b/brickflow/engine/hooks.py @@ -12,6 +12,20 @@ class BrickflowTaskPluginSpec: + @staticmethod + def handle_user_result_errors(resp: "TaskResponse") -> None: + """Custom execute method that is able to be plugged in.""" + if resp.user_code_error is not None: + original_message = str(resp.user_code_error) + additional_info = ( + "BRICKFLOW_USER_OR_DBR_ERROR: This is an error thrown in user code. \n" + f"BRICKFLOW_INPUT_ARGS: {resp.input_kwargs}\n" + "Original Exception Message: " + ) + new_message = additional_info + original_message + resp.user_code_error.args = (new_message,) + raise resp.user_code_error + @staticmethod @brickflow_plugin_spec(firstresult=True) def task_execute(task: "Task", workflow: "Workflow") -> "TaskResponse": diff --git a/brickflow/engine/task.py b/brickflow/engine/task.py index af52444d..e0a1f110 100644 --- a/brickflow/engine/task.py +++ b/brickflow/engine/task.py @@ -84,6 +84,10 @@ class InvalidTaskLibraryError(Exception): pass +class BrickflowUserCodeException(Exception): + pass + + class BrickflowTaskEnvVars(Enum): BRICKFLOW_SELECT_TASKS = "BRICKFLOW_SELECT_TASKS" @@ -303,6 +307,8 @@ def to_tf_dict( class TaskResponse: response: Any push_return_value: bool = True + user_code_error: Optional[Exception] = None + input_kwargs: Optional[Dict[str, Any]] = None @dataclass(frozen=True) @@ -373,6 +379,8 @@ class DefaultBrickflowTaskPluginImpl(BrickflowTaskPluginSpec): @brickflow_task_plugin_impl def handle_results(resp: "TaskResponse", task: "Task", workflow: "Workflow") -> Any: _ilog.info("using default for handling results") + + BrickflowTaskPluginSpec.handle_user_result_errors(resp) # by default don't do anything just return the response as is return resp @@ -381,6 +389,7 @@ def handle_results(resp: "TaskResponse", task: "Task", workflow: "Workflow") -> def task_execute(task: "Task", workflow: "Workflow") -> TaskResponse: """default execute implementation method.""" _ilog.info("using default plugin for handling task execute") + if ( task.task_type == TaskType.CUSTOM_PYTHON_TASK and task.custom_execute_callback is not None @@ -388,10 +397,18 @@ def task_execute(task: "Task", workflow: "Workflow") -> TaskResponse: _ilog.info("handling custom execute") return task.custom_execute_callback(task) else: - return TaskResponse( - task.task_func(**task.get_runtime_parameter_values()), - push_return_value=True, - ) + kwargs = task.get_runtime_parameter_values() + try: + return TaskResponse( + task.task_func(**kwargs), + user_code_error=None, + push_return_value=True, + input_kwargs=kwargs, + ) + except Exception as e: + return TaskResponse( + None, push_return_value=True, user_code_error=e, input_kwargs=kwargs + ) @functools.lru_cache @@ -706,13 +723,14 @@ def execute(self, ignore_all_deps: bool = False) -> Any: _ilog.info("Executing task... %s", self.name) _ilog.info("%s", pretty_print_function_source(self.name, self.task_func)) - initial_resp: TaskResponse = get_brickflow_tasks_hook().task_execute( + brickflow_execution_hook = get_brickflow_tasks_hook() + + initial_resp: TaskResponse = brickflow_execution_hook.task_execute( task=self, workflow=self.workflow ) - resp: TaskResponse = get_brickflow_tasks_hook().handle_results( + resp: TaskResponse = brickflow_execution_hook.handle_results( resp=initial_resp, task=self, workflow=self.workflow ) - if resp.push_return_value is True: ctx.task_coms.put(self.name, RETURN_VALUE_KEY, resp.response) ctx._reset_current_task() diff --git a/brickflow/resolver/__init__.py b/brickflow/resolver/__init__.py index 008bf230..802ae328 100644 --- a/brickflow/resolver/__init__.py +++ b/brickflow/resolver/__init__.py @@ -72,6 +72,8 @@ def get_relative_path_to_brickflow_root() -> None: _ilog.info("Sys path set to: %s", str(sys.path)) except BrickflowRootNotFound: _ilog.info("Unable to find for path: %s", path) + except PermissionError: + _ilog.info("Most likely not accessible due to shared cluster: %s", path) def get_notebook_ws_path(dbutils: Optional[Any]) -> Optional[str]: diff --git a/brickflow_plugins/airflow/brickflow_task_plugin.py b/brickflow_plugins/airflow/brickflow_task_plugin.py index 84ea7be6..7ee493aa 100644 --- a/brickflow_plugins/airflow/brickflow_task_plugin.py +++ b/brickflow_plugins/airflow/brickflow_task_plugin.py @@ -36,7 +36,15 @@ class AirflowOperatorBrickflowTaskPluginImpl(BrickflowTaskPluginSpec): def handle_results( resp: "TaskResponse", task: "Task", workflow: "Workflow" ) -> "TaskResponse": + log.info( + "using AirflowOperatorBrickflowTaskPlugin for handling results for task: %s", + task.task_id, + ) + + BrickflowTaskPluginSpec.handle_user_result_errors(resp) + _operator = resp.response + if not isinstance(_operator, BaseOperator): return resp @@ -55,10 +63,7 @@ def handle_results( epoch_to_pendulum_datetime(ctx.start_time(debug=None)), tz=workflow.timezone, ) - log.info( - "using AirflowOperatorBrickflowTaskPlugin for handling results for task: %s", - task.task_id, - ) + env: Optional[Environment] = Environment() env.globals.update({"macros": macros, "ti": context}) with BrickflowSecretsBackend(): diff --git a/tests/engine/sample_workflow.py b/tests/engine/sample_workflow.py index c28cbe33..f6c59754 100644 --- a/tests/engine/sample_workflow.py +++ b/tests/engine/sample_workflow.py @@ -33,6 +33,11 @@ def task_function(*, test="var"): return test +@wf.task() +def task_function_with_error(*, test="var"): + raise ValueError("throwing random error") + + @wf.task def task_function_no_deco_args(*, test="var"): return "hello world" diff --git a/tests/engine/test_task.py b/tests/engine/test_task.py index 63e80524..593fad8c 100644 --- a/tests/engine/test_task.py +++ b/tests/engine/test_task.py @@ -27,6 +27,8 @@ TaskLibrary, get_brickflow_lib_version, get_brickflow_libraries, + get_plugin_manager, + get_brickflow_tasks_hook, ) from tests.engine.sample_workflow import ( wf, @@ -36,6 +38,7 @@ task_function_3, task_function_4, custom_python_task_push, + task_function_with_error, ) @@ -226,6 +229,17 @@ def test_execute(self, task_coms_mock: Mock, dbutils: Mock): assert resp is task_function() + @patch("brickflow.context.ctx.get_parameter") + def test_execute_with_error(self, dbutils: Mock): + dbutils.return_value = "" + get_plugin_manager.cache_clear() + get_brickflow_tasks_hook.cache_clear() + with pytest.raises( + ValueError, + match="BRICKFLOW_USER_OR_DBR_ERROR: This is an error thrown in user code.", + ): + wf.get_task(task_function_with_error.__name__).execute() + @patch("brickflow.context.ctx.get_parameter") @patch("brickflow.context.ctx._task_coms") def test_execute_custom(self, task_coms_mock: Mock, dbutils: Mock): diff --git a/tests/engine/test_workflow.py b/tests/engine/test_workflow.py index 8c95c0f8..fdd176a9 100644 --- a/tests/engine/test_workflow.py +++ b/tests/engine/test_workflow.py @@ -155,7 +155,7 @@ def test_deco_no_args(self): wf.task("hello world") def test_get_tasks(self): - assert len(wf.tasks) == 9 + assert len(wf.tasks) == 10 def test_task_iter(self): arr = [] @@ -163,7 +163,7 @@ def test_task_iter(self): assert isinstance(t, Task) assert callable(t.task_func) arr.append(t) - assert len(arr) == 9, print([t.name for t in arr]) + assert len(arr) == 10, print([t.name for t in arr]) def test_permissions(self): assert wf.permissions.to_access_controls() == [ @@ -236,7 +236,7 @@ def test_another_workflow(self): from tests.engine.sample_workflow_2 import wf as wf1 assert len(wf1.graph.nodes) == 2 - assert len(wf.graph.nodes) == 10 + assert len(wf.graph.nodes) == 11 def test_schedule_run_status_workflow(self): this_wf = Workflow("test", clusters=[Cluster("name", "spark", "vm-node")]) diff --git a/tests/resolver/test_resolver.py b/tests/resolver/test_resolver.py new file mode 100644 index 00000000..a7c2f5e5 --- /dev/null +++ b/tests/resolver/test_resolver.py @@ -0,0 +1,52 @@ +# test_resolver.py +from typing import Type + +import pytest + +import brickflow +from brickflow.resolver import ( + BrickflowRootNotFound, +) + + +@pytest.fixture +def default_mocks(mocker): + # Create mocks for the three methods + mocker.patch( + "brickflow.resolver.get_caller_file_paths", return_value=["path1", "path2"] + ) + mocker.patch( + "brickflow.resolver.get_notebook_ws_path", return_value="/notebook/ws/path" + ) + + +def test_resolver_methods(default_mocks, mocker): # noqa + error_msg = "This is a test message" + + def make_exception_function(exc: Type[Exception]): + def raise_exception(*args, **kwargs): + raise exc(error_msg) + + return raise_exception + + # catch random error + mocker.patch( + "brickflow.resolver.go_up_till_brickflow_root", + side_effect=make_exception_function(ValueError), + ) + with pytest.raises(ValueError, match=error_msg): + brickflow.resolver.get_relative_path_to_brickflow_root() + + mocker.patch( + "brickflow.resolver.go_up_till_brickflow_root", + side_effect=make_exception_function(BrickflowRootNotFound), + ) + + brickflow.resolver.get_relative_path_to_brickflow_root() + + mocker.patch( + "brickflow.resolver.go_up_till_brickflow_root", + side_effect=make_exception_function(PermissionError), + ) + + brickflow.resolver.get_relative_path_to_brickflow_root()