diff --git a/lib/galaxy/managers/users.py b/lib/galaxy/managers/users.py index 4489a00689e6..bd739a169ffe 100644 --- a/lib/galaxy/managers/users.py +++ b/lib/galaxy/managers/users.py @@ -20,7 +20,6 @@ and_, exc, func, - select, true, ) from sqlalchemy.orm.exc import NoResultFound @@ -838,15 +837,3 @@ def _add_parsers(self): ) self.fn_filter_parsers.update({}) - - -def get_user_by_username(session, user_class, username): - """ - Get a user from the database by username. - (We pass the session and the user_class to accommodate usage from the tool_shed app.) - """ - try: - stmt = select(user_class).filter(user_class.username == username) - return session.execute(stmt).scalar_one() - except Exception: - return None diff --git a/lib/galaxy/model/repositories/user.py b/lib/galaxy/model/repositories/user.py index 157bf7c2bf1f..f0cf0b2c98fe 100644 --- a/lib/galaxy/model/repositories/user.py +++ b/lib/galaxy/model/repositories/user.py @@ -11,15 +11,16 @@ class UserRepository(ModelRepository): def __init__(self, session: SessionType): super().__init__(session, User) - def get_by_email(self, email: str): - stmt = select(User).filter(User.email == email).limit(1) - return self.session.scalars(stmt).first() # type:ignore[union-attr] +# The get_user_by_email and get_user_by_username functions may be called from +# the tool_shed app, which has its own User model, which is different from +# galaxy.model.User. In that case, the tool_shed user model should be passed as +# the model_class argument. +def get_user_by_email(session, email: str, model_class=User): + stmt = select(model_class).filter(model_class.email == email).limit(1) + return session.scalars(stmt).first() -def get_user_by_username(session: SessionType, user_class, username: str): - """Get a user from the database by username.""" - # This may be called from the tool_shed app, which has a different - # definition of the User mapped class. Therefore, we must pass the User - # class as an argument instead of importing from galaxy.model. - stmt = select(user_class).filter(user_class.username == username) - return session.execute(stmt).scalar_one() # type:ignore[union-attr] + +def get_user_by_username(session, username: str, model_class=User): + stmt = select(model_class).filter(model_class.username == username).limit(1) + return session.scalars(stmt).first() diff --git a/lib/tool_shed/test/base/test_db_util.py b/lib/tool_shed/test/base/test_db_util.py index 3b4cbce31f16..e7e4c88415f9 100644 --- a/lib/tool_shed/test/base/test_db_util.py +++ b/lib/tool_shed/test/base/test_db_util.py @@ -10,6 +10,7 @@ import galaxy.model import galaxy.model.tool_shed_install import tool_shed.webapp.model as model +from galaxy.model.repositories.user import get_user_by_username log = logging.getLogger("test.tool_shed.test_db_util") @@ -171,10 +172,6 @@ def get_user(email): return sa_session().query(model.User).filter(model.User.table.c.email == email).first() -def get_user_by_name(username): - return sa_session().query(model.User).filter(model.User.table.c.username == username).first() - - def mark_obj_deleted(obj): obj.deleted = True sa_session().add(obj) @@ -190,7 +187,7 @@ def ga_refresh(obj): def get_repository_by_name_and_owner(name, owner_username, return_multiple=False): - owner = get_user_by_name(owner_username) + owner = get_user_by_username(sa_session(), owner_username, model.User) repository = ( sa_session() .query(model.Repository) diff --git a/lib/tool_shed/webapp/controllers/repository.py b/lib/tool_shed/webapp/controllers/repository.py index e7627292f426..747162130692 100644 --- a/lib/tool_shed/webapp/controllers/repository.py +++ b/lib/tool_shed/webapp/controllers/repository.py @@ -24,8 +24,8 @@ util, web, ) -from galaxy.managers.users import get_user_by_username from galaxy.model.base import transaction +from galaxy.model.repositories.user import get_user_by_username from galaxy.tool_shed.util import dependency_display from galaxy.tools.repositories import ValidationContext from galaxy.web.form_builder import ( @@ -2293,7 +2293,7 @@ def set_malicious(self, trans, id, ctx_str, **kwd): def sharable_owner(self, trans, owner): """Support for sharable URL for each repository owner's tools, e.g. http://example.org/view/owner.""" try: - user = get_user_by_username(trans.model.session, trans.model.User, owner) + user = get_user_by_username(trans.model.session, owner, trans.model.User) except Exception: user = None if user: @@ -2321,7 +2321,7 @@ def sharable_repository(self, trans, owner, name): else: # If the owner is valid, then show all of their repositories. try: - user = get_user_by_username(trans.model.session, trans.model.User, owner) + user = get_user_by_username(trans.model.session, owner, trans.model.User) except Exception: user = None if user: diff --git a/lib/tool_shed/webapp/model/repositories.py b/lib/tool_shed/webapp/model/repositories.py new file mode 100644 index 000000000000..67f83430e389 --- /dev/null +++ b/lib/tool_shed/webapp/model/repositories.py @@ -0,0 +1,9 @@ +from sqlalchemy import select +from tool_shed.webapp.model import User + + +class UserRepository: + + def get_by_username(self, session, username: str): + stmt = select(User).filter(User.username == username).limit(1) + return session.scalars(stmt).first()