Skip to content

Commit

Permalink
Fix SA2.0 ORM usage in managers.groups
Browse files Browse the repository at this point in the history
Move data access method to managers.users
  • Loading branch information
jdavcs committed Sep 25, 2023
1 parent 1a665c3 commit 0e4df32
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
34 changes: 26 additions & 8 deletions lib/galaxy/managers/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
List,
)

from sqlalchemy import false
from sqlalchemy import (
false,
select,
)
from sqlalchemy.orm import Session

from galaxy import model
from galaxy.exceptions import (
Expand All @@ -14,6 +18,11 @@
)
from galaxy.managers.base import decode_id
from galaxy.managers.context import ProvidesAppContext
from galaxy.managers.users import get_users_by_ids
from galaxy.model import (
Group,
Role,
)
from galaxy.model.base import transaction
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.fields import (
Expand All @@ -35,7 +44,7 @@ def index(self, trans: ProvidesAppContext):
Displays a collection (list) of groups.
"""
rval = []
for group in trans.sa_session.query(model.Group).filter(model.Group.deleted == false()):
for group in get_not_deleted_groups(trans.sa_session):
item = group.to_dict(value_mapper={"id": DecodedDatabaseIdField.encode})
encoded_id = DecodedDatabaseIdField.encode(group.id)
item["url"] = url_for("group", id=encoded_id)
Expand Down Expand Up @@ -101,11 +110,11 @@ def update(self, trans: ProvidesAppContext, group_id: int, payload: Dict[str, An
sa_session.commit()

def _check_duplicated_group_name(self, sa_session: galaxy_scoped_session, group_name: str) -> None:
if sa_session.query(model.Group).filter(model.Group.name == group_name).first():
if get_group_by_name(sa_session, group_name):
raise Conflict(f"A group with name '{group_name}' already exists")

def _get_group(self, sa_session: galaxy_scoped_session, group_id: int) -> model.Group:
group = sa_session.query(model.Group).get(group_id)
group = sa_session.get(model.Group, group_id)
if group is None:
raise ObjectNotFound("Group with the provided id was not found.")
return group
Expand All @@ -114,18 +123,27 @@ def _get_users_by_encoded_ids(
self, sa_session: galaxy_scoped_session, encoded_user_ids: List[EncodedDatabaseIdField]
) -> List[model.User]:
user_ids = self._decode_ids(encoded_user_ids)
users = sa_session.query(model.User).filter(model.User.table.c.id.in_(user_ids)).all()
return users
return get_users_by_ids(sa_session, user_ids)

def _get_roles_by_encoded_ids(
self, sa_session: galaxy_scoped_session, encoded_role_ids: List[EncodedDatabaseIdField]
) -> List[model.Role]:
role_ids = self._decode_ids(encoded_role_ids)
roles = sa_session.query(model.Role).filter(model.Role.id.in_(role_ids)).all()
return roles
stmt = select(Role).where(Role.id.in_(role_ids))
return sa_session.scalars(stmt).all()

def _decode_id(self, encoded_id: EncodedDatabaseIdField) -> int:
return decode_id(self._app, encoded_id)

def _decode_ids(self, encoded_ids: List[EncodedDatabaseIdField]) -> List[int]:
return [self._decode_id(encoded_id) for encoded_id in encoded_ids]


def get_group_by_name(session: Session, name: str):
stmt = select(Group).filter(Group.name == name).limit(1)
return session.scalars(stmt).first()


def get_not_deleted_groups(session: Session):
stmt = select(Group).where(Group.deleted == false())
return session.scalars(stmt)
11 changes: 10 additions & 1 deletion lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
select,
true,
)
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound

from galaxy import (
Expand All @@ -36,7 +37,10 @@
base,
deletable,
)
from galaxy.model import UserQuotaUsage
from galaxy.model import (
User,
UserQuotaUsage,
)
from galaxy.model.base import transaction
from galaxy.security.validate_user_input import (
VALID_EMAIL_RE,
Expand Down Expand Up @@ -850,3 +854,8 @@ def get_user_by_username(session, user_class, username):
return session.execute(stmt).scalar_one()
except Exception:
return None


def get_users_by_ids(session: Session, user_ids):
stmt = select(User).where(User.id.in_(user_ids))
return session.scalars(stmt).all()

0 comments on commit 0e4df32

Please sign in to comment.