From 37e97080f21ac78f709a98be16d498fdf21d47b7 Mon Sep 17 00:00:00 2001 From: John Davis Date: Thu, 10 Aug 2023 15:03:56 -0400 Subject: [PATCH] Fix SA2.0 (query->select) in galaxy.tools --- lib/galaxy/tools/__init__.py | 18 ++++--- lib/galaxy/tools/actions/__init__.py | 4 +- lib/galaxy/tools/actions/upload_common.py | 16 +++--- lib/galaxy/tools/errors.py | 4 +- lib/galaxy/tools/imp_exp/__init__.py | 5 +- lib/galaxy/tools/parameters/basic.py | 61 +++++++++++------------ lib/galaxy/tools/parameters/meta.py | 2 +- 7 files changed, 59 insertions(+), 51 deletions(-) diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index 53e8f1e0c4cb..6c0ce878bd55 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -30,6 +30,11 @@ from lxml import etree from mako.template import Template from packaging.version import Version +from sqlalchemy import ( + delete, + func, + select, +) from galaxy import ( exceptions, @@ -346,9 +351,9 @@ def __init__(self, app): def reset_tags(self): log.info( - f"removing all tool tag associations ({str(self.sa_session.query(self.app.model.ToolTagAssociation).count())})" + f"removing all tool tag associations ({str(self.sa_session.scalar(select(func.count(self.app.model.ToolTagAssociation))))})" ) - self.sa_session.query(self.app.model.ToolTagAssociation).delete() + self.sa_session.execute(delete(self.app.model.ToolTagAssociation)) with transaction(self.sa_session): self.sa_session.commit() @@ -359,7 +364,8 @@ def handle_tags(self, tool_id, tool_definition_source): for tag_name in tag_names: if tag_name == "": continue - tag = self.sa_session.query(self.app.model.Tag).filter_by(name=tag_name).first() + stmt = select(self.app.model.Tag).filter_by(name=tag_name).limit(1) + tag = self.sa_session.scalars(stmt).first() if not tag: tag = self.app.model.Tag(name=tag_name) self.sa_session.add(tag) @@ -618,7 +624,7 @@ def _load_workflow(self, workflow_id): which is encoded in the tool panel. """ id = self.app.security.decode_id(workflow_id) - stored = self.app.model.context.query(self.app.model.StoredWorkflow).get(id) + stored = self.app.model.context.get(self.app.model.StoredWorkflow, id) return stored.latest_workflow def __build_tool_version_select_field(self, tools, tool_id, set_selected): @@ -3121,7 +3127,7 @@ def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kw self.sa_session.commit() def job_failed(self, job_wrapper, message, exception=False): - job = job_wrapper.sa_session.query(model.Job).get(job_wrapper.job_id) + job = job_wrapper.sa_session.get(model.Job, job_wrapper.job_id) if job: inp_data = {} for dataset_assoc in job.input_datasets: @@ -3168,7 +3174,7 @@ def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kw def job_failed(self, job_wrapper, message, exception=False): super().job_failed(job_wrapper, message, exception=exception) - job = job_wrapper.sa_session.query(model.Job).get(job_wrapper.job_id) + job = job_wrapper.sa_session.get(model.Job, job_wrapper.job_id) self.__remove_interactivetool_by_job(job) diff --git a/lib/galaxy/tools/actions/__init__.py b/lib/galaxy/tools/actions/__init__.py index 934604a8ef8e..e6259d9a33df 100644 --- a/lib/galaxy/tools/actions/__init__.py +++ b/lib/galaxy/tools/actions/__init__.py @@ -481,7 +481,7 @@ def handle_output(name, output, hidden=None): if async_tool and name in incoming: # HACK: output data has already been created as a result of the async controller dataid = incoming[name] - data = trans.sa_session.query(app.model.HistoryDatasetAssociation).get(dataid) + data = trans.sa_session.get(app.model.HistoryDatasetAssociation, dataid) assert data is not None out_data[name] = data else: @@ -745,7 +745,7 @@ def _remap_job_on_rerun(self, trans, galaxy_session, rerun_remap_job_id, current input datasets to be those of the job that is being rerun. """ try: - old_job = trans.sa_session.query(trans.app.model.Job).get(rerun_remap_job_id) + old_job = trans.sa_session.get(trans.app.model.Job, rerun_remap_job_id) assert old_job is not None, f"({rerun_remap_job_id}/{current_job.id}): Old job id is invalid" assert ( old_job.tool_id == current_job.tool_id diff --git a/lib/galaxy/tools/actions/upload_common.py b/lib/galaxy/tools/actions/upload_common.py index 14e4861a9ad8..b0a4c691477d 100644 --- a/lib/galaxy/tools/actions/upload_common.py +++ b/lib/galaxy/tools/actions/upload_common.py @@ -13,6 +13,7 @@ Optional, ) +from sqlalchemy import select from sqlalchemy.orm import joinedload from webob.compat import cgi_FieldStorage @@ -94,12 +95,12 @@ def handle_library_params( # See if we have any template field contents template_field_contents = {} template_id = params.get("template_id", None) - folder = trans.sa_session.query(LibraryFolder).get(folder_id) + folder = trans.sa_session.get(LibraryFolder, folder_id) # We are inheriting the folder's info_association, so we may have received inherited contents or we may have redirected # here after the user entered template contents ( due to errors ). template: Optional[FormDefinition] = None if template_id not in [None, "None"]: - template = trans.sa_session.query(FormDefinition).get(template_id) + template = trans.sa_session.get(FormDefinition, template_id) assert template for field in template.fields: field_name = field["name"] @@ -108,7 +109,7 @@ def handle_library_params( template_field_contents[field_name] = field_value roles: List[Role] = [] for role_id in util.listify(params.get("roles", [])): - role = trans.sa_session.query(Role).get(role_id) + role = trans.sa_session.get(Role, role_id) roles.append(role) tags = params.get("tags", None) return LibraryParams( @@ -436,10 +437,11 @@ def active_folders(trans, folder): # Stolen from galaxy.web.controllers.library_common (importing from which causes a circular issues). # Much faster way of retrieving all active sub-folders within a given folder than the # performance of the mapper. This query also eagerloads the permissions on each folder. - return ( - trans.sa_session.query(LibraryFolder) + stmt = ( + select(LibraryFolder) .filter_by(parent=folder, deleted=False) .options(joinedload(LibraryFolder.actions)) - .order_by(LibraryFolder.table.c.name) - .all() + .unique() + .order_by(LibraryFolder.name) ) + return trans.sa_session.scalars(stmt).all() diff --git a/lib/galaxy/tools/errors.py b/lib/galaxy/tools/errors.py index 26aa0e017af9..a2ee14a69ddf 100644 --- a/lib/galaxy/tools/errors.py +++ b/lib/galaxy/tools/errors.py @@ -137,10 +137,10 @@ def __init__(self, hda, app): if not isinstance(hda, model.HistoryDatasetAssociation): hda_id = hda try: - hda = sa_session.query(model.HistoryDatasetAssociation).get(hda_id) + hda = sa_session.get(model.HistoryDatasetAssociation, hda_id) assert hda is not None, ValueError("No HDA yet") except Exception: - hda = sa_session.query(model.HistoryDatasetAssociation).get(app.security.decode_id(hda_id)) + hda = sa_session.get(model.HistoryDatasetAssociation, app.security.decode_id(hda_id)) assert isinstance(hda, model.HistoryDatasetAssociation), ValueError(f"Bad value provided for HDA ({hda}).") self.hda = hda # Get the associated job diff --git a/lib/galaxy/tools/imp_exp/__init__.py b/lib/galaxy/tools/imp_exp/__init__.py index 4c8c53a91141..a1338d8c453b 100644 --- a/lib/galaxy/tools/imp_exp/__init__.py +++ b/lib/galaxy/tools/imp_exp/__init__.py @@ -4,6 +4,8 @@ import shutil from typing import Optional +from sqlalchemy import select + from galaxy import model from galaxy.model import store from galaxy.model.base import transaction @@ -49,7 +51,8 @@ def cleanup_after_job(self): # Import history. # - jiha = self.sa_session.query(model.JobImportHistoryArchive).filter_by(job_id=self.job_id).first() + stmt = select(model.JobImportHistoryArchive).filter_by(job_id=self.job_id).limit(1) + jiha = self.sa_session.scalars(stmt).first() if not jiha: return None user = jiha.job.user diff --git a/lib/galaxy/tools/parameters/basic.py b/lib/galaxy/tools/parameters/basic.py index 6de845c6fd32..969551cf3389 100644 --- a/lib/galaxy/tools/parameters/basic.py +++ b/lib/galaxy/tools/parameters/basic.py @@ -1943,13 +1943,13 @@ def single_to_python(value): if isinstance(value, dict) and "src" in value: id = value["id"] if isinstance(value["id"], int) else app.security.decode_id(value["id"]) if value["src"] == "dce": - return app.model.context.query(DatasetCollectionElement).get(id) + return app.model.context.get(DatasetCollectionElement, id) elif value["src"] == "hdca": - return app.model.context.query(HistoryDatasetCollectionAssociation).get(id) + return app.model.context.get(HistoryDatasetCollectionAssociation, id) elif value["src"] == "ldda": - return app.model.context.query(LibraryDatasetDatasetAssociation).get(id) + return app.model.context.get(LibraryDatasetDatasetAssociation, id) else: - return app.model.context.query(HistoryDatasetAssociation).get(id) + return app.model.context.get(HistoryDatasetAssociation, id) if isinstance(value, dict) and "values" in value: if hasattr(self, "multiple") and self.multiple is True: @@ -1963,7 +1963,7 @@ def single_to_python(value): return None if isinstance(value, str) and value.find(",") > -1: return [ - app.model.context.query(HistoryDatasetAssociation).get(int(v)) + app.model.context.get(HistoryDatasetAssociation, int(v)) for v in value.split(",") if v not in none_values ] @@ -1971,13 +1971,13 @@ def single_to_python(value): decoded_id = str(value)[len("__collection_reduce__|") :] if not decoded_id.isdigit(): decoded_id = app.security.decode_id(decoded_id) - return app.model.context.query(HistoryDatasetCollectionAssociation).get(int(decoded_id)) + return app.model.context.get(HistoryDatasetCollectionAssociation, int(decoded_id)) elif str(value).startswith("dce:"): - return app.model.context.query(DatasetCollectionElement).get(int(value[len("dce:") :])) + return app.model.context.get(DatasetCollectionElement, int(value[len("dce:") :])) elif str(value).startswith("hdca:"): - return app.model.context.query(HistoryDatasetCollectionAssociation).get(int(value[len("hdca:") :])) + return app.model.context.get(HistoryDatasetCollectionAssociation, int(value[len("hdca:") :])) else: - return app.model.context.query(HistoryDatasetAssociation).get(int(value)) + return app.model.context.get(HistoryDatasetAssociation, int(value)) def validate(self, value, trans=None): def do_validate(v): @@ -2097,17 +2097,17 @@ def from_json(self, value, trans, other_values=None): if isinstance(single_value, dict) and "src" in single_value and "id" in single_value: if single_value["src"] == "hda": decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(decoded_id)) + rval.append(trans.sa_session.get(HistoryDatasetAssociation, decoded_id)) elif single_value["src"] == "hdca": found_hdca = True decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id)) + rval.append(trans.sa_session.get(HistoryDatasetCollectionAssociation, decoded_id)) elif single_value["src"] == "ldda": decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(trans.sa_session.query(LibraryDatasetDatasetAssociation).get(decoded_id)) + rval.append(trans.sa_session.get(LibraryDatasetDatasetAssociation, decoded_id)) elif single_value["src"] == "dce": decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(trans.sa_session.query(DatasetCollectionElement).get(decoded_id)) + rval.append(trans.sa_session.get(DatasetCollectionElement, decoded_id)) else: raise ValueError(f"Unknown input source {single_value['src']} passed to job submission API.") elif isinstance( @@ -2126,7 +2126,7 @@ def from_json(self, value, trans, other_values=None): # support that for integer column types. log.warning("Encoded ID where unencoded ID expected.") single_value = trans.security.decode_id(single_value) - rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(single_value)) + rval.append(trans.sa_session.get(HistoryDatasetAssociation, single_value)) if found_hdca: for val in rval: if not isinstance(val, HistoryDatasetCollectionAssociation): @@ -2139,13 +2139,13 @@ def from_json(self, value, trans, other_values=None): elif isinstance(value, dict) and "src" in value and "id" in value: if value["src"] == "hda": decoded_id = trans.security.decode_id(value["id"]) - rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(decoded_id)) + rval.append(trans.sa_session.get(HistoryDatasetAssociation, decoded_id)) elif value["src"] == "hdca": decoded_id = trans.security.decode_id(value["id"]) - rval.append(trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id)) + rval.append(trans.sa_session.get(HistoryDatasetCollectionAssociation, decoded_id)) elif value["src"] == "dce": decoded_id = trans.security.decode_id(value["id"]) - rval.append(trans.sa_session.query(DatasetCollectionElement).get(decoded_id)) + rval.append(trans.sa_session.get(DatasetCollectionElement, decoded_id)) else: raise ValueError(f"Unknown input source {value['src']} passed to job submission API.") elif str(value).startswith("__collection_reduce__|"): @@ -2153,12 +2153,12 @@ def from_json(self, value, trans, other_values=None): decoded_ids = map(trans.security.decode_id, encoded_ids) rval = [] for decoded_id in decoded_ids: - hdca = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id) + hdca = trans.sa_session.get(HistoryDatasetCollectionAssociation, decoded_id) rval.append(hdca) elif isinstance(value, HistoryDatasetCollectionAssociation) or isinstance(value, DatasetCollectionElement): rval.append(value) else: - rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(value)) + rval.append(trans.sa_session.get(HistoryDatasetAssociation, value)) dataset_matcher_factory = get_dataset_matcher_factory(trans) dataset_matcher = dataset_matcher_factory.dataset_matcher(self, other_values) for v in rval: @@ -2443,28 +2443,24 @@ def from_json(self, value, trans, other_values=None): rval = value elif isinstance(value, dict) and "src" in value and "id" in value: if value["src"] == "hdca": - rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get( - trans.security.decode_id(value["id"]) - ) + rval = trans.sa_session.get(HistoryDatasetCollectionAssociation, trans.security.decode_id(value["id"])) elif isinstance(value, list): if len(value) > 0: value = value[0] if isinstance(value, dict) and "src" in value and "id" in value: if value["src"] == "hdca": - rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get( - trans.security.decode_id(value["id"]) + rval = trans.sa_session.get( + HistoryDatasetCollectionAssociation, trans.security.decode_id(value["id"]) ) elif value["src"] == "dce": - rval = trans.sa_session.query(DatasetCollectionElement).get( - trans.security.decode_id(value["id"]) - ) + rval = trans.sa_session.get(DatasetCollectionElement, trans.security.decode_id(value["id"])) elif isinstance(value, str): if value.startswith("dce:"): - rval = trans.sa_session.query(DatasetCollectionElement).get(value[len("dce:") :]) + rval = trans.sa_session.get(DatasetCollectionElement, value[len("dce:") :]) elif value.startswith("hdca:"): - rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(value[len("hdca:") :]) + rval = trans.sa_session.get(HistoryDatasetCollectionAssociation, value[len("hdca:") :]) else: - rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(value) + rval = trans.sa_session.get(HistoryDatasetCollectionAssociation, value) if rval and isinstance(rval, HistoryDatasetCollectionAssociation): if rval.deleted: raise ParameterValueError("the previously selected dataset collection has been deleted", self.name) @@ -2634,8 +2630,9 @@ def to_python(self, value, app, other_values=None, validate=False): else: lst = [] break - lda = app.model.context.query(LibraryDatasetDatasetAssociation).get( - lda_id if isinstance(lda_id, int) else app.security.decode_id(lda_id) + lda = app.model.context.get( + LibraryDatasetDatasetAssociation, + lda_id if isinstance(lda_id, int) else app.security.decode_id(lda_id), ) if lda is not None: lst.append(lda) diff --git a/lib/galaxy/tools/parameters/meta.py b/lib/galaxy/tools/parameters/meta.py index a9f47ac0eae0..d11bfda71a5a 100644 --- a/lib/galaxy/tools/parameters/meta.py +++ b/lib/galaxy/tools/parameters/meta.py @@ -265,7 +265,7 @@ def __expand_collection_parameter(trans, input_key, incoming_val, collections_to encoded_hdc_id = incoming_val subcollection_type = None hdc_id = trans.app.security.decode_id(encoded_hdc_id) - hdc = trans.sa_session.query(model.HistoryDatasetCollectionAssociation).get(hdc_id) + hdc = trans.sa_session.get(model.HistoryDatasetCollectionAssociation, hdc_id) collections_to_match.add(input_key, hdc, subcollection_type=subcollection_type, linked=linked) if subcollection_type is not None: subcollection_elements = subcollections.split_dataset_collection_instance(hdc, subcollection_type)