diff --git a/ray_provider/decorators/ray.py b/ray_provider/decorators/ray.py index e7d38c3..5dcc4e7 100644 --- a/ray_provider/decorators/ray.py +++ b/ray_provider/decorators/ray.py @@ -35,12 +35,19 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: self.conn_id: str = config.get("conn_id", "") + self.is_decorated_function = False if "entrypoint" in config else True self.entrypoint: str = config.get("entrypoint", "python script.py") self.runtime_env: dict[str, Any] = config.get("runtime_env", {}) self.num_cpus: int | float = config.get("num_cpus", 1) self.num_gpus: int | float = config.get("num_gpus", 0) - self.memory: int | float = config.get("memory", 0) + self.memory: int | float = config.get("memory", None) + self.ray_resources: dict[str, Any] | None = config.get("resources", None) + self.fetch_logs: bool = config.get("fetch_logs", True) + self.wait_for_completion: bool = config.get("wait_for_completion", True) + self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600) + self.poll_interval: int = config.get("poll_interval", 60) + self.xcom_task_key: str | None = config.get("xcom_task_key", None) self.config = config if not isinstance(self.num_cpus, (int, float)): @@ -55,6 +62,12 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: num_cpus=self.num_cpus, num_gpus=self.num_gpus, memory=self.memory, + resources=self.ray_resources, + fetch_logs=self.fetch_logs, + wait_for_completion=self.wait_for_completion, + job_timeout_seconds=self.job_timeout_seconds, + poll_interval=self.poll_interval, + xcom_task_key=self.xcom_task_key, **kwargs, ) @@ -66,19 +79,25 @@ def execute(self, context: Context) -> Any: :return: The result of the Ray job execution. :raises AirflowException: If job submission fails. """ - tmp_dir = mkdtemp(prefix="ray_") + tmp_dir = None try: - py_source = self.get_python_source().splitlines() # type: ignore - function_body = textwrap.dedent("\n".join(py_source[1:])) + if self.is_decorated_function: + self.log.info( + f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}" + ) + tmp_dir = mkdtemp(prefix="ray_") + py_source = self.get_python_source().splitlines() # type: ignore + function_body = textwrap.dedent("\n".join(py_source)) + + script_filename = os.path.join(tmp_dir, "script.py") + with open(script_filename, "w") as file: + all_args_str = self._build_args_str() + script_body = f"{function_body}\n{self._extract_function_name()}({all_args_str})" + file.write(script_body) + + self.entrypoint = "python script.py" + self.runtime_env["working_dir"] = tmp_dir - script_filename = os.path.join(tmp_dir, "script.py") - with open(script_filename, "w") as file: - all_args_str = self._build_args_str() - script_body = f"{function_body}\n{self._extract_function_name()}({all_args_str})" - file.write(script_body) - - self.entrypoint = "python script.py" - self.runtime_env["working_dir"] = tmp_dir self.log.info("Running ray job...") result = super().execute(context) @@ -87,7 +106,7 @@ def execute(self, context: Context) -> Any: self.log.error(f"Failed during execution with error: {e}") raise AirflowException("Job submission failed") from e finally: - if os.path.exists(tmp_dir): + if tmp_dir and os.path.exists(tmp_dir): shutil.rmtree(tmp_dir) def _build_args_str(self) -> str: @@ -109,22 +128,24 @@ def _extract_function_name(self) -> str: return self.python_callable.__name__ -def ray_task( - python_callable: Callable[..., Any] | None = None, - multiple_outputs: bool | None = None, - **kwargs: Any, -) -> TaskDecorator: - """ - Decorator to define a task that submits a Ray job. +class task: + @staticmethod + def ray( + python_callable: Callable[..., Any] | None = None, + multiple_outputs: bool | None = None, + **kwargs: Any, + ) -> TaskDecorator: + """ + Decorator to define a task that submits a Ray job. - :param python_callable: The callable function to decorate. - :param multiple_outputs: If True, will return multiple outputs. - :param kwargs: Additional keyword arguments. - :return: The decorated task. - """ - return task_decorator_factory( - python_callable=python_callable, - multiple_outputs=multiple_outputs, - decorated_operator_class=_RayDecoratedOperator, - **kwargs, - ) + :param python_callable: The callable function to decorate. + :param multiple_outputs: If True, will return multiple outputs. + :param kwargs: Additional keyword arguments. + :return: The decorated task. + """ + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=multiple_outputs, + decorated_operator_class=_RayDecoratedOperator, + **kwargs, + ) diff --git a/tests/decorators/test_ray_decorators.py b/tests/decorators/test_ray_decorators.py index 56d1c90..545197f 100644 --- a/tests/decorators/test_ray_decorators.py +++ b/tests/decorators/test_ray_decorators.py @@ -2,9 +2,10 @@ import pytest from airflow.decorators.base import _TaskDecorator +from airflow.exceptions import AirflowException from airflow.utils.context import Context -from ray_provider.decorators.ray import _RayDecoratedOperator, ray_task +from ray_provider.decorators.ray import _RayDecoratedOperator, task from ray_provider.operators.ray import SubmitRayJob DEFAULT_DATE = "2023-01-01" @@ -15,12 +16,18 @@ class TestRayDecoratedOperator: def test_initialization(self): config = { - "host": "http://localhost:8265", + "conn_id": "ray_default", "entrypoint": "python my_script.py", "runtime_env": {"pip": ["ray"]}, "num_cpus": 2, "num_gpus": 1, - "memory": "1G", + "memory": 1024, + "resources": {"custom_resource": 1}, + "fetch_logs": True, + "wait_for_completion": True, + "job_timeout_seconds": 300, + "poll_interval": 30, + "xcom_task_key": "ray_result", } def dummy_callable(): @@ -28,17 +35,60 @@ def dummy_callable(): operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + assert operator.conn_id == "ray_default" assert operator.entrypoint == "python my_script.py" assert operator.runtime_env == {"pip": ["ray"]} assert operator.num_cpus == 2 assert operator.num_gpus == 1 - assert operator.memory == "1G" + assert operator.memory == 1024 + assert operator.ray_resources == {"custom_resource": 1} + assert operator.fetch_logs == True + assert operator.wait_for_completion == True + assert operator.job_timeout_seconds == 300 + assert operator.poll_interval == 30 + assert operator.xcom_task_key == "ray_result" + + def test_initialization_defaults(self): + config = {} + + def dummy_callable(): + pass + + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + assert operator.conn_id == "" + assert operator.entrypoint == "python script.py" + assert operator.runtime_env == {} + assert operator.num_cpus == 1 + assert operator.num_gpus == 0 + assert operator.memory is None + assert operator.resources is None + assert operator.fetch_logs == True + assert operator.wait_for_completion == True + assert operator.job_timeout_seconds == 600 + assert operator.poll_interval == 60 + assert operator.xcom_task_key is None + + def test_invalid_config_raises_exception(self): + config = { + "num_cpus": "invalid_number", + } + + def dummy_callable(): + pass + + with pytest.raises(TypeError): + _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + config["num_cpus"] = 1 + config["num_gpus"] = "invalid_number" + with pytest.raises(TypeError): + _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) @patch.object(_RayDecoratedOperator, "get_python_source") @patch.object(SubmitRayJob, "execute") - def test_execute(self, mock_super_execute, mock_get_python_source): + def test_execute_decorated_function(self, mock_super_execute, mock_get_python_source): config = { - "entrypoint": "python my_script.py", "runtime_env": {"pip": ["ray"]}, } @@ -47,15 +97,17 @@ def dummy_callable(): context = MagicMock(spec=Context) operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - - mock_get_python_source.return_value = "def my_function():\n pass\n" + mock_get_python_source.return_value = "def dummy_callable():\n pass\n" mock_super_execute.return_value = "success" result = operator.execute(context) assert result == "success" + assert operator.entrypoint == "python script.py" + assert "working_dir" in operator.runtime_env - def test_missing_host_config(self): + @patch.object(SubmitRayJob, "execute") + def test_execute_with_entrypoint(self, mock_super_execute): config = { "entrypoint": "python my_script.py", } @@ -63,36 +115,84 @@ def test_missing_host_config(self): def dummy_callable(): pass + context = MagicMock(spec=Context) operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + mock_super_execute.return_value = "success" + + result = operator.execute(context) + + assert result == "success" assert operator.entrypoint == "python my_script.py" - def test_invalid_config_raises_exception(self): - config = { - "host": "http://localhost:8265", - "entrypoint": "python my_script.py", - "runtime_env": {"pip": ["ray"]}, - "num_cpus": "invalid_number", - } + @patch.object(SubmitRayJob, "execute") + def test_execute_failure(self, mock_super_execute): + config = {} def dummy_callable(): pass - with pytest.raises(TypeError): - _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + context = MagicMock(spec=Context) + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + mock_super_execute.side_effect = Exception("Ray job failed") + with pytest.raises(AirflowException): + operator.execute(context) -class TestRayTaskDecorator: + def test_build_args_str(self): + config = {} + + def dummy_callable(arg1, arg2, kwarg1="default"): + pass + + operator = _RayDecoratedOperator( + task_id="test_task", + config=config, + python_callable=dummy_callable, + op_args=["value1", "value2"], + op_kwargs={"kwarg1": "custom"}, + ) + + args_str = operator._build_args_str() + assert args_str == "'value1', 'value2', kwarg1='custom'" + + def test_extract_function_name(self): + config = {} + + def dummy_callable(): + pass + + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + function_name = operator._extract_function_name() + assert function_name == "dummy_callable" + + +class TestRayTaskDecorator: def test_ray_task_decorator(self): + @task.ray() def dummy_function(): return "dummy" - decorator = ray_task(python_callable=dummy_function) - assert isinstance(decorator, _TaskDecorator) + assert isinstance(dummy_function, _TaskDecorator) def test_ray_task_decorator_with_multiple_outputs(self): + @task.ray(multiple_outputs=True) def dummy_function(): return {"key": "value"} - decorator = ray_task(python_callable=dummy_function, multiple_outputs=True) - assert isinstance(decorator, _TaskDecorator) + assert isinstance(dummy_function, _TaskDecorator) + + def test_ray_task_decorator_with_config(self): + config = { + "num_cpus": 2, + "num_gpus": 1, + "memory": 1024, + } + + @task.ray(**config) + def dummy_function(): + return "dummy" + + assert isinstance(dummy_function, _TaskDecorator) + # We can't directly access the config here, but we can check if the decorator was applied + assert dummy_function.operator_class == _RayDecoratedOperator