From 0e4df328e21f52e61f7dc698ac408215ba97292d Mon Sep 17 00:00:00 2001 From: John Davis Date: Wed, 20 Sep 2023 15:25:28 -0400 Subject: [PATCH] Fix SA2.0 ORM usage in managers.groups Move data access method to managers.users --- lib/galaxy/managers/groups.py | 34 ++++++++++++++++++++++++++-------- lib/galaxy/managers/users.py | 11 ++++++++++- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/lib/galaxy/managers/groups.py b/lib/galaxy/managers/groups.py index 2fc03d42738a..374428c979f1 100644 --- a/lib/galaxy/managers/groups.py +++ b/lib/galaxy/managers/groups.py @@ -4,7 +4,11 @@ List, ) -from sqlalchemy import false +from sqlalchemy import ( + false, + select, +) +from sqlalchemy.orm import Session from galaxy import model from galaxy.exceptions import ( @@ -14,6 +18,11 @@ ) from galaxy.managers.base import decode_id from galaxy.managers.context import ProvidesAppContext +from galaxy.managers.users import get_users_by_ids +from galaxy.model import ( + Group, + Role, +) from galaxy.model.base import transaction from galaxy.model.scoped_session import galaxy_scoped_session from galaxy.schema.fields import ( @@ -35,7 +44,7 @@ def index(self, trans: ProvidesAppContext): Displays a collection (list) of groups. """ rval = [] - for group in trans.sa_session.query(model.Group).filter(model.Group.deleted == false()): + for group in get_not_deleted_groups(trans.sa_session): item = group.to_dict(value_mapper={"id": DecodedDatabaseIdField.encode}) encoded_id = DecodedDatabaseIdField.encode(group.id) item["url"] = url_for("group", id=encoded_id) @@ -101,11 +110,11 @@ def update(self, trans: ProvidesAppContext, group_id: int, payload: Dict[str, An sa_session.commit() def _check_duplicated_group_name(self, sa_session: galaxy_scoped_session, group_name: str) -> None: - if sa_session.query(model.Group).filter(model.Group.name == group_name).first(): + if get_group_by_name(sa_session, group_name): raise Conflict(f"A group with name '{group_name}' already exists") def _get_group(self, sa_session: galaxy_scoped_session, group_id: int) -> model.Group: - group = sa_session.query(model.Group).get(group_id) + group = sa_session.get(model.Group, group_id) if group is None: raise ObjectNotFound("Group with the provided id was not found.") return group @@ -114,18 +123,27 @@ def _get_users_by_encoded_ids( self, sa_session: galaxy_scoped_session, encoded_user_ids: List[EncodedDatabaseIdField] ) -> List[model.User]: user_ids = self._decode_ids(encoded_user_ids) - users = sa_session.query(model.User).filter(model.User.table.c.id.in_(user_ids)).all() - return users + return get_users_by_ids(sa_session, user_ids) def _get_roles_by_encoded_ids( self, sa_session: galaxy_scoped_session, encoded_role_ids: List[EncodedDatabaseIdField] ) -> List[model.Role]: role_ids = self._decode_ids(encoded_role_ids) - roles = sa_session.query(model.Role).filter(model.Role.id.in_(role_ids)).all() - return roles + stmt = select(Role).where(Role.id.in_(role_ids)) + return sa_session.scalars(stmt).all() def _decode_id(self, encoded_id: EncodedDatabaseIdField) -> int: return decode_id(self._app, encoded_id) def _decode_ids(self, encoded_ids: List[EncodedDatabaseIdField]) -> List[int]: return [self._decode_id(encoded_id) for encoded_id in encoded_ids] + + +def get_group_by_name(session: Session, name: str): + stmt = select(Group).filter(Group.name == name).limit(1) + return session.scalars(stmt).first() + + +def get_not_deleted_groups(session: Session): + stmt = select(Group).where(Group.deleted == false()) + return session.scalars(stmt) diff --git a/lib/galaxy/managers/users.py b/lib/galaxy/managers/users.py index 4489a00689e6..dedc1d5ef90d 100644 --- a/lib/galaxy/managers/users.py +++ b/lib/galaxy/managers/users.py @@ -23,6 +23,7 @@ select, true, ) +from sqlalchemy.orm import Session from sqlalchemy.orm.exc import NoResultFound from galaxy import ( @@ -36,7 +37,10 @@ base, deletable, ) -from galaxy.model import UserQuotaUsage +from galaxy.model import ( + User, + UserQuotaUsage, +) from galaxy.model.base import transaction from galaxy.security.validate_user_input import ( VALID_EMAIL_RE, @@ -850,3 +854,8 @@ def get_user_by_username(session, user_class, username): return session.execute(stmt).scalar_one() except Exception: return None + + +def get_users_by_ids(session: Session, user_ids): + stmt = select(User).where(User.id.in_(user_ids)) + return session.scalars(stmt).all()