Skip to content

Commit

Permalink
Stop catching generic Exception in triggerer
Browse files Browse the repository at this point in the history
By catching `Exception`, we run into the ristk of hitting an unexpected exception that the program can't recover from, or worse, swallowing an important exception without properly logging it - a huge headache when trying to debug programs that are failing in weird ways

This was identified during #81 development.
  • Loading branch information
tatiana committed Nov 28, 2024
1 parent fde7d5e commit 1f76a47
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 94 deletions.
77 changes: 38 additions & 39 deletions ray_provider/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent
from kubernetes.client.exceptions import ApiException
from ray.job_submission import JobStatus

from ray_provider.constants import TERMINAL_JOB_STATUSES
from ray_provider.hooks import RayHook


Expand Down Expand Up @@ -43,6 +45,7 @@ def __init__(
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.fetch_logs = fetch_logs
self.poll_interval = poll_interval
self._job_status = None | JobStatus

def serialize(self) -> tuple[str, dict[str, Any]]:
"""
Expand Down Expand Up @@ -81,22 +84,22 @@ async def cleanup(self) -> None:
resources are not deleted.
"""
try:
if self.ray_cluster_yaml:
self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml
)
self.log.info("Ray cluster deletion process completed")
else:
self.log.info("No Ray cluster YAML provided, skipping cluster deletion")
except Exception as e:
self.log.error(f"Unexpected error during cleanup: {str(e)}")
if self.ray_cluster_yaml:
self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml
)
self.log.info("Ray cluster deletion process completed")
else:
self.log.info("No Ray cluster YAML provided, skipping cluster deletion")

async def _poll_status(self) -> None:
while not self._is_terminal_state():
self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
while self._job_status not in TERMINAL_JOB_STATUSES:
self.log.info(f"Status of job {self.job_id} is: {self._job_status}")
await asyncio.sleep(self.poll_interval)
self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)

async def _stream_logs(self) -> None:
"""
Expand All @@ -111,46 +114,42 @@ async def _stream_logs(self) -> None:

async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Asynchronously polls the job status and yields events based on the job's state.
Asynchronously polls the Ray job status and yields events based on the job's state.
This method gets job status at each poll interval and streams logs if available.
It yields a TriggerEvent upon job completion, cancellation, or failure.
:yield: TriggerEvent containing the status, message, and job ID related to the job.
"""
try:
self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...")
self.log.info(f"::group:: Trigger 1/2: Checking the job status")
self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...")

try:
tasks = [self._poll_status()]
if self.fetch_logs:
tasks.append(self._stream_logs())

await asyncio.gather(*tasks)
except ApiException as e:
error_msg = str(e)
self.log.info(f"::endgroup::")
self.log.error("::group:: Trigger unable to poll job status")
self.log.error("Exception details:", exc_info=True)
self.log.info("Attempting to clean up...")
await self.cleanup()
self.log.info("Cleanup completed!")
self.log.info(f"::endgroup::")

yield TriggerEvent({"status": "EXCEPTION", "message": error_msg, "job_id": self.job_id})
else:
self.log.info(f"::endgroup::")
self.log.info(f"::group:: Trigger 2/2: Job reached a terminal state")
self.log.info(f"Status of completed job {self.job_id} is: {self._job_status}")
self.log.info(f"::endgroup::")

completed_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Status of completed job {self.job_id} is: {completed_status}")
yield TriggerEvent(
{
"status": completed_status,
"message": f"Job {self.job_id} completed with status {completed_status}",
"status": self._job_status,
"message": f"Job {self.job_id} completed with status {self._job_status}",
"job_id": self.job_id,
}
)
except Exception as e:
self.log.error(f"Error occurred: {str(e)}")
await self.cleanup()
yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id})

def _is_terminal_state(self) -> bool:
"""
Checks if the Ray job is in a terminal state.
A terminal state is one of the following: SUCCEEDED, STOPPED, or FAILED.
:return: True if the job is in a terminal state, False otherwise.
"""
return self.hook.get_ray_job_status(self.dashboard_url, self.job_id) in (
JobStatus.SUCCEEDED,
JobStatus.STOPPED,
JobStatus.FAILED,
)
75 changes: 20 additions & 55 deletions tests/test_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from airflow.triggers.base import TriggerEvent
from kubernetes.client.exceptions import ApiException
from ray.job_submission import JobStatus

from ray_provider.triggers import RayJobTrigger
Expand All @@ -22,11 +23,9 @@ def trigger(self):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", return_value=JobStatus.FAILED)
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
mock_is_terminal.return_value = True
mock_hook.get_ray_job_status.return_value = JobStatus.FAILED
async def test_run_no_job_id(self, mock_hook, mock_job_status):
trigger = RayJobTrigger(
job_id="",
poll_interval=1,
Expand All @@ -42,11 +41,9 @@ async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.SUCCEEDED])
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_job_succeeded(self, mock_hook, mock_is_terminal):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED
async def test_run_job_succeeded(self, mock_hook, mock_job_status):
trigger = RayJobTrigger(
job_id="test_job_id",
poll_interval=1,
Expand All @@ -66,12 +63,9 @@ async def test_run_job_succeeded(self, mock_hook, mock_is_terminal):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.STOPPED])
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.STOPPED

async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger):
generator = trigger.run()
event = await generator.asend(None)

Expand All @@ -84,12 +78,9 @@ async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.FAILED])
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.FAILED

async def test_run_job_failed(self, mock_hook, mock_job_status, trigger):
generator = trigger.run()
event = await generator.asend(None)

Expand All @@ -102,12 +93,10 @@ async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.SUCCEEDED])
@patch("ray_provider.triggers.RayJobTrigger.hook")
@patch("ray_provider.triggers.RayJobTrigger._stream_logs")
async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED
async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_job_status, trigger):
mock_stream_logs.return_value = None

generator = trigger.run()
Expand Down Expand Up @@ -156,19 +145,6 @@ def test_serialize(self, trigger):
},
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_is_terminal_state(self, mock_hook, trigger):
mock_hook.get_ray_job_status.side_effect = [
JobStatus.PENDING,
JobStatus.RUNNING,
JobStatus.SUCCEEDED,
]

assert not trigger._is_terminal_state()
assert not trigger._is_terminal_state()
assert trigger._is_terminal_state()

@pytest.mark.asyncio
@patch.object(RayJobTrigger, "hook")
@patch.object(logging.Logger, "info")
Expand Down Expand Up @@ -200,41 +176,30 @@ async def test_cleanup_without_cluster_yaml(self, mock_log_info):

mock_log_info.assert_called_once_with("No Ray cluster YAML provided, skipping cluster deletion")

@pytest.mark.asyncio
@patch.object(RayJobTrigger, "hook")
@patch.object(logging.Logger, "error")
async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger):
mock_hook.delete_ray_cluster.side_effect = Exception("Test exception")

await trigger.cleanup()

mock_log_error.assert_called_once_with("Unexpected error during cleanup: Test exception")

@pytest.mark.asyncio
@patch("asyncio.sleep", new_callable=AsyncMock)
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger):
mock_is_terminal.side_effect = [False, False, True]

@patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, None, JobStatus.SUCCEEDED])
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_poll_status(self, mock_hook, mock_job_status, mock_sleep, trigger):
await trigger._poll_status()

assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(1)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch(
"ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=ApiException("Failed to get job.")
)
@patch("ray_provider.triggers.RayJobTrigger.hook")
@patch("ray_provider.triggers.RayJobTrigger.cleanup")
async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = Exception("Test exception")

async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_job_status, trigger):
generator = trigger.run()
event = await generator.asend(None)

assert event == TriggerEvent(
{
"status": str(JobStatus.FAILED),
"message": "Test exception",
"status": "EXCEPTION",
"message": "(Failed to get job.)\nReason: None\n",
"job_id": "test_job_id",
}
)
Expand Down

0 comments on commit 1f76a47

Please sign in to comment.