From 55e8609c74cc8bda00b95d1ba3634d89e82d6f00 Mon Sep 17 00:00:00 2001 From: mvdbeek Date: Thu, 2 Nov 2023 14:08:56 +0100 Subject: [PATCH] Don't store job in JobIO instance attributes and invalidate `_output_hdas_and_paths` when current session is not the same session that was used to populate `_output_hdas_and_paths`. The Job instance may originate from a session that is associated to another thread, and when that thread closes the session the instance (along with other instances retrieved through loading relationships) becomes detached. I am not sure at all if this will fix ``` DetachedInstanceError: Instance is not bound to a Session; attribute refresh operation cannot proceed (Background on this error at: https://sqlalche.me/e/14/bhk3) File "galaxy/jobs/runners/__init__.py", line 291, in prepare_job job_wrapper.prepare() File "galaxy/jobs/__init__.py", line 1248, in prepare tool_evaluator.set_compute_environment(compute_environment, get_special=get_special) File "galaxy/tools/evaluation.py", line 162, in set_compute_environment self.param_dict = self.build_param_dict( File "galaxy/tools/evaluation.py", line 204, in build_param_dict self.__populate_output_dataset_wrappers(param_dict, output_datasets, job_working_directory) File "galaxy/tools/evaluation.py", line 447, in __populate_output_dataset_wrappers param_dict[name] = DatasetFilenameWrapper( File "galaxy/tools/wrappers.py", line 403, in __init__ path_rewrite = compute_environment and compute_environment.output_path_rewrite(dataset_instance) File "galaxy/job_execution/compute_environment.py", line 132, in output_path_rewrite return str(self.job_io.get_output_path(dataset)) File "galaxy/job_execution/setup.py", line 226, in get_output_path if hda.id == dataset.id: File "sqlalchemy/orm/attributes.py", line 487, in __get__ return self.impl.get(state, dict_) File "sqlalchemy/orm/attributes.py", line 959, in get value = self._fire_loader_callables(state, key, passive) File "sqlalchemy/orm/attributes.py", line 990, in _fire_loader_callables return state._load_expired(state, passive) File "sqlalchemy/orm/state.py", line 712, in _load_expired self.manager.expired_attribute_loader(self, toload, passive) File "sqlalchemy/orm/loading.py", line 1369, in load_scalar_attributes raise orm_exc.DetachedInstanceError( ``` but it seems to make some sense. JobIO crosses thread boundaries as part of the job wrapper getting put into threading queues. Ideally we'd make sure that no ORM instance crosses the thread boundary (or we systematically re-associated with a session). I also tried flagging these patterns automatically using something like: ``` @event.listens_for(session, "persistent_to_detached") def on_detach(sess, instance): if not getattr(instance, "allow_detatch", False): raise Exception(f"{instance} detached. This ain't good for how we do things ?") ``` but it seems tricky to figure out when this is fine and when it is not. --- lib/galaxy/job_execution/setup.py | 58 ++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/lib/galaxy/job_execution/setup.py b/lib/galaxy/job_execution/setup.py index cfa281a329c2..330ba38d572e 100644 --- a/lib/galaxy/job_execution/setup.py +++ b/lib/galaxy/job_execution/setup.py @@ -1,11 +1,13 @@ """Utilities to help job and tool code setup jobs.""" import json import os +import threading from typing import ( Any, cast, Dict, List, + NamedTuple, Optional, Tuple, Union, @@ -38,6 +40,27 @@ OutputPaths = List[DatasetPath] +class JobOutput(NamedTuple): + output_name: str + dataset: DatasetInstance + dataset_path: DatasetPath + + +class JobOutputs(threading.local): + def __init__(self) -> None: + super().__init__() + self.output_hdas_and_paths: Optional[OutputHdasAndType] = None + self.output_paths: Optional[OutputPaths] = None + + @property + def populated(self) -> bool: + return self.output_hdas_and_paths is not None + + def set_job_outputs(self, job_outputs: List[JobOutput]) -> None: + self.output_paths = [t[2] for t in job_outputs] + self.output_hdas_and_paths = {t.output_name: (t.dataset, t.dataset_path) for t in job_outputs} + + class JobIO(Dictifiable): dict_collection_visible_keys = ( "job_id", @@ -99,7 +122,6 @@ def __init__( user_context_instance = user_context self.user_context = user_context_instance self.sa_session = sa_session - self.job = job self.job_id = job.id self.working_directory = working_directory self.outputs_directory = outputs_directory @@ -121,22 +143,25 @@ def __init__( self.is_task = is_task self.tool_source = tool_source self.tool_source_class = tool_source_class - self._output_paths: Optional[OutputPaths] = None - self._output_hdas_and_paths: Optional[OutputHdasAndType] = None + self.job_outputs = JobOutputs() self._dataset_path_rewriter: Optional[DatasetPathRewriter] = None + @property + def job(self): + return self.sa_session.query(Job).get(self.job_id) + @classmethod def from_json(cls, path, sa_session): with open(path) as job_io_serialized: io_dict = json.load(job_io_serialized) - return cls.from_dict(io_dict=io_dict, sa_session=sa_session) + job_id = io_dict.pop("job_id") + job = sa_session.query(Job).get(job_id) + return cls(sa_session=sa_session, job=job, **io_dict) @classmethod def from_dict(cls, io_dict, sa_session): io_dict.pop("model_class") - job_id = io_dict.pop("job_id") - job = sa_session.query(Job).get(job_id) - return cls(sa_session=sa_session, job=job, **io_dict) + return cls(sa_session=sa_session, **io_dict) def to_dict(self): io_dict = super().to_dict() @@ -165,15 +190,15 @@ def dataset_path_rewriter(self) -> DatasetPathRewriter: @property def output_paths(self) -> OutputPaths: - if self._output_paths is None: + if not self.job_outputs.populated: self.compute_outputs() - return cast(OutputPaths, self._output_paths) + return cast(OutputPaths, self.job_outputs.output_paths) @property def output_hdas_and_paths(self) -> OutputHdasAndType: - if self._output_hdas_and_paths is None: + if not self.job_outputs.populated: self.compute_outputs() - return cast(OutputHdasAndType, self._output_hdas_and_paths) + return cast(OutputHdasAndType, self.job_outputs.output_hdas_and_paths) def get_input_dataset_fnames(self, ds: DatasetInstance) -> List[str]: filenames = [ds.file_name] @@ -241,22 +266,21 @@ def compute_outputs(self) -> None: special = self.sa_session.query(JobExportHistoryArchive).filter_by(job=job).first() false_path = None - results = [] + job_outputs = [] for da in job.output_datasets + job.output_library_datasets: da_false_path = dataset_path_rewriter.rewrite_dataset_path(da.dataset, "output") mutable = da.dataset.dataset.external_filename is None dataset_path = DatasetPath( da.dataset.dataset.id, da.dataset.file_name, false_path=da_false_path, mutable=mutable ) - results.append((da.name, da.dataset, dataset_path)) + job_outputs.append(JobOutput(da.name, da.dataset, dataset_path)) - self._output_paths = [t[2] for t in results] - self._output_hdas_and_paths = {t[0]: t[1:] for t in results} if special: false_path = dataset_path_rewriter.rewrite_dataset_path(special, "output") dsp = DatasetPath(special.dataset.id, special.dataset.file_name, false_path) - self._output_paths.append(dsp) - self._output_hdas_and_paths["output_file"] = (special.fda, dsp) + job_outputs.append(JobOutput("output_file", special.fda, dsp)) + + self.job_outputs.set_job_outputs(job_outputs) def get_output_file_id(self, file: str) -> Optional[int]: for dp in self.output_paths: