Skip to content

Commit

Permalink
Decorator fixes (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath authored Jul 28, 2024
1 parent 96009cc commit f286997
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 54 deletions.
83 changes: 52 additions & 31 deletions ray_provider/decorators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,19 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):

def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)
self.config = config

if not isinstance(self.num_cpus, (int, float)):
Expand All @@ -55,6 +62,12 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
num_cpus=self.num_cpus,
num_gpus=self.num_gpus,
memory=self.memory,
resources=self.ray_resources,
fetch_logs=self.fetch_logs,
wait_for_completion=self.wait_for_completion,
job_timeout_seconds=self.job_timeout_seconds,
poll_interval=self.poll_interval,
xcom_task_key=self.xcom_task_key,
**kwargs,
)

Expand All @@ -66,19 +79,25 @@ def execute(self, context: Context) -> Any:
:return: The result of the Ray job execution.
:raises AirflowException: If job submission fails.
"""
tmp_dir = mkdtemp(prefix="ray_")
tmp_dir = None
try:
py_source = self.get_python_source().splitlines() # type: ignore
function_body = textwrap.dedent("\n".join(py_source[1:]))
if self.is_decorated_function:
self.log.info(
f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}"
)
tmp_dir = mkdtemp(prefix="ray_")
py_source = self.get_python_source().splitlines() # type: ignore
function_body = textwrap.dedent("\n".join(py_source))

script_filename = os.path.join(tmp_dir, "script.py")
with open(script_filename, "w") as file:
all_args_str = self._build_args_str()
script_body = f"{function_body}\n{self._extract_function_name()}({all_args_str})"
file.write(script_body)

self.entrypoint = "python script.py"
self.runtime_env["working_dir"] = tmp_dir

script_filename = os.path.join(tmp_dir, "script.py")
with open(script_filename, "w") as file:
all_args_str = self._build_args_str()
script_body = f"{function_body}\n{self._extract_function_name()}({all_args_str})"
file.write(script_body)

self.entrypoint = "python script.py"
self.runtime_env["working_dir"] = tmp_dir
self.log.info("Running ray job...")

result = super().execute(context)
Expand All @@ -87,7 +106,7 @@ def execute(self, context: Context) -> Any:
self.log.error(f"Failed during execution with error: {e}")
raise AirflowException("Job submission failed") from e
finally:
if os.path.exists(tmp_dir):
if tmp_dir and os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)

def _build_args_str(self) -> str:
Expand All @@ -109,22 +128,24 @@ def _extract_function_name(self) -> str:
return self.python_callable.__name__


def ray_task(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Decorator to define a task that submits a Ray job.
class task:
@staticmethod
def ray(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Decorator to define a task that submits a Ray job.
:param python_callable: The callable function to decorate.
:param multiple_outputs: If True, will return multiple outputs.
:param kwargs: Additional keyword arguments.
:return: The decorated task.
"""
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
**kwargs,
)
:param python_callable: The callable function to decorate.
:param multiple_outputs: If True, will return multiple outputs.
:param kwargs: Additional keyword arguments.
:return: The decorated task.
"""
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
**kwargs,
)
146 changes: 123 additions & 23 deletions tests/decorators/test_ray_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest
from airflow.decorators.base import _TaskDecorator
from airflow.exceptions import AirflowException
from airflow.utils.context import Context

from ray_provider.decorators.ray import _RayDecoratedOperator, ray_task
from ray_provider.decorators.ray import _RayDecoratedOperator, task
from ray_provider.operators.ray import SubmitRayJob

DEFAULT_DATE = "2023-01-01"
Expand All @@ -15,30 +16,79 @@ class TestRayDecoratedOperator:

def test_initialization(self):
config = {
"host": "http://localhost:8265",
"conn_id": "ray_default",
"entrypoint": "python my_script.py",
"runtime_env": {"pip": ["ray"]},
"num_cpus": 2,
"num_gpus": 1,
"memory": "1G",
"memory": 1024,
"resources": {"custom_resource": 1},
"fetch_logs": True,
"wait_for_completion": True,
"job_timeout_seconds": 300,
"poll_interval": 30,
"xcom_task_key": "ray_result",
}

def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

assert operator.conn_id == "ray_default"
assert operator.entrypoint == "python my_script.py"
assert operator.runtime_env == {"pip": ["ray"]}
assert operator.num_cpus == 2
assert operator.num_gpus == 1
assert operator.memory == "1G"
assert operator.memory == 1024
assert operator.ray_resources == {"custom_resource": 1}
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 300
assert operator.poll_interval == 30
assert operator.xcom_task_key == "ray_result"

def test_initialization_defaults(self):
config = {}

def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

assert operator.conn_id == ""
assert operator.entrypoint == "python script.py"
assert operator.runtime_env == {}
assert operator.num_cpus == 1
assert operator.num_gpus == 0
assert operator.memory is None
assert operator.resources is None
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 600
assert operator.poll_interval == 60
assert operator.xcom_task_key is None

def test_invalid_config_raises_exception(self):
config = {
"num_cpus": "invalid_number",
}

def dummy_callable():
pass

with pytest.raises(TypeError):
_RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

config["num_cpus"] = 1
config["num_gpus"] = "invalid_number"
with pytest.raises(TypeError):
_RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

@patch.object(_RayDecoratedOperator, "get_python_source")
@patch.object(SubmitRayJob, "execute")
def test_execute(self, mock_super_execute, mock_get_python_source):
def test_execute_decorated_function(self, mock_super_execute, mock_get_python_source):
config = {
"entrypoint": "python my_script.py",
"runtime_env": {"pip": ["ray"]},
}

Expand All @@ -47,52 +97,102 @@ def dummy_callable():

context = MagicMock(spec=Context)
operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

mock_get_python_source.return_value = "def my_function():\n pass\n"
mock_get_python_source.return_value = "def dummy_callable():\n pass\n"
mock_super_execute.return_value = "success"

result = operator.execute(context)

assert result == "success"
assert operator.entrypoint == "python script.py"
assert "working_dir" in operator.runtime_env

def test_missing_host_config(self):
@patch.object(SubmitRayJob, "execute")
def test_execute_with_entrypoint(self, mock_super_execute):
config = {
"entrypoint": "python my_script.py",
}

def dummy_callable():
pass

context = MagicMock(spec=Context)
operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
mock_super_execute.return_value = "success"

result = operator.execute(context)

assert result == "success"
assert operator.entrypoint == "python my_script.py"

def test_invalid_config_raises_exception(self):
config = {
"host": "http://localhost:8265",
"entrypoint": "python my_script.py",
"runtime_env": {"pip": ["ray"]},
"num_cpus": "invalid_number",
}
@patch.object(SubmitRayJob, "execute")
def test_execute_failure(self, mock_super_execute):
config = {}

def dummy_callable():
pass

with pytest.raises(TypeError):
_RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
context = MagicMock(spec=Context)
operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
mock_super_execute.side_effect = Exception("Ray job failed")

with pytest.raises(AirflowException):
operator.execute(context)

class TestRayTaskDecorator:
def test_build_args_str(self):
config = {}

def dummy_callable(arg1, arg2, kwarg1="default"):
pass

operator = _RayDecoratedOperator(
task_id="test_task",
config=config,
python_callable=dummy_callable,
op_args=["value1", "value2"],
op_kwargs={"kwarg1": "custom"},
)

args_str = operator._build_args_str()
assert args_str == "'value1', 'value2', kwarg1='custom'"

def test_extract_function_name(self):
config = {}

def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

function_name = operator._extract_function_name()
assert function_name == "dummy_callable"


class TestRayTaskDecorator:
def test_ray_task_decorator(self):
@task.ray()
def dummy_function():
return "dummy"

decorator = ray_task(python_callable=dummy_function)
assert isinstance(decorator, _TaskDecorator)
assert isinstance(dummy_function, _TaskDecorator)

def test_ray_task_decorator_with_multiple_outputs(self):
@task.ray(multiple_outputs=True)
def dummy_function():
return {"key": "value"}

decorator = ray_task(python_callable=dummy_function, multiple_outputs=True)
assert isinstance(decorator, _TaskDecorator)
assert isinstance(dummy_function, _TaskDecorator)

def test_ray_task_decorator_with_config(self):
config = {
"num_cpus": 2,
"num_gpus": 1,
"memory": 1024,
}

@task.ray(**config)
def dummy_function():
return "dummy"

assert isinstance(dummy_function, _TaskDecorator)
# We can't directly access the config here, but we can check if the decorator was applied
assert dummy_function.operator_class == _RayDecoratedOperator

0 comments on commit f286997

Please sign in to comment.