Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Jul 19, 2024
1 parent 7ad9aac commit 7e5cc2e
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 72 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[run]
omit =
tests/*
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
- name: Test Ray against Airflow ${{ matrix.airflow-version }} and Python ${{ matrix.python-version }}
run: |
hatch run tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }}:test
hatch run tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }}:test-cov
- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ repos:
name: Run codespell to check for common misspellings in files
language: python
types: [text]
args: ["--ignore-words", codespell-ignore-words.txt]
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
Expand Down
1 change: 1 addition & 0 deletions ray_provider/codespell-ignore-words.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ascend
6 changes: 6 additions & 0 deletions scripts/test/unit_cov.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pytest \
-vv \
--cov=ray_provider \
--cov-report=term-missing \
--cov-report=xml \
--durations=0
2 changes: 1 addition & 1 deletion scripts/test/unit_test.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest \
-vv \
--durations=0 \
-m "not (integration or perf)"
-m "not (integration or perf)"
5 changes: 1 addition & 4 deletions tests/decorators/test_ray_decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -29,13 +28,11 @@ def dummy_callable():

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

assert operator.host == "http://localhost:8265"
assert operator.entrypoint == "python my_script.py"
assert operator.runtime_env == {"pip": ["ray"]}
assert operator.num_cpus == 2
assert operator.num_gpus == 1
assert operator.memory == "1G"
assert operator.node_group is None

@patch.object(_RayDecoratedOperator, "get_python_source")
@patch.object(SubmitRayJob, "execute")
Expand Down Expand Up @@ -67,7 +64,7 @@ def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
assert operator.host == os.getenv("RAY_DASHBOARD_URL")
assert operator.entrypoint == "python my_script.py"

def test_invalid_config_raises_exception(self):
config = {
Expand Down
50 changes: 16 additions & 34 deletions tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from ray.job_submission import JobStatus

from ray_provider.operators.ray import SubmitRayJob

Expand All @@ -21,7 +20,7 @@
@pytest.fixture
def operator():
return SubmitRayJob(
host=host,
conn_id="test_conn",
entrypoint=entrypoint,
runtime_env=runtime_env,
num_cpus=num_cpus,
Expand All @@ -36,51 +35,34 @@ def operator():
class TestSubmitRayJob:

def test_init(self, operator):
assert operator.host == host
assert operator.conn_id == "test_conn"
assert operator.entrypoint == entrypoint
assert operator.runtime_env == runtime_env
assert operator.num_cpus == num_cpus
assert operator.num_gpus == num_gpus
assert operator.memory == memory
assert operator.resources == resources
# assert operator.resources == resources
assert operator.timeout == timeout
assert operator.client is None
assert operator.job_id is None
assert operator.status_to_wait_for == {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}

@patch("ray_provider.operators.kuberay.JobSubmissionClient")
def test_execute(self, mock_client_class, operator):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.submit_job.return_value = "job_12345"
mock_client.get_job_status.return_value = JobStatus.RUNNING

try:
@patch("ray_provider.operators.ray.SubmitRayJob.hook")
def test_execute(self, mock_hook, operator):
with pytest.raises(TaskDeferred):
operator.execute(context)
except TaskDeferred:
pass

mock_client_class.assert_called_once_with(host)
mock_client.submit_job.assert_called_once_with(
entrypoint=entrypoint,
runtime_env=runtime_env,
entrypoint_num_cpus=num_cpus,
entrypoint_num_gpus=num_gpus,
entrypoint_memory=memory,
entrypoint_resources=resources,
mock_hook.submit_ray_job.assert_called_once_with(
entrypoint="python script.py",
runtime_env={"pip": ["requests"]},
entrypoint_num_cpus=2,
entrypoint_num_gpus=1,
entrypoint_memory=1024,
entrypoint_resources={"CPU": 2},
)
assert operator.job_id == "job_12345"

@patch("ray_provider.operators.kuberay.JobSubmissionClient")
def test_on_kill(self, mock_client_class, operator):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
operator.client = mock_client
@patch("ray_provider.operators.ray.SubmitRayJob.hook")
def test_on_kill(self, mock_hook, operator):
operator.job_id = "job_12345"

operator.on_kill()

mock_client.delete_job.assert_called_once_with("job_12345")
mock_hook.delete_ray_job.assert_called_once_with("job_12345")

def test_execute_complete_success(self, operator):
event = {"status": "success", "message": "Job completed successfully"}
Expand Down
57 changes: 25 additions & 32 deletions tests/triggers/test_ray_triggers.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,39 @@
import time
from unittest import mock
from unittest.mock import patch

import pytest
from airflow.triggers.base import TriggerEvent
from ray.dashboard.modules.job.sdk import JobStatus, JobSubmissionClient
from ray.dashboard.modules.job.sdk import JobStatus

from ray_provider.triggers.ray import RayJobTrigger


class TestRayJobTrigger:

@pytest.mark.asyncio
async def test_run_no_job_id(self):
trigger = RayJobTrigger(job_id="", host="localhost", end_time=time.time() + 60, poll_interval=1)
@patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
mock_is_terminal.return_value = True
trigger = RayJobTrigger(job_id="", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

generator = trigger.run()
event = await generator.send(None)
assert event == TriggerEvent(
{"status": "error", "message": "No job_id provided to async trigger", "job_id": ""}
)
event = await generator.asend(None)
assert event == TriggerEvent({"status": "error", "message": "Job run has failed.", "job_id": ""})

@pytest.mark.asyncio
async def test_run_job_succeeded(self):
trigger = RayJobTrigger(job_id="test_job_id", host="localhost", end_time=time.time() + 60, poll_interval=1)

client_mock = mock.MagicMock(spec=JobSubmissionClient)
client_mock.get_job_status.return_value = JobStatus.SUCCEEDED

async def async_generator():
yield "log line 1"
yield "log line 2"

client_mock.tail_job_logs.return_value = async_generator()

with mock.patch("ray_provider.triggers.kuberay.JobSubmissionClient", return_value=client_mock):
generator = trigger.run()
async for event in generator:
assert event == TriggerEvent(
{
"status": "success",
"message": "Job run test_job_id has completed successfully.",
"job_id": "test_job_id",
}
)
break # Stop after the first event for testing purposes
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_job_succeeded(self, mock_hook):
trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED

generator = trigger.run()
async for event in generator:
assert event == TriggerEvent(
{
"status": "success",
"message": "Job run test_job_id has completed successfully.",
"job_id": "test_job_id",
}
)
break # Stop after the first event for testing purposes

0 comments on commit 7e5cc2e

Please sign in to comment.