Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Aug 3, 2024
1 parent 5638780 commit 57aedfd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
6 changes: 3 additions & 3 deletions ray_provider/hooks/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ def get_ray_job_logs(self, job_id: str) -> str:
async def get_ray_tail_logs(self, job_id: str) -> AsyncIterator[str]:
"""
Tails the logs of a submitted job asynchronously.
:param job_id: The ID of the job.
:return: An async iterator of log lines.
"""
client = self.ray_client
iterator = await client.tail_job_logs(job_id)
for line in iterator:
yield line
async for lines in client.tail_job_logs(job_id):
yield lines

def load_yaml_content(self, path_or_link: str) -> Any:
"""
Expand Down
33 changes: 20 additions & 13 deletions ray_provider/triggers/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RayJobTrigger(BaseTrigger):
:param conn_id: The connection ID for the Ray cluster.
:param xcom_dashboard_url: Optional URL for the Ray dashboard.
:param poll_interval: The interval in seconds at which to poll the job status. Defaults to 30 seconds.
:param fetch_logs: Whether to fetch and stream logs. Defaults to True.
"""

def __init__(
Expand All @@ -38,7 +39,6 @@ def __init__(
self.dashboard_url = xcom_dashboard_url
self.fetch_logs = fetch_logs
self.poll_interval = poll_interval
self.log_iterator: AsyncIterator[str] | None = None

def serialize(self) -> tuple[str, dict[str, Any]]:
"""
Expand Down Expand Up @@ -66,6 +66,21 @@ def hook(self) -> RayHook:
"""
return RayHook(conn_id=self.conn_id, xcom_dashboard_url=self.dashboard_url)

async def _poll_status(self) -> None:
while not self._is_terminal_state():
await asyncio.sleep(self.poll_interval)

async def _stream_logs(self) -> None:
"""
Streams logs from the Ray job in real-time.
"""
self.log.info(f"::group::{self.job_id} logs")
async for log_lines in self.hook.get_ray_tail_logs(self.job_id):
for line in log_lines.split("\n"):
if line.strip(): # Avoid logging empty lines
self.log.info(line.strip())
self.log.info("::endgroup::")

async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Asynchronously polls the job status and yields events based on the job's state.
Expand All @@ -78,19 +93,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
try:
self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...")

while not self._is_terminal_state():
if self.log_iterator is None:
self.log_iterator = self.hook.get_ray_tail_logs(self.job_id)

# Check for new log lines
try:
async for line in self.log_iterator:
self.log.info(line)
except StopIteration:
# No more logs available at this time
pass
tasks = [self._poll_status()]
if self.fetch_logs:
tasks.append(self._stream_logs())

await asyncio.sleep(self.poll_interval)
await asyncio.gather(*tasks)

completed_status = self.hook.get_ray_job_status(self.job_id)
self.log.info(f"Status of completed job {self.job_id} is: {completed_status}")
Expand Down

0 comments on commit 57aedfd

Please sign in to comment.