Skip to content

Commit

Permalink
Merge pull request #16852 from jdavcs/dev_sa20_fix20
Browse files Browse the repository at this point in the history
SQLAlchemy 2.0 upgrades (part 4)
  • Loading branch information
mvdbeek authored Nov 10, 2023
2 parents a1b5485 + 2fa96df commit 1956639
Show file tree
Hide file tree
Showing 23 changed files with 547 additions and 483 deletions.
11 changes: 0 additions & 11 deletions lib/galaxy/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,17 +323,6 @@ def _one_with_recast_errors(self, query: Query) -> Query:
except sqlalchemy.orm.exc.MultipleResultsFound:
raise exceptions.InconsistentDatabase(f"found more than one {self.model_class.__name__}")

def _one_or_none(self, query):
"""
Return the object if found, None if it's not.
:raises exceptions.InconsistentDatabase: if more than one model is found
"""
try:
return self._one_with_recast_errors(query)
except exceptions.ObjectNotFound:
return None

# NOTE: at this layer, all ids are expected to be decoded and in int form
def by_id(self, id: int) -> Query:
"""
Expand Down
14 changes: 5 additions & 9 deletions lib/galaxy/managers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def __load_element(self, trans, element_identifier, hide_source_items, copy_elem
if the_object is not None and the_object.id:
context = self.model.context
if the_object not in context:
the_object = context.query(type(the_object)).get(the_object.id)
the_object = context.get(type(the_object), the_object.id)
return the_object

# dataset_identifier is dict {src=hda|ldda|hdca|new_collection, id=<encoded_id>}
Expand Down Expand Up @@ -693,7 +693,7 @@ def get_dataset_collection_instance(

def get_dataset_collection(self, trans, encoded_id):
collection_id = int(trans.app.security.decode_id(encoded_id))
collection = trans.sa_session.query(trans.app.model.DatasetCollection).get(collection_id)
collection = trans.sa_session.get(trans.app.model.DatasetCollection, collection_id)
return collection

def apply_rules(self, hdca, rule_set, handle_dataset):
Expand Down Expand Up @@ -800,9 +800,7 @@ def __get_history_collection_instance(
self, trans: ProvidesHistoryContext, id, check_ownership=False, check_accessible=True
) -> model.HistoryDatasetCollectionAssociation:
instance_id = trans.app.security.decode_id(id) if isinstance(id, str) else id
collection_instance = trans.sa_session.query(trans.app.model.HistoryDatasetCollectionAssociation).get(
instance_id
)
collection_instance = trans.sa_session.get(trans.app.model.HistoryDatasetCollectionAssociation, instance_id)
if not collection_instance:
raise RequestParameterInvalidException("History dataset collection association not found")
# TODO: that sure looks like a bug, we can't check ownership using the history of the object we're checking ownership for ...
Expand All @@ -823,9 +821,7 @@ def __get_library_collection_instance(
"Functionality (getting library dataset collection with ownership check) unimplemented."
)
instance_id = int(trans.security.decode_id(id))
collection_instance = trans.sa_session.query(trans.app.model.LibraryDatasetCollectionAssociation).get(
instance_id
)
collection_instance = trans.sa_session.get(trans.app.model.LibraryDatasetCollectionAssociation, instance_id)
if not collection_instance:
raise RequestParameterInvalidException("Library dataset collection association not found")
if check_accessible:
Expand Down Expand Up @@ -861,6 +857,6 @@ def write_dataset_collection(self, request: PrepareDatasetCollectionDownload):
short_term_storage_monitor = self.short_term_storage_monitor
instance_id = request.history_dataset_collection_association_id
with storage_context(request.short_term_storage_request_id, short_term_storage_monitor) as target:
collection_instance = self.model.context.query(model.HistoryDatasetCollectionAssociation).get(instance_id)
collection_instance = self.model.context.get(model.HistoryDatasetCollectionAssociation, instance_id)
with ZipFile(target.path, "w") as zip_f:
write_dataset_collection(collection_instance, zip_f)
20 changes: 12 additions & 8 deletions lib/galaxy/managers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
Optional,
)

from sqlalchemy import select

from galaxy.exceptions import (
AuthenticationRequired,
UserActivationRequiredException,
Expand Down Expand Up @@ -307,15 +309,8 @@ def db_dataset_for(self, dbkey) -> Optional[HistoryDatasetAssociation]:
return None
non_ready_or_ok = set(Dataset.non_ready_states)
non_ready_or_ok.add(HistoryDatasetAssociation.states.OK)
datasets = (
self.sa_session.query(HistoryDatasetAssociation)
.filter_by(deleted=False, history_id=self.history.id, extension="len")
.filter(
HistoryDatasetAssociation.table.c._state.in_(non_ready_or_ok),
)
)
valid_ds = None
for ds in datasets:
for ds in get_hdas(self.sa_session, self.history.id, non_ready_or_ok):
if ds.dbkey == dbkey:
if ds.state == HistoryDatasetAssociation.states.OK:
return ds
Expand All @@ -330,3 +325,12 @@ def db_builds(self):
"""
# FIXME: This method should be removed
return self.app.genome_builds.get_genome_build_names(trans=self)


def get_hdas(session, history_id, states):
stmt = (
select(HistoryDatasetAssociation)
.filter_by(deleted=False, history_id=history_id, extension="len")
.where(HistoryDatasetAssociation._state.in_(states))
)
return session.scalars(stmt)
32 changes: 20 additions & 12 deletions lib/galaxy/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
TypeVar,
)

from sqlalchemy import select

from galaxy import (
exceptions,
model,
Expand All @@ -25,6 +27,10 @@
secured,
users,
)
from galaxy.model import (
Dataset,
DatasetHash,
)
from galaxy.model.base import transaction
from galaxy.schema.tasks import (
ComputeDatasetHashTaskRequest,
Expand Down Expand Up @@ -103,7 +109,7 @@ def purge_datasets(self, request: PurgeDatasetsTaskRequest):
self.error_unless_dataset_purge_allowed()
with self.session().begin():
for dataset_id in request.dataset_ids:
dataset: model.Dataset = self.session().query(model.Dataset).get(dataset_id)
dataset: Dataset = self.session().get(Dataset, dataset_id)
if dataset.user_can_purge:
try:
dataset.full_delete()
Expand Down Expand Up @@ -158,15 +164,7 @@ def compute_hash(self, request: ComputeDatasetHashTaskRequest):
# TODO: replace/update if the combination of dataset_id/hash_function has already
# been stored.
sa_session = self.session()
hash = (
sa_session.query(model.DatasetHash)
.filter(
model.DatasetHash.dataset_id == dataset.id,
model.DatasetHash.hash_function == hash_function,
model.DatasetHash.extra_files_path == extra_files_path,
)
.one_or_none()
)
hash = get_dataset_hash(sa_session, dataset.id, hash_function, extra_files_path)
if hash is None:
sa_session.add(dataset_hash)
with transaction(sa_session):
Expand Down Expand Up @@ -477,7 +475,7 @@ def ensure_can_set_metadata(self, dataset: model.DatasetInstance, raiseException

def detect_datatype(self, trans, dataset_assoc):
"""Sniff and assign the datatype to a given dataset association (ldda or hda)"""
data = trans.sa_session.query(self.model_class).get(dataset_assoc.id)
data = trans.sa_session.get(self.model_class, dataset_assoc.id)
self.ensure_can_change_datatype(data)
self.ensure_can_set_metadata(data)
path = data.dataset.get_file_name()
Expand All @@ -489,7 +487,7 @@ def detect_datatype(self, trans, dataset_assoc):

def set_metadata(self, trans, dataset_assoc, overwrite=False, validate=True):
"""Trigger a job that detects and sets metadata on a given dataset association (ldda or hda)"""
data = trans.sa_session.query(self.model_class).get(dataset_assoc.id)
data = trans.sa_session.get(self.model_class, dataset_assoc.id)
self.ensure_can_set_metadata(data)
if overwrite:
self.overwrite_metadata(data)
Expand Down Expand Up @@ -874,3 +872,13 @@ def isinstance_datatype(self, dataset_assoc, class_strs):
if datatype_class:
comparison_classes.append(datatype_class)
return comparison_classes and isinstance(dataset_assoc.datatype, tuple(comparison_classes))


def get_dataset_hash(session, dataset_id, hash_function, extra_files_path):
stmt = (
select(DatasetHash)
.where(DatasetHash.dataset_id == dataset_id)
.where(DatasetHash.hash_function == hash_function)
.where(DatasetHash.extra_files_path == extra_files_path)
)
return session.scalars(stmt).one_or_none()
Loading

0 comments on commit 1956639

Please sign in to comment.