Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Aug 25, 2024
1 parent 78c9b64 commit 584e09c
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 88 deletions.
22 changes: 13 additions & 9 deletions ray_provider/hooks/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Any, AsyncIterator

import requests
Expand Down Expand Up @@ -164,14 +165,14 @@ def submit_ray_job(
self.log.info(f"Submitted job with ID: {job_id}")
return str(job_id)

def delete_ray_job(self, job_id: str) -> Any:
def delete_ray_job(self, dashboard_url: str | None, job_id: str) -> Any:
"""
Deletes a job from the Ray cluster.
:param job_id: The ID of the job to delete.
:return: The result of the delete operation.
"""
client = self.ray_client
client = self.ray_client(dashboard_url=dashboard_url)
self.log.info(f"Deleting job with ID: {job_id}")
return client.delete_job(job_id=job_id)

Expand All @@ -187,14 +188,14 @@ def get_ray_job_status(self, dashboard_url: str | None, job_id: str) -> JobStatu
self.log.info(f"Job {job_id} status: {status}")
return status

def get_ray_job_logs(self, job_id: str) -> str:
def get_ray_job_logs(self, dashboard_url: str | None, job_id: str) -> str:
"""
Retrieves the logs of a submitted job.
:param job_id: The ID of the job.
:return: Logs of the job.
"""
client = self.ray_client
client = self.ray_client(dashboard_url=dashboard_url)
logs = client.get_job_logs(job_id=job_id)
return str(logs)

Expand Down Expand Up @@ -334,15 +335,18 @@ def _wait_for_load_balancer(

raise AirflowException(f"LoadBalancer did not become ready after {max_retries} attempts")

def _validate_yaml_file(self, yaml_file: str) -> None:
def _validate_yaml_file(self, yaml_file: str | Path) -> None:
"""Validate the existence and format of the YAML file."""
if not os.path.isfile(yaml_file):
raise AirflowException(f"The specified YAML file does not exist: {yaml_file}")
if not yaml_file.endswith((".yaml", ".yml")):
yaml_path = Path(yaml_file)

if not yaml_path.is_file():
raise AirflowException(f"The specified YAML file does not exist: {yaml_path}")

if not yaml_path.name.endswith((".yaml", ".yml")):
raise AirflowException("The specified YAML file must have a .yaml or .yml extension.")

try:
with open(yaml_file) as stream:
with yaml_path.open() as stream:
yaml.safe_load(stream)
except yaml.YAMLError as exc:
raise AirflowException(f"The specified YAML file is not valid YAML: {exc}")
Expand Down
2 changes: 1 addition & 1 deletion ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def on_kill(self) -> None:
"""
if hasattr(self, "hook") and self.job_id:
self.log.info(f"Deleting Ray job {self.job_id} due to task kill.")
self.hook.delete_ray_job(self.job_id)
self.hook.delete_ray_job(self.dashboard_url, self.job_id)

@cached_property
def hook(self) -> PodOperatorHookProtocol:
Expand Down
Loading

0 comments on commit 584e09c

Please sign in to comment.