Skip to content

Commit

Permalink
job-runner: move async one level up in the call stack
Browse files Browse the repository at this point in the history
* The job_runner was converted to async, its tasks run in the background
  (i.e. calls complete before the code has finished running, you cannot
  await them from synchronous code).
* This jiggered up some `finally` blocks which were being invoked too
  soon.
* This PR moves the async code up one level in the stack, brining these
  finally blocks within the lifecycle of the job runner async code
  allowing them to function correctly.
* Closes #2784
  • Loading branch information
oliver-sanders committed Jun 11, 2024
1 parent b4f653f commit e1392e7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 42 deletions.
54 changes: 41 additions & 13 deletions metomi/rose/config_processors/fileinstall.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"""Process "file:*" sections in node of a metomi.rose.config_tree.ConfigTree.
"""

import asyncio
from contextlib import suppress
from fnmatch import fnmatch
from glob import glob
from io import BytesIO
import os
import shlex
from shutil import rmtree
import sqlite3
from tempfile import mkdtemp
from typing import Any, Optional
Expand All @@ -47,6 +47,13 @@
from metomi.rose.scheme_handler import SchemeHandlersManager


# set containing references to "background" coroutines that are not referenced
# from any code (i.e. are not directly awaited), adding them to this list
# avoids the potential for garbage collection to delete them whilst they are
# running
_BACKGROUND_TASKS = set()


class ConfigProcessorForFile(ConfigProcessorBase):
"""Processor for [file:*] in node of a ConfigTree."""

Expand Down Expand Up @@ -89,29 +96,51 @@ def process(
if not nodes:
return

# Create database to store information for incremental updates,
# if it does not already exist.
loc_dao = LocDAO()
loc_dao.create()

cwd = os.getcwd()
file_install_root = conf_tree.node.get_value(
["file-install-root"], os.getenv("ROSE_FILE_INSTALL_ROOT", None)
)
if file_install_root:
file_install_root = env_var_process(file_install_root)
self.manager.fs_util.makedirs(file_install_root)
self.manager.fs_util.chdir(file_install_root)

loop = asyncio.get_event_loop()
loop.set_exception_handler(self.handle_event)
coro = self.__process(conf_tree, nodes, **kwargs)

try:
self._process(conf_tree, nodes, loc_dao, **kwargs)
# event loop is not running (e.g. rose CLI use)
loop.run_until_complete(coro)
except RuntimeError:
# event loop is already running (e.g. cylc CLI use)
# WARNING: this starts the file installation running, but it
# doesn't wait for it to finish, that's your problem :(
task = loop.create_task(coro)
# reference this task from a global variable to prevent it from
# being garbage collected
_BACKGROUND_TASKS.add(task)
# tidy up afterwards
task.add_done_callback(_BACKGROUND_TASKS.discard)

async def __process(self, conf_tree, nodes, **kwargs):
"""Helper for self._process."""
cwd = os.getcwd()

# Create database to store information for incremental updates,
# if it does not already exist.
loc_dao = LocDAO()
loc_dao.create()

try:
await self._process(conf_tree, nodes, loc_dao, **kwargs)
finally:
if cwd != os.getcwd():
self.manager.fs_util.chdir(cwd)
if loc_dao.conn:
with suppress(Exception):
loc_dao.conn.close()

def _process(self, conf_tree, nodes, loc_dao, **kwargs):
async def _process(self, conf_tree, nodes, loc_dao, **kwargs):
"""Helper for self.process."""
# Ensure that everything is overwritable
# Ensure that container directories exist
Expand Down Expand Up @@ -360,7 +389,9 @@ def _process(self, conf_tree, nodes, loc_dao, **kwargs):
if nproc_str is not None:
nproc = int(nproc_str)
job_runner = JobRunner(self, nproc)
job_runner(JobManager(jobs), conf_tree, loc_dao, work_dir)
await job_runner(
JobManager(jobs), conf_tree, loc_dao, work_dir
)
except ValueError as exc:
if exc.args and exc.args[0] in jobs:
job = jobs[exc.args[0]]
Expand All @@ -372,9 +403,6 @@ def _process(self, conf_tree, nodes, loc_dao, **kwargs):
]
raise ConfigProcessError(keys, source.name)
raise exc
finally:
loc_dao.execute_queued_items()
rmtree(work_dir)

# Target checksum compare and report
for target in list(targets.values()):
Expand Down
47 changes: 18 additions & 29 deletions metomi/rose/job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,11 @@
"""A multiprocessing runner of jobs with dependencies."""

import asyncio
from shutil import rmtree

from metomi.rose.reporter import Event


# set containing references to "background" coroutines that are not referenced
# from any code (i.e. are not directly awaited), adding them to this list
# avoids the potential for garbage collection to delete them whilst they are
# running
_BACKGROUND_TASKS = set()


class JobEvent(Event):
"""Event raised when a job completes."""

Expand Down Expand Up @@ -171,7 +165,8 @@ def __init__(self, job_processor, nproc=None):
"""
self.job_processor = job_processor

def run(self, job_manager, *args, concurrency=6):
async def run(self, job_manager, conf_tree, loc_dao, work_dir, concurrency=6):

"""Start the job runner with an instance of JobManager.
Args:
Expand All @@ -183,32 +178,26 @@ def run(self, job_manager, *args, concurrency=6):
The maximum number of jobs to run concurrently.
"""
loop = asyncio.get_event_loop()
loop.set_exception_handler(self.job_processor.handle_event)
coro = self._run(job_manager, *args, concurrency=concurrency)
try:
# event loop is not running (e.g. rose CLI use)
loop.run_until_complete(coro)
except RuntimeError:
# event loop is already running (e.g. cylc CLI use)
# WARNING: this starts the file installation running, but it
# doesn't wait for it to finish, that's your problem :(
task = loop.create_task(coro)
# reference this task from a global variable to prevent it from
# being garbage collected
_BACKGROUND_TASKS.add(task)
# tidy up afterwards
task.add_done_callback(_BACKGROUND_TASKS.discard)
await self._run(
job_manager, conf_tree, loc_dao, work_dir, concurrency=concurrency
)
dead_jobs = job_manager.get_dead_jobs()
if dead_jobs:
raise JobRunnerNotCompletedError(dead_jobs)

async def _run(self, job_manager, *args, concurrency=6):
async def _run(
self, job_manager, conf_tree, loc_dao, work_dir, concurrency=6
):
running = []
await asyncio.gather(
self._run_jobs(running, job_manager, args, concurrency),
self._post_process_jobs(running, job_manager, args),
)
args = (conf_tree, loc_dao, work_dir)
try:
await asyncio.gather(
self._run_jobs(running, job_manager, args, concurrency),
self._post_process_jobs(running, job_manager, args),
)
finally:
loc_dao.execute_queued_items()
rmtree(work_dir)

async def _run_jobs(self, running, job_manager, args, concurrency):
"""Run pending jobs subject to the concurrency limit.
Expand Down

0 comments on commit e1392e7

Please sign in to comment.