Skip to content

Commit

Permalink
Refactor get_user_by_username for use from galaxy and tool_shed
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Sep 11, 2023
1 parent 2a093a8 commit 386662e
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 31 deletions.
13 changes: 0 additions & 13 deletions lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
and_,
exc,
func,
select,
true,
)
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -838,15 +837,3 @@ def _add_parsers(self):
)

self.fn_filter_parsers.update({})


def get_user_by_username(session, user_class, username):
"""
Get a user from the database by username.
(We pass the session and the user_class to accommodate usage from the tool_shed app.)
"""
try:
stmt = select(user_class).filter(user_class.username == username)
return session.execute(stmt).scalar_one()
except Exception:
return None
21 changes: 11 additions & 10 deletions lib/galaxy/model/repositories/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ class UserRepository(ModelRepository):
def __init__(self, session: SessionType):
super().__init__(session, User)

def get_by_email(self, email: str):
stmt = select(User).filter(User.email == email).limit(1)
return self.session.scalars(stmt).first() # type:ignore[union-attr]

# The get_user_by_email and get_user_by_username functions may be called from
# the tool_shed app, which has its own User model, which is different from
# galaxy.model.User. In that case, the tool_shed user model should be passed as
# the model_class argument.
def get_user_by_email(session, email: str, model_class=User):
stmt = select(model_class).filter(model_class.email == email).limit(1)
return session.scalars(stmt).first()

def get_user_by_username(session: SessionType, user_class, username: str):
"""Get a user from the database by username."""
# This may be called from the tool_shed app, which has a different
# definition of the User mapped class. Therefore, we must pass the User
# class as an argument instead of importing from galaxy.model.
stmt = select(user_class).filter(user_class.username == username)
return session.execute(stmt).scalar_one() # type:ignore[union-attr]

def get_user_by_username(session, username: str, model_class=User):
stmt = select(model_class).filter(model_class.username == username).limit(1)
return session.scalars(stmt).first()
7 changes: 2 additions & 5 deletions lib/tool_shed/test/base/test_db_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import galaxy.model
import galaxy.model.tool_shed_install
import tool_shed.webapp.model as model
from galaxy.model.repositories.user import get_user_by_username

log = logging.getLogger("test.tool_shed.test_db_util")

Expand Down Expand Up @@ -171,10 +172,6 @@ def get_user(email):
return sa_session().query(model.User).filter(model.User.table.c.email == email).first()


def get_user_by_name(username):
return sa_session().query(model.User).filter(model.User.table.c.username == username).first()


def mark_obj_deleted(obj):
obj.deleted = True
sa_session().add(obj)
Expand All @@ -190,7 +187,7 @@ def ga_refresh(obj):


def get_repository_by_name_and_owner(name, owner_username, return_multiple=False):
owner = get_user_by_name(owner_username)
owner = get_user_by_username(sa_session(), owner_username, model.User)
repository = (
sa_session()
.query(model.Repository)
Expand Down
6 changes: 3 additions & 3 deletions lib/tool_shed/webapp/controllers/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
util,
web,
)
from galaxy.managers.users import get_user_by_username
from galaxy.model.base import transaction
from galaxy.model.repositories.user import get_user_by_username
from galaxy.tool_shed.util import dependency_display
from galaxy.tools.repositories import ValidationContext
from galaxy.web.form_builder import (
Expand Down Expand Up @@ -2293,7 +2293,7 @@ def set_malicious(self, trans, id, ctx_str, **kwd):
def sharable_owner(self, trans, owner):
"""Support for sharable URL for each repository owner's tools, e.g. http://example.org/view/owner."""
try:
user = get_user_by_username(trans.model.session, trans.model.User, owner)
user = get_user_by_username(trans.model.session, owner, trans.model.User)
except Exception:
user = None
if user:
Expand Down Expand Up @@ -2321,7 +2321,7 @@ def sharable_repository(self, trans, owner, name):
else:
# If the owner is valid, then show all of their repositories.
try:
user = get_user_by_username(trans.model.session, trans.model.User, owner)
user = get_user_by_username(trans.model.session, owner, trans.model.User)
except Exception:
user = None
if user:
Expand Down
9 changes: 9 additions & 0 deletions lib/tool_shed/webapp/model/repositories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sqlalchemy import select
from tool_shed.webapp.model import User


class UserRepository:

def get_by_username(self, session, username: str):
stmt = select(User).filter(User.username == username).limit(1)
return session.scalars(stmt).first()

0 comments on commit 386662e

Please sign in to comment.