Skip to content

Commit

Permalink
Merge branch 'release_23.2' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdbeek committed Jan 18, 2024
2 parents 274a154 + 064360d commit e02a765
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 33 deletions.
8 changes: 8 additions & 0 deletions lib/galaxy/celery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid
from functools import (
lru_cache,
wraps,
Expand Down Expand Up @@ -166,6 +167,10 @@ def wrapper(*args, **kwds):
app = get_galaxy_app()
assert app

# Ensure sqlalchemy session registry scope is specific to this instance of the celery task
scoped_id = str(uuid.uuid4())
app.model.set_request_id(scoped_id)

desc = func.__name__
if action is not None:
desc += f" to {action}"
Expand All @@ -183,6 +188,9 @@ def wrapper(*args, **kwds):
except Exception:
log.warning(f"Celery task execution failed for {desc} {timer}")
raise
finally:
# Close and remove any open session this task has created
app.model.unset_request_id(scoped_id)

return wrapper

Expand Down
8 changes: 8 additions & 0 deletions lib/galaxy/jobs/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
import traceback
import typing
import uuid
from queue import (
Empty,
Queue,
Expand Down Expand Up @@ -829,12 +830,19 @@ def monitor(self):
self.watched.append(async_job_state)
except Empty:
pass
# Ideally we'd construct a sqlalchemy session now and pass it into `check_watched_items`
# and have that be the only session being used. The next best thing is to scope
# the session and discard it after each check_watched_item loop
scoped_id = str(uuid.uuid4())
self.app.model.set_request_id(scoped_id)
# Iterate over the list of watched jobs and check state
try:
check_database_connection(self.sa_session)
self.check_watched_items()
except Exception:
log.exception("Unhandled exception checking active jobs")
finally:
self.app.model.unset_request_id(scoped_id)
# Sleep a bit before the next state check
time.sleep(self.app.config.job_runner_monitor_sleep)

Expand Down
15 changes: 15 additions & 0 deletions test/integration/test_celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
purge_hda,
)
from galaxy.model import HistoryDatasetAssociation
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema import PdfDocumentType
from galaxy.schema.schema import CreatePagePayload
from galaxy.schema.tasks import GeneratePdfDownload
Expand All @@ -31,13 +32,27 @@ def process_page(request: CreatePagePayload):
return f"content_format is {request.content_format} with annotation {request.annotation}"


@galaxy_task
def invalidate_connection(sa_session: galaxy_scoped_session):
sa_session().connection().invalidate()


@galaxy_task
def use_session(sa_session: galaxy_scoped_session):
sa_session().query(HistoryDatasetAssociation).get(1)


class TestCeleryTasksIntegration(IntegrationTestCase):
dataset_populator: DatasetPopulator

def setUp(self):
super().setUp()
self.dataset_populator = DatasetPopulator(self.galaxy_interactor)

def test_recover_from_invalid_connection(self):
invalidate_connection.delay().get()
use_session.delay().get()

def test_random_simple_task_to_verify_framework_for_testing(self):
assert mul.delay(4, 4).get(timeout=10) == 16

Expand Down
51 changes: 18 additions & 33 deletions test/unit/app/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from contextlib import contextmanager
from typing import (
Iterator,
List,
)
from typing import List

from galaxy.celery import set_thread_app
from galaxy.app_unittest_utils.galaxy_mock import MockApp
from galaxy.celery.tasks import clean_object_store_caches
from galaxy.di import Container
from galaxy.objectstore import BaseObjectStore
from galaxy.objectstore.caching import CacheTarget

Expand All @@ -20,34 +15,24 @@ def cache_targets(self) -> List[CacheTarget]:


def test_clean_object_store_caches(tmp_path):
with celery_injected_app_container() as container:
cache_targets: List[CacheTarget] = []
container[BaseObjectStore] = MockObjectStore(cache_targets) # type: ignore[assignment]
container = MockApp()
cache_targets: List[CacheTarget] = []
container[BaseObjectStore] = MockObjectStore(cache_targets) # type: ignore[assignment]

# similar code used in object store unit tests
cache_dir = tmp_path
path = cache_dir / "a_file_0"
path.write_text("this is an example file")
# similar code used in object store unit tests
cache_dir = tmp_path
path = cache_dir / "a_file_0"
path.write_text("this is an example file")

# works fine on an empty list of cache targets...
clean_object_store_caches()
# works fine on an empty list of cache targets...
clean_object_store_caches()

assert path.exists()
assert path.exists()

# place the file in mock object store's cache targets and
# run the task again and the above file should be gone.
cache_targets.append(CacheTarget(cache_dir, 1, 0.000000001))
# works fine on an empty list of cache targets...
clean_object_store_caches()
# place the file in mock object store's cache targets and
# run the task again and the above file should be gone.
cache_targets.append(CacheTarget(cache_dir, 1, 0.000000001))
# works fine on an empty list of cache targets...
clean_object_store_caches()

assert not path.exists()


@contextmanager
def celery_injected_app_container() -> Iterator[Container]:
container = Container()
set_thread_app(container)
try:
yield container
finally:
set_thread_app(None)
assert not path.exists()

0 comments on commit e02a765

Please sign in to comment.