Skip to content

Commit

Permalink
Don't store job in JobIO instance attributes
Browse files Browse the repository at this point in the history
and invalidate `_output_hdas_and_paths` when current session is not the
same session that was used to populate `_output_hdas_and_paths`.

The Job instance may originate from a session that is associated to
another thread, and when that thread closes the session the instance
(along with other instances retrieved through loading relationships)
becomes detached.

I am not sure at all if this will fix
```
DetachedInstanceError: Instance <HistoryDatasetAssociation at 0x7fe68bbf14f0> is not bound to a Session; attribute refresh operation cannot proceed (Background on this error at: https://sqlalche.me/e/14/bhk3)
  File "galaxy/jobs/runners/__init__.py", line 291, in prepare_job
    job_wrapper.prepare()
  File "galaxy/jobs/__init__.py", line 1248, in prepare
    tool_evaluator.set_compute_environment(compute_environment, get_special=get_special)
  File "galaxy/tools/evaluation.py", line 162, in set_compute_environment
    self.param_dict = self.build_param_dict(
  File "galaxy/tools/evaluation.py", line 204, in build_param_dict
    self.__populate_output_dataset_wrappers(param_dict, output_datasets, job_working_directory)
  File "galaxy/tools/evaluation.py", line 447, in __populate_output_dataset_wrappers
    param_dict[name] = DatasetFilenameWrapper(
  File "galaxy/tools/wrappers.py", line 403, in __init__
    path_rewrite = compute_environment and compute_environment.output_path_rewrite(dataset_instance)
  File "galaxy/job_execution/compute_environment.py", line 132, in output_path_rewrite
    return str(self.job_io.get_output_path(dataset))
  File "galaxy/job_execution/setup.py", line 226, in get_output_path
    if hda.id == dataset.id:
  File "sqlalchemy/orm/attributes.py", line 487, in __get__
    return self.impl.get(state, dict_)
  File "sqlalchemy/orm/attributes.py", line 959, in get
    value = self._fire_loader_callables(state, key, passive)
  File "sqlalchemy/orm/attributes.py", line 990, in _fire_loader_callables
    return state._load_expired(state, passive)
  File "sqlalchemy/orm/state.py", line 712, in _load_expired
    self.manager.expired_attribute_loader(self, toload, passive)
  File "sqlalchemy/orm/loading.py", line 1369, in load_scalar_attributes
    raise orm_exc.DetachedInstanceError(
```
but it seems to make some sense. JobIO crosses thread boundaries as part
of the job wrapper getting put into threading queues.

Ideally we'd make sure that no ORM instance crosses the thread boundary
(or we systematically re-associated with a session).

I also tried flagging these patterns automatically using something like:

```

    @event.listens_for(session, "persistent_to_detached")
    def on_detach(sess, instance):
        if not getattr(instance, "allow_detatch", False):
            raise Exception(f"{instance} detached. This ain't good for how we do things ?")
```

but it seems tricky to figure out when this is fine and when it is not.
  • Loading branch information
mvdbeek committed Nov 2, 2023
1 parent 12b20c6 commit 55e8609
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions lib/galaxy/job_execution/setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Utilities to help job and tool code setup jobs."""
import json
import os
import threading
from typing import (
Any,
cast,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -38,6 +40,27 @@
OutputPaths = List[DatasetPath]


class JobOutput(NamedTuple):
output_name: str
dataset: DatasetInstance
dataset_path: DatasetPath


class JobOutputs(threading.local):
def __init__(self) -> None:
super().__init__()
self.output_hdas_and_paths: Optional[OutputHdasAndType] = None
self.output_paths: Optional[OutputPaths] = None

@property
def populated(self) -> bool:
return self.output_hdas_and_paths is not None

def set_job_outputs(self, job_outputs: List[JobOutput]) -> None:
self.output_paths = [t[2] for t in job_outputs]
self.output_hdas_and_paths = {t.output_name: (t.dataset, t.dataset_path) for t in job_outputs}


class JobIO(Dictifiable):
dict_collection_visible_keys = (
"job_id",
Expand Down Expand Up @@ -99,7 +122,6 @@ def __init__(
user_context_instance = user_context
self.user_context = user_context_instance
self.sa_session = sa_session
self.job = job
self.job_id = job.id
self.working_directory = working_directory
self.outputs_directory = outputs_directory
Expand All @@ -121,22 +143,25 @@ def __init__(
self.is_task = is_task
self.tool_source = tool_source
self.tool_source_class = tool_source_class
self._output_paths: Optional[OutputPaths] = None
self._output_hdas_and_paths: Optional[OutputHdasAndType] = None
self.job_outputs = JobOutputs()
self._dataset_path_rewriter: Optional[DatasetPathRewriter] = None

@property
def job(self):
return self.sa_session.query(Job).get(self.job_id)

@classmethod
def from_json(cls, path, sa_session):
with open(path) as job_io_serialized:
io_dict = json.load(job_io_serialized)
return cls.from_dict(io_dict=io_dict, sa_session=sa_session)
job_id = io_dict.pop("job_id")
job = sa_session.query(Job).get(job_id)
return cls(sa_session=sa_session, job=job, **io_dict)

@classmethod
def from_dict(cls, io_dict, sa_session):
io_dict.pop("model_class")
job_id = io_dict.pop("job_id")
job = sa_session.query(Job).get(job_id)
return cls(sa_session=sa_session, job=job, **io_dict)
return cls(sa_session=sa_session, **io_dict)

def to_dict(self):
io_dict = super().to_dict()
Expand Down Expand Up @@ -165,15 +190,15 @@ def dataset_path_rewriter(self) -> DatasetPathRewriter:

@property
def output_paths(self) -> OutputPaths:
if self._output_paths is None:
if not self.job_outputs.populated:
self.compute_outputs()
return cast(OutputPaths, self._output_paths)
return cast(OutputPaths, self.job_outputs.output_paths)

@property
def output_hdas_and_paths(self) -> OutputHdasAndType:
if self._output_hdas_and_paths is None:
if not self.job_outputs.populated:
self.compute_outputs()
return cast(OutputHdasAndType, self._output_hdas_and_paths)
return cast(OutputHdasAndType, self.job_outputs.output_hdas_and_paths)

def get_input_dataset_fnames(self, ds: DatasetInstance) -> List[str]:
filenames = [ds.file_name]
Expand Down Expand Up @@ -241,22 +266,21 @@ def compute_outputs(self) -> None:
special = self.sa_session.query(JobExportHistoryArchive).filter_by(job=job).first()
false_path = None

results = []
job_outputs = []
for da in job.output_datasets + job.output_library_datasets:
da_false_path = dataset_path_rewriter.rewrite_dataset_path(da.dataset, "output")
mutable = da.dataset.dataset.external_filename is None
dataset_path = DatasetPath(
da.dataset.dataset.id, da.dataset.file_name, false_path=da_false_path, mutable=mutable
)
results.append((da.name, da.dataset, dataset_path))
job_outputs.append(JobOutput(da.name, da.dataset, dataset_path))

self._output_paths = [t[2] for t in results]
self._output_hdas_and_paths = {t[0]: t[1:] for t in results}
if special:
false_path = dataset_path_rewriter.rewrite_dataset_path(special, "output")
dsp = DatasetPath(special.dataset.id, special.dataset.file_name, false_path)
self._output_paths.append(dsp)
self._output_hdas_and_paths["output_file"] = (special.fda, dsp)
job_outputs.append(JobOutput("output_file", special.fda, dsp))

self.job_outputs.set_job_outputs(job_outputs)

def get_output_file_id(self, file: str) -> Optional[int]:
for dp in self.output_paths:
Expand Down

0 comments on commit 55e8609

Please sign in to comment.