diff --git a/lib/galaxy/managers/users.py b/lib/galaxy/managers/users.py index b3c309383e84..8d08d767f355 100644 --- a/lib/galaxy/managers/users.py +++ b/lib/galaxy/managers/users.py @@ -844,8 +844,11 @@ def get_users_by_ids(session: Session, user_ids): # the tool_shed app, which has its own User model, which is different from # galaxy.model.User. In that case, the tool_shed user model should be passed as # the model_class argument. -def get_user_by_email(session, email: str, model_class=User): - stmt = select(model_class).filter(model_class.email == email).limit(1) +def get_user_by_email(session, email: str, model_class=User, case_sensitive=True): + filter_clause = model_class.email == email + if not case_sensitive: + filter_clause = func.lower(model_class.email) == func.lower(email) + stmt = select(model_class).where(filter_clause).limit(1) return session.scalars(stmt).first() diff --git a/lib/tool_shed/managers/categories.py b/lib/tool_shed/managers/categories.py index 9176fef9e978..5341c4dea450 100644 --- a/lib/tool_shed/managers/categories.py +++ b/lib/tool_shed/managers/categories.py @@ -5,6 +5,8 @@ List, ) +from sqlalchemy import select + import tool_shed.util.shed_util_common as suc from galaxy import ( exceptions, @@ -41,14 +43,9 @@ def create(self, trans: ProvidesUserContext, category_request: CreateCategoryReq raise exceptions.RequestParameterMissingException('Missing required parameter "name".') def index_db(self, trans: ProvidesUserContext, deleted: bool) -> List[Category]: - category_db_objects: List[Category] = [] if deleted and not trans.user_is_admin: raise exceptions.AdminRequiredException("Only administrators can query deleted categories.") - for category in ( - trans.sa_session.query(Category).filter(Category.table.c.deleted == deleted).order_by(Category.table.c.name) - ): - category_db_objects.append(category) - return category_db_objects + return list(get_categories_by_deleted(trans.sa_session, deleted)) def index(self, trans: ProvidesUserContext, deleted: bool) -> List[Dict[str, Any]]: category_dicts: List[Dict[str, Any]] = [] @@ -80,3 +77,8 @@ def to_model(self, category: Category) -> CategoryResponse: def get_value_mapper(app: ToolShedApp) -> Dict[str, Callable]: value_mapper = {"id": app.security.encode_id} return value_mapper + + +def get_categories_by_deleted(session, deleted): + stmt = select(Category).where(Category.deleted == deleted).order_by(Category.name) + return session.scalars(stmt) diff --git a/lib/tool_shed/managers/groups.py b/lib/tool_shed/managers/groups.py index 50f5e06222f5..8ade7cb76c92 100644 --- a/lib/tool_shed/managers/groups.py +++ b/lib/tool_shed/managers/groups.py @@ -5,6 +5,7 @@ from sqlalchemy import ( false, + select, true, ) from sqlalchemy.orm.exc import ( @@ -47,13 +48,11 @@ def get(self, trans, decoded_group_id=None, name=None): if decoded_group_id is None and name is None: raise RequestParameterInvalidException("You must supply either ID or a name of the group.") - name_query = trans.sa_session.query(trans.app.model.Group).filter(trans.app.model.Group.table.c.name == name) - id_query = trans.sa_session.query(trans.app.model.Group).filter( - trans.app.model.Group.table.c.id == decoded_group_id - ) - try: - group = id_query.one() if decoded_group_id else name_query.one() + if decoded_group_id: + group = trans.sa_session.get(trans.app.model.Group, decoded_group_id) + else: + group = get_group_by_name(trans.sa_session, name, trans.app.model.Group) except MultipleResultsFound: raise InconsistentDatabase("Multiple groups found with the same identifier.") except NoResultFound: @@ -121,16 +120,21 @@ def list(self, trans, deleted=False): :returns: query that will emit all groups :rtype: sqlalchemy query """ - is_admin = trans.user_is_admin - query = trans.sa_session.query(trans.app.model.Group) - if is_admin: + Group = trans.app.model.Group + stmt = select(Group) + if trans.user_is_admin: if deleted is None: # Flag is not specified, do not filter on it. pass elif deleted: - query = query.filter(trans.app.model.Group.table.c.deleted == true()) + stmt = stmt.where(Group.deleted == true()) else: - query = query.filter(trans.app.model.Group.table.c.deleted == false()) + stmt = stmt.where(Group.deleted == false()) else: - query = query.filter(trans.app.model.Group.table.c.deleted == false()) - return query + stmt = stmt.where(Group.deleted == false()) + return trans.sa_session.scalars(stmt) + + +def get_group_by_name(session, name, group_model): + stmt = select(group_model).where(group_model.name == name) + return session.execute(stmt).scalar_one() diff --git a/lib/tool_shed/managers/repositories.py b/lib/tool_shed/managers/repositories.py index 805f716164c5..596aca642b0e 100644 --- a/lib/tool_shed/managers/repositories.py +++ b/lib/tool_shed/managers/repositories.py @@ -17,8 +17,8 @@ from pydantic import BaseModel from sqlalchemy import ( - and_, false, + select, ) from galaxy import web @@ -202,17 +202,7 @@ def check_updates(app: ToolShedApp, request: UpdatesRequest) -> Union[str, Dict[ def guid_to_repository(app: ToolShedApp, tool_id: str) -> "Repository": # tool_id = remove_protocol_and_user_from_clone_url(tool_id) shed, _, owner, name, rest = tool_id.split("/", 5) - clause_list = [ - and_( - app.model.Repository.table.c.deprecated == false(), - app.model.Repository.table.c.deleted == false(), - app.model.Repository.table.c.name == name, - app.model.User.table.c.username == owner, - app.model.Repository.table.c.user_id == app.model.User.table.c.id, - ) - ] - repository = app.model.context.query(app.model.Repository).filter(*clause_list).first() - return repository + return _get_repository_by_name_and_owner(app.model.context, name, owner, app.model.User) def index_tool_ids(app: ToolShedApp, tool_ids: List[str]) -> Dict[str, Any]: @@ -222,16 +212,7 @@ def index_tool_ids(app: ToolShedApp, tool_ids: List[str]) -> Dict[str, Any]: repository = guid_to_repository(app, tool_id) owner = repository.user.username name = repository.name - clause_list = [ - and_( - app.model.Repository.table.c.deprecated == false(), - app.model.Repository.table.c.deleted == false(), - app.model.Repository.table.c.name == name, - app.model.User.table.c.username == owner, - app.model.Repository.table.c.user_id == app.model.User.table.c.id, - ) - ] - repository = app.model.context.current.sa_session.query(app.model.Repository).filter(*clause_list).first() + repository = _get_repository_by_name_and_owner(app.model.context.current, name, owner, app.model.User) if not repository: log.warning(f"Repository {owner}/{name} does not exist, skipping") continue @@ -273,27 +254,7 @@ def index_tool_ids(app: ToolShedApp, tool_ids: List[str]) -> Dict[str, Any]: def index_repositories(app: ToolShedApp, name: Optional[str], owner: Optional[str], deleted: bool): - clause_list = [ - and_( - app.model.Repository.table.c.deprecated == false(), - app.model.Repository.table.c.deleted == deleted, - ) - ] - if owner is not None: - clause_list.append( - and_( - app.model.User.table.c.username == owner, - app.model.Repository.table.c.user_id == app.model.User.table.c.id, - ) - ) - if name is not None: - clause_list.append(app.model.Repository.table.c.name == name) - repositories = [] - for repository in ( - app.model.context.query(app.model.Repository).filter(*clause_list).order_by(app.model.Repository.table.c.name) - ): - repositories.append(repository) - return repositories + return list(_get_repository_by_name_and_owner_and_deleted(app.model.context, name, owner, deleted, app.model.User)) def can_manage_repo(trans: ProvidesUserContext, repository: Repository) -> bool: @@ -616,3 +577,27 @@ def upload_tar_and_set_metadata( else: raise InternalServerError(message) return message + + +def _get_repository_by_name_and_owner(session, name, owner, user_model): + stmt = ( + select(Repository) + .where(Repository.deprecated == false()) + .where(Repository.deleted == false()) + .where(Repository.name == name) + .where(user_model.username == owner) + .where(Repository.user_id == user_model.id) + .limit(1) + ) + return session.scalars(stmt).first() + + +def _get_repository_by_name_and_owner_and_deleted(session, name, owner, deleted, user_model): + stmt = select(Repository).where(Repository.deprecated == false()).where(Repository.deleted == deleted) + if owner is not None: + stmt = stmt.where(user_model.username == owner) + stmt = stmt.where(Repository.user_id == user_model.id) + if name is not None: + stmt = stmt.where(Repository.name == name) + stmt = stmt.order_by(Repository.name).limit(1) + return session.scalars(stmt).first() diff --git a/lib/tool_shed/managers/users.py b/lib/tool_shed/managers/users.py index 3514c6e76cf8..ef1ca01ff6f5 100644 --- a/lib/tool_shed/managers/users.py +++ b/lib/tool_shed/managers/users.py @@ -1,5 +1,7 @@ from typing import List +from sqlalchemy import select + from galaxy.exceptions import RequestParameterInvalidException from galaxy.model.base import transaction from galaxy.security.validate_user_input import ( @@ -18,11 +20,7 @@ def index(app: ToolShedApp, deleted: bool) -> List[ApiUser]: users: List[ApiUser] = [] - for user in ( - app.model.context.query(app.model.User) - .filter(app.model.User.table.c.deleted == deleted) - .order_by(app.model.User.table.c.username) - ): + for user in get_users_by_deleted(app.model.context, app.model.User, deleted): users.append(get_api_user(app, user)) return users @@ -77,3 +75,8 @@ def _validate(trans: ProvidesUserContext, email: str, password: str, confirm: st ) ).rstrip() return message + + +def get_users_by_deleted(session, user_model, deleted): + stmt = select(user_model).where(user_model.deleted == deleted).order_by(user_model.username) + return session.scalars(stmt) diff --git a/lib/tool_shed/metadata/repository_metadata_manager.py b/lib/tool_shed/metadata/repository_metadata_manager.py index 9ec1b29271b2..03c5562d5c02 100644 --- a/lib/tool_shed/metadata/repository_metadata_manager.py +++ b/lib/tool_shed/metadata/repository_metadata_manager.py @@ -8,9 +8,8 @@ ) from sqlalchemy import ( - and_, false, - or_, + select, ) from galaxy import util @@ -36,7 +35,11 @@ tool_util, ) from tool_shed.util.metadata_util import repository_metadata_by_changeset_revision -from tool_shed.webapp.model import Repository +from tool_shed.webapp.model import ( + Repository, + RepositoryMetadata, + User, +) log = logging.getLogger(__name__) @@ -150,11 +153,7 @@ def handle_repository_elem(self, repository_elem, only_if_compiling_contained_td if suc.tool_shed_is_this_tool_shed(toolshed, trans=self.trans): try: - user = ( - self.sa_session.query(self.app.model.User) - .filter(self.app.model.User.table.c.username == owner) - .one() - ) + user = get_user_by_username(self.sa_session, owner) except Exception: error_message = ( f"Ignoring repository dependency definition for tool shed {toolshed}, name {name}, owner {owner}, " @@ -164,16 +163,7 @@ def handle_repository_elem(self, repository_elem, only_if_compiling_contained_td is_valid = False return repository_dependency_tup, is_valid, error_message try: - repository = ( - self.sa_session.query(self.app.model.Repository) - .filter( - and_( - self.app.model.Repository.table.c.name == name, - self.app.model.Repository.table.c.user_id == user.id, - ) - ) - .one() - ) + repository = get_repository(self.sa_session, name, user.id) except Exception: error_message = f"Ignoring repository dependency definition for tool shed {toolshed}," error_message += f"name {name}, owner {owner}, " @@ -288,8 +278,7 @@ def build_repository_ids_select_field( ): """Generate the current list of repositories for resetting metadata.""" repositories_select_field = SelectField(name=name, multiple=multiple, display=display) - query = self.get_query_for_setting_metadata_on_repositories(my_writable=my_writable, order=True) - for repository in query: + for repository in self.get_repositories_for_setting_metadata(my_writable=my_writable, order=True): owner = str(repository.user.username) option_label = f"{str(repository.name)} ({owner})" option_value = f"{self.app.security.encode_id(repository.id)}" @@ -303,14 +292,7 @@ def _clean_repository_metadata(self, changeset_revisions): # records with the same changeset revision value - no idea how this happens. We'll # assume we can delete the older records, so we'll order by update_time descending and # delete records that have the same changeset_revision we come across later. - for repository_metadata in ( - self.sa_session.query(self.app.model.RepositoryMetadata) - .filter(self.app.model.RepositoryMetadata.table.c.repository_id == self.repository.id) - .order_by( - self.app.model.RepositoryMetadata.table.c.changeset_revision, - self.app.model.RepositoryMetadata.table.c.update_time.desc(), - ) - ): + for repository_metadata in get_repository_metadata(self.sa_session, self.repository.id): changeset_revision = repository_metadata.changeset_revision if changeset_revision not in changeset_revisions: self.sa_session.delete(repository_metadata) @@ -600,9 +582,9 @@ def _get_parent_id(self, id: int, old_id, version, guid, changeset_revisions): # The tool did not change through all of the changeset revisions. return old_id - def get_query_for_setting_metadata_on_repositories(self, my_writable=False, order=True): + def get_repositories_for_setting_metadata(self, my_writable=False, order=True): """ - Return a query containing repositories for resetting metadata. The order parameter + Return a list of repositories for resetting metadata. The order parameter is used for displaying the list of repositories ordered alphabetically for display on a page. When called from the Tool Shed API, order is False. """ @@ -611,46 +593,25 @@ def get_query_for_setting_metadata_on_repositories(self, my_writable=False, orde # repositories. if my_writable: username = self.user.username - clause_list = [] - for repository in self.sa_session.query(self.app.model.Repository).filter( - self.app.model.Repository.table.c.deleted == false() - ): + repo_ids = [] + for repository in get_current_repositories(self.sa_session): # Always reset metadata on all repositories of types repository_suite_definition and # tool_dependency_definition. if repository.type in [rt_util.REPOSITORY_SUITE_DEFINITION, rt_util.TOOL_DEPENDENCY_DEFINITION]: - clause_list.append(self.app.model.Repository.table.c.id == repository.id) + repo_ids.append(repository.id) else: allow_push = repository.allow_push() if allow_push: # Include all repositories that are writable by the current user. allow_push_usernames = allow_push.split(",") if username in allow_push_usernames: - clause_list.append(self.app.model.Repository.table.c.id == repository.id) - if clause_list: - if order: - return ( - self.sa_session.query(self.app.model.Repository) - .filter(or_(*clause_list)) - .order_by(self.app.model.Repository.table.c.name, self.app.model.Repository.table.c.user_id) - ) - else: - return self.sa_session.query(self.app.model.Repository).filter(or_(*clause_list)) + repo_ids.append(repository.id) + if repo_ids: + return get_filtered_repositories(self.sa_session, repo_ids, order) else: - # Return an empty query. - return self.sa_session.query(self.app.model.Repository).filter( - self.app.model.Repository.table.c.id == -1 - ) + return [] else: - if order: - return ( - self.sa_session.query(self.app.model.Repository) - .filter(self.app.model.Repository.table.c.deleted == false()) - .order_by(self.app.model.Repository.table.c.name, self.app.model.Repository.table.c.user_id) - ) - else: - return self.sa_session.query(self.app.model.Repository).filter( - self.app.model.Repository.table.c.deleted == false() - ) + return get_current_repositories(self.sa_session, order) def new_metadata_required_for_utilities(self): """ @@ -1107,3 +1068,36 @@ def _get_changeset_revisions_that_contain_tools(app: "ToolShedApp", repo, reposi if metadata.get("tools", None): changeset_revisions_that_contain_tools.append(changeset_revision) return changeset_revisions_that_contain_tools + + +def get_user_by_username(session, username): + stmt = select(User).where(User.username == username) + return session.execute(stmt).scalar_one() + + +def get_repository(session, name, user_id): + stmt = select(Repository).where(Repository.name == name).where(Repository.user_id == user_id) + return session.execute(stmt).scalar_one() + + +def get_repository_metadata(session, repository_id): + stmt = ( + select(RepositoryMetadata) + .where(RepositoryMetadata.repository_id == repository_id) + .order_by(RepositoryMetadata.changeset_revision, RepositoryMetadata.update_time.desc()) # type: ignore[attr-defined] # mapped attribute + ) + return session.scalars(stmt) + + +def get_current_repositories(session, order=False): + stmt = select(Repository).where(Repository.deleted == false()) + if order: + stmt = stmt.order_by(Repository.name, Repository.user_id) + return session.scalars(stmt) + + +def get_filtered_repositories(session, repo_ids, order): + stmt = select(Repository).where(Repository.in_(repo_ids)) + if order: + stmt = stmt.order_by(Repository.name, Repository.user_id) + return session.scalars(stmt) diff --git a/lib/tool_shed/repository_registry.py b/lib/tool_shed/repository_registry.py index 941bd83285f2..6927b4e68e74 100644 --- a/lib/tool_shed/repository_registry.py +++ b/lib/tool_shed/repository_registry.py @@ -4,6 +4,7 @@ and_, false, or_, + select, ) import tool_shed.repository_types.util as rt_util @@ -12,6 +13,11 @@ metadata_util, ) from tool_shed.webapp import model +from tool_shed.webapp.model import ( + Category, + Repository, + RepositoryMetadata, +) log = logging.getLogger(__name__) @@ -20,7 +26,6 @@ class Registry: def __init__(self, app): log.debug("Loading the repository registry...") self.app = app - self.certified_level_one_clause_list = self.get_certified_level_one_clause_list() # The following lists contain tuples like ( repository.name, repository.user.username, changeset_revision ) # where the changeset_revision entry is always the latest installable changeset_revision.. self.certified_level_one_repository_and_suite_tuples = [] @@ -143,19 +148,15 @@ def edit_category_entry(self, old_name, new_name): self.certified_level_one_viewable_suites_by_category[new_name] = 0 def get_certified_level_one_clause_list(self): - certified_level_one_tuples = [] clause_list = [] - for repository in self.sa_session.query(model.Repository).filter( - and_(model.Repository.table.c.deleted == false(), model.Repository.table.c.deprecated == false()) - ): + for repository in get_repositories(self.sa_session): certified_level_one_tuple = self.get_certified_level_one_tuple(repository) latest_installable_changeset_revision, is_level_one_certified = certified_level_one_tuple if is_level_one_certified: - certified_level_one_tuples.append(certified_level_one_tuple) clause_list.append( and_( - model.RepositoryMetadata.table.c.repository_id == repository.id, - model.RepositoryMetadata.table.c.changeset_revision == latest_installable_changeset_revision, + RepositoryMetadata.repository_id == repository.id, + RepositoryMetadata.changeset_revision == latest_installable_changeset_revision, ) ) return clause_list @@ -234,19 +235,11 @@ def load_repository_and_suite_tuple(self, repository): def load_repository_and_suite_tuples(self): # Load self.certified_level_one_repository_and_suite_tuples and self.certified_level_one_suite_tuples. - for repository in ( - self.sa_session.query(model.Repository) - .join(model.RepositoryMetadata.table) - .filter(or_(*self.certified_level_one_clause_list)) - .join(model.User.table) - ): + clauses = self.get_certified_level_one_clause_list() + for repository in get_certified_repositories_with_user(self.sa_session, clauses, model.User): self.load_certified_level_one_repository_and_suite_tuple(repository) # Load self.repository_and_suite_tuples and self.suite_tuples - for repository in ( - self.sa_session.query(model.Repository) - .filter(and_(model.Repository.table.c.deleted == false(), model.Repository.table.c.deprecated == false())) - .join(model.User.table) - ): + for repository in get_repositories_with_user(self.sa_session, model.User): self.load_repository_and_suite_tuple(repository) def load_viewable_repositories_and_suites_by_category(self): @@ -259,7 +252,7 @@ def load_viewable_repositories_and_suites_by_category(self): self.viewable_suites_by_category = {} self.viewable_valid_repositories_and_suites_by_category = {} self.viewable_valid_suites_by_category = {} - for category in self.sa_session.query(model.Category): + for category in self.sa_session.scalars(select(Category)): category_name = str(category.name) if category not in self.certified_level_one_viewable_repositories_and_suites_by_category: self.certified_level_one_viewable_repositories_and_suites_by_category[category_name] = 0 @@ -393,3 +386,20 @@ def unload_repository_and_suite_tuple(self, repository): if repository.type == rt_util.REPOSITORY_SUITE_DEFINITION: if tuple in self.suite_tuples: self.suite_tuples.remove(tuple) + + +def get_repositories(session): + stmt = select(Repository).where(Repository.deleted == false()).where(Repository.deprecated == false()) + return session.scalars(stmt) + + +def get_repositories_with_user(session, user_model): + stmt = ( + select(Repository).where(Repository.deleted == false()).where(Repository.deprecated == false()).join(user_model) + ) + return session.scalars(stmt) + + +def get_certified_repositories_with_user(session, where_clauses, user_model): + stmt = select(Repository).join(RepositoryMetadata).where(or_(*where_clauses)).join(user_model) + return session.scalars(stmt) diff --git a/lib/tool_shed/test/base/test_db_util.py b/lib/tool_shed/test/base/test_db_util.py index 9081a007a097..03b1bbcc7550 100644 --- a/lib/tool_shed/test/base/test_db_util.py +++ b/lib/tool_shed/test/base/test_db_util.py @@ -1,19 +1,18 @@ import logging -from typing import ( - List, - Optional, -) +from typing import List from sqlalchemy import ( - and_, false, - true, + select, ) import galaxy.model import galaxy.model.tool_shed_install import tool_shed.webapp.model as model -from galaxy.managers.users import get_user_by_username +from galaxy.managers.users import ( + get_user_by_email, + get_user_by_username, +) log = logging.getLogger("test.tool_shed.test_db_util") @@ -30,79 +29,44 @@ def install_session(): return install_session -def delete_obj(obj): - sa_session().delete(obj) - sa_session().flush() - - -def delete_user_roles(user): - for ura in user.roles: - sa_session().delete(ura) - sa_session().flush() - - def flush(obj): sa_session().add(obj) sa_session().flush() def get_all_repositories(): - return sa_session().query(model.Repository).all() + return sa_session().scalars(select(model.Repository)).all() def get_all_installed_repositories(session=None) -> List[galaxy.model.tool_shed_install.ToolShedRepository]: if session is None: session = install_session() - return list( - session.query(galaxy.model.tool_shed_install.ToolShedRepository) - .filter( - and_( - galaxy.model.tool_shed_install.ToolShedRepository.table.c.deleted == false(), - galaxy.model.tool_shed_install.ToolShedRepository.table.c.uninstalled == false(), - galaxy.model.tool_shed_install.ToolShedRepository.table.c.status - == galaxy.model.tool_shed_install.ToolShedRepository.installation_status.INSTALLED, - ) - ) - .all() - ) - - -def get_galaxy_repository_by_name_owner_changeset_revision(repository_name, owner, changeset_revision): - return ( - install_session() - .query(galaxy.model.tool_shed_install.ToolShedRepository) - .filter( - and_( - galaxy.model.tool_shed_install.ToolShedRepository.table.c.name == repository_name, - galaxy.model.tool_shed_install.ToolShedRepository.table.c.owner == owner, - galaxy.model.tool_shed_install.ToolShedRepository.table.c.changeset_revision == changeset_revision, - ) - ) - .first() + ToolShedRepository = galaxy.model.tool_shed_install.ToolShedRepository + stmt = ( + select(ToolShedRepository) + .where(ToolShedRepository.deleted == false()) + .where(ToolShedRepository.uninstalled == false()) + .where(ToolShedRepository.status == ToolShedRepository.installation_status.INSTALLED) ) + return session.scalars(stmt).all() def get_installed_repository_by_id(repository_id): - return ( - install_session() - .query(galaxy.model.tool_shed_install.ToolShedRepository) - .filter(galaxy.model.tool_shed_install.ToolShedRepository.table.c.id == repository_id) - .first() - ) + return install_session().get(galaxy.model.tool_shed_install.ToolShedRepository, repository_id) def get_installed_repository_by_name_owner(repository_name, owner, return_multiple=False, session=None): if session is None: session = install_session() - query = session.query(galaxy.model.tool_shed_install.ToolShedRepository).filter( - and_( - galaxy.model.tool_shed_install.ToolShedRepository.table.c.name == repository_name, - galaxy.model.tool_shed_install.ToolShedRepository.table.c.owner == owner, - ) + ToolShedRepository = galaxy.model.tool_shed_install.ToolShedRepository + stmt = ( + select(ToolShedRepository) + .where(ToolShedRepository.name == repository_name) + .where(ToolShedRepository.owner == owner) ) if return_multiple: - return query.all() - return query.first() + return session.scalars(stmt).all() + return session.scalars(stmt.limit(1)).first() def get_role(user, role_name): @@ -113,68 +77,21 @@ def get_role(user, role_name): def get_repository_role_association(repository_id, role_id): - rra = ( - sa_session() - .query(model.RepositoryRoleAssociation) - .filter( - and_( - model.RepositoryRoleAssociation.table.c.role_id == role_id, - model.RepositoryRoleAssociation.table.c.repository_id == repository_id, - ) - ) - .first() + stmt = ( + select(model.RepositoryRoleAssociation) + .where(model.RepositoryRoleAssociation.role_id == role_id) + .where(model.RepositoryRoleAssociation.repository_id == repository_id) + .limit(1) ) - return rra + return sa_session().scalars(stmt).first() def get_repository_by_id(repository_id): - return sa_session().query(model.Repository).filter(model.Repository.table.c.id == repository_id).first() - - -def get_repository_downloadable_revisions(repository_id): - revisions = ( - sa_session() - .query(model.RepositoryMetadata) - .filter( - and_( - model.RepositoryMetadata.table.c.repository_id == repository_id, - model.RepositoryMetadata.table.c.downloadable == true(), - ) - ) - .all() - ) - return revisions - - -def get_repository_metadata_for_changeset_revision( - repository_id: int, changeset_revision: Optional[str] -) -> model.RepositoryMetadata: - repository_metadata = ( - sa_session() - .query(model.RepositoryMetadata) - .filter( - and_( - model.RepositoryMetadata.table.c.repository_id == repository_id, - model.RepositoryMetadata.table.c.changeset_revision == changeset_revision, - ) - ) - .first() - ) - return repository_metadata - - -def get_role_by_name(role_name): - return sa_session().query(model.Role).filter(model.Role.table.c.name == role_name).first() + return sa_session().get(model.Repository, repository_id) def get_user(email): - return sa_session().query(model.User).filter(model.User.table.c.email == email).first() - - -def mark_obj_deleted(obj): - obj.deleted = True - sa_session().add(obj) - sa_session().flush() + return get_user_by_email(sa_session(), email, model.User) def refresh(obj): @@ -187,25 +104,20 @@ def ga_refresh(obj): def get_repository_by_name_and_owner(name, owner_username, return_multiple=False): owner = get_user_by_username(sa_session(), owner_username, model.User) - repository = ( - sa_session() - .query(model.Repository) - .filter(and_(model.Repository.table.c.name == name, model.Repository.table.c.user_id == owner.id)) - .first() + stmt = ( + select(model.Repository) + .where(model.Repository.name == name) + .where(model.Repository.user_id == owner.id) + .limit(1) ) - return repository + return sa_session().scalars(stmt).first() def get_repository_metadata_by_repository_id_changeset_revision(repository_id, changeset_revision): - repository_metadata = ( - sa_session() - .query(model.RepositoryMetadata) - .filter( - and_( - model.RepositoryMetadata.table.c.repository_id == repository_id, - model.RepositoryMetadata.table.c.changeset_revision == changeset_revision, - ) - ) - .first() + stmt = ( + select(model.RepositoryMetadata) + .where(model.RepositoryMetadata.repository_id == repository_id) + .where(model.RepositoryMetadata.changeset_revision == changeset_revision) + .limit(1) ) - return repository_metadata + return sa_session().scalars(stmt).first() diff --git a/lib/tool_shed/test/base/twilltestcase.py b/lib/tool_shed/test/base/twilltestcase.py index 8c15af3dc210..2d9b5e25c86f 100644 --- a/lib/tool_shed/test/base/twilltestcase.py +++ b/lib/tool_shed/test/base/twilltestcase.py @@ -31,8 +31,8 @@ ) from playwright.sync_api import Page from sqlalchemy import ( - and_, false, + select, ) import galaxy.model.tool_shed_install as galaxy_model @@ -572,20 +572,7 @@ def get_installed_repositories_by_name_owner( def get_installed_repository_for( self, owner: Optional[str] = None, name: Optional[str] = None, changeset: Optional[str] = None ) -> Optional[Dict[str, Any]]: - clause_list = [] - if name is not None: - clause_list.append(galaxy_model.ToolShedRepository.table.c.name == name) - if owner is not None: - clause_list.append(galaxy_model.ToolShedRepository.table.c.owner == owner) - if changeset is not None: - clause_list.append(galaxy_model.ToolShedRepository.table.c.changeset_revision == changeset) - clause_list.append(galaxy_model.ToolShedRepository.table.c.deleted == false()) - clause_list.append(galaxy_model.ToolShedRepository.table.c.uninstalled == false()) - - query = self._installation_target.install_model.context.query(galaxy_model.ToolShedRepository) - if len(clause_list) > 0: - query = query.filter(and_(*clause_list)) - repository = query.one_or_none() + repository = get_installed_repository(self._installation_target.install_model.context, name, owner, changeset) if repository: return repository.to_dict() else: @@ -1500,9 +1487,9 @@ def get_repository_metadata_for_db_object(self, repository: DbRepository): return [metadata_revision for metadata_revision in repository.metadata_revisions] def get_repository_metadata_by_changeset_revision(self, repository_id: int, changeset_revision): - return test_db_util.get_repository_metadata_for_changeset_revision( + return test_db_util.get_repository_metadata_by_repository_id_changeset_revision( repository_id, changeset_revision - ) or test_db_util.get_repository_metadata_for_changeset_revision(repository_id, None) + ) or test_db_util.get_repository_metadata_by_repository_id_changeset_revision(repository_id, None) def get_repository_metadata_revisions(self, repository: Repository) -> List[str]: return [ @@ -2096,3 +2083,17 @@ def _get_tool_panel_section_from_repository_metadata(metadata): tool_panel_section_metadata = metadata["tool_panel_section"] tool_panel_section = tool_panel_section_metadata[tool_guid][0]["name"] return tool_panel_section + + +def get_installed_repository(session, name, owner, changeset): + ToolShedRepository = galaxy_model.ToolShedRepository + stmt = select(ToolShedRepository) + if name is not None: + stmt = stmt.where(ToolShedRepository.name == name) + if owner is not None: + stmt = stmt.where(ToolShedRepository.owner == owner) + if changeset is not None: + stmt = stmt.wehre(ToolShedRepository.changeset_revision == changeset) + stmt = stmt.where(ToolShedRepository.deleted == false()) + stmt = stmt.where(ToolShedRepository.uninstalled == false()) + return session.scalars(stmt).one_or_none() diff --git a/lib/tool_shed/util/commit_util.py b/lib/tool_shed/util/commit_util.py index e3e3148343c1..ab3864f8e98d 100644 --- a/lib/tool_shed/util/commit_util.py +++ b/lib/tool_shed/util/commit_util.py @@ -14,6 +14,7 @@ Union, ) +from sqlalchemy import select from sqlalchemy.sql.expression import null import tool_shed.repository_types.util as rt_util @@ -25,10 +26,10 @@ hg_util, shed_util_common as suc, ) +from tool_shed.webapp.model import Repository if TYPE_CHECKING: from tool_shed.structured_app import ToolShedApp - from tool_shed.webapp.model import Repository log = logging.getLogger(__name__) @@ -99,9 +100,7 @@ def check_file_contents_for_email_alerts(app: "ToolShedApp"): """ sa_session = app.model.session admin_users = app.config.get("admin_users", "").split(",") - for repository in sa_session.query(app.model.Repository).filter( - app.model.Repository.table.c.email_alerts != null() - ): + for repository in get_repositories_with_alerts(sa_session, app.model.Repository): email_alerts = json.loads(repository.email_alerts) for user_email in email_alerts: if user_email in admin_users: @@ -266,3 +265,8 @@ def uncompress(repository, uploaded_file_name, uploaded_file_filename, isgzip=Fa if isbz2: handle_bz2(repository, uploaded_file_name) return uploaded_file_filename.rstrip(".bz2") + + +def get_repositories_with_alerts(session, repository_model): + stmt = select(repository_model).where(repository_model.email_alerts != null()) + return session.scalars(stmt) diff --git a/lib/tool_shed/util/metadata_util.py b/lib/tool_shed/util/metadata_util.py index c41b62b7a7c0..0ca15a8aa919 100644 --- a/lib/tool_shed/util/metadata_util.py +++ b/lib/tool_shed/util/metadata_util.py @@ -5,7 +5,7 @@ TYPE_CHECKING, ) -from sqlalchemy import and_ +from sqlalchemy import select from galaxy.model.base import transaction from galaxy.tool_shed.util.hg_util import ( @@ -30,7 +30,6 @@ def get_all_dependencies(app, metadata_entry, processed_dependency_links=None): encoder = app.security.encode_id value_mapper = {"repository_id": encoder, "id": encoder, "user_id": encoder} metadata = metadata_entry.to_dict(value_mapper=value_mapper, view="element") - db = app.model.session returned_dependencies = [] required_metadata = get_dependencies_for_metadata_revision(app, metadata) if required_metadata is None: @@ -41,7 +40,9 @@ def get_all_dependencies(app, metadata_entry, processed_dependency_links=None): if dependency_link in processed_dependency_links: continue processed_dependency_links.append(dependency_link) - repository = db.query(app.model.Repository).get(app.security.decode_id(dependency_dict["repository_id"])) + repository = app.model.session.get( + app.model.Repository, app.security.decode_id(dependency_dict["repository_id"]) + ) dependency_dict["repository"] = repository.to_dict(value_mapper=value_mapper) if dependency_metadata.includes_tools: dependency_dict["tools"] = dependency_metadata.metadata["tools"] @@ -123,7 +124,7 @@ def get_latest_downloadable_changeset_revision(app, repository): def get_latest_repository_metadata(app, decoded_repository_id, downloadable=False): """Get last metadata defined for a specified repository from the database.""" sa_session = app.model.session - repository = sa_session.query(app.model.Repository).get(decoded_repository_id) + repository = sa_session.get(app.model.Repository, decoded_repository_id) if downloadable: changeset_revision = get_latest_downloadable_changeset_revision(app, repository) else: @@ -258,16 +259,10 @@ def repository_metadata_by_changeset_revision( # Make sure there are no duplicate records, and return the single unique record for the changeset_revision. # Duplicate records were somehow created in the past. The cause of this issue has been resolved, but we'll # leave this method as is for a while longer to ensure all duplicate records are removed. + sa_session = model_mapping.context - all_metadata_records = ( - sa_session.query(model_mapping.RepositoryMetadata) - .filter( - and_( - model_mapping.RepositoryMetadata.table.c.repository_id == id, - model_mapping.RepositoryMetadata.table.c.changeset_revision == changeset_revision, - ) - ) - .all() + all_metadata_records = get_metadata_by_changeset( + sa_session, id, changeset_revision, model_mapping.RepositoryMetadata ) if len(all_metadata_records) > 1: # Delete all records older than the last one updated. @@ -285,7 +280,7 @@ def repository_metadata_by_changeset_revision( def get_repository_metadata_by_id(app, id): """Get repository metadata from the database""" sa_session = app.model.session - return sa_session.query(app.model.RepositoryMetadata).get(app.security.decode_id(id)) + return sa_session.get(app.model.RepositoryMetadata, app.security.decode_id(id)) def get_repository_metadata_by_repository_id_changeset_revision(app, id, changeset_revision, metadata_only=False): @@ -348,3 +343,12 @@ def is_malicious(app, id, changeset_revision, **kwd): if repository_metadata: return repository_metadata.malicious return False + + +def get_metadata_by_changeset(session, repository_id, changeset_revision, repository_metadata_model): + stmt = ( + select(repository_metadata_model) + .where(repository_metadata_model.repository_id == repository_id) + .where(repository_metadata_model.changeset_revision == changeset_revision) + ) + return session.scalars(stmt).all() diff --git a/lib/tool_shed/util/repository_util.py b/lib/tool_shed/util/repository_util.py index 2f242a5d66f0..1cc1449bc40b 100644 --- a/lib/tool_shed/util/repository_util.py +++ b/lib/tool_shed/util/repository_util.py @@ -9,9 +9,12 @@ ) from markupsafe import escape -from sqlalchemy import false +from sqlalchemy import ( + delete, + false, + select, +) from sqlalchemy.orm import joinedload -from sqlalchemy.sql import select import tool_shed.dependencies.repository from galaxy import ( @@ -232,7 +235,7 @@ def create_repository( if category_ids: # Create category associations for category_id in category_ids: - category = sa_session.query(app.model.Category).get(app.security.decode_id(category_id)) + category = sa_session.get(app.model.Category, app.security.decode_id(category_id)) rca = app.model.RepositoryCategoryAssociation(repository, category) sa_session.add(rca) flush_needed = True @@ -335,39 +338,20 @@ def get_repo_info_dict(trans: "ProvidesRepositoriesContext", repository_id, chan def get_repositories_by_category( app: "ToolShedApp", category_id, installable=False, sort_order="asc", sort_key="name", page=None, per_page=25 ): - sa_session = app.model.session - query = ( - sa_session.query(app.model.Repository) - .join( - app.model.RepositoryCategoryAssociation, - app.model.Repository.id == app.model.RepositoryCategoryAssociation.repository_id, - ) - .join(app.model.User, app.model.User.id == app.model.Repository.user_id) - .filter(app.model.RepositoryCategoryAssociation.category_id == category_id) - ) - if installable: - subquery = select(app.model.RepositoryMetadata.table.c.repository_id) - query = query.filter(app.model.Repository.id.in_(subquery)) - if sort_key == "owner": - query = ( - query.order_by(app.model.User.username) - if sort_order == "asc" - else query.order_by(app.model.User.username.desc()) - ) - else: - query = ( - query.order_by(app.model.Repository.name) - if sort_order == "asc" - else query.order_by(app.model.Repository.name.desc()) - ) - if page is not None: - page = int(page) - query = query.limit(per_page) - if page > 1: - query = query.offset((page - 1) * per_page) - resultset = query.all() repositories = [] - for repository in resultset: + for repository in get_repositories( + app.model.session, + app.model.Repository, + app.model.RepositoryCategoryAssociation, + app.model.User, + app.model.RepositoryMetadata, + category_id, + installable, + sort_order, + sort_key, + page, + per_page, + ): default_value_mapper = { "id": app.security.encode_id, "user_id": app.security.encode_id, @@ -396,7 +380,7 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd): repository_owner = repository.user if kwd.get("manage_role_associations_button", False): in_users_list = util.listify(kwd.get("in_users", [])) - in_users = [sa_session.query(app.model.User).get(x) for x in in_users_list] + in_users = [sa_session.get(app.model.User, x) for x in in_users_list] # Make sure the repository owner is always associated with the repostory's admin role. owner_associated = False for user in in_users: @@ -408,7 +392,7 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd): message += "The repository owner must always be associated with the repository's administrator role. " status = "error" in_groups_list = util.listify(kwd.get("in_groups", [])) - in_groups = [sa_session.query(app.model.Group).get(x) for x in in_groups_list] + in_groups = [sa_session.get(app.model.Group, x) for x in in_groups_list] in_repositories = [repository] app.security_agent.set_entity_role_associations( roles=[role], users=in_users, groups=in_groups, repositories=in_repositories @@ -424,20 +408,12 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd): out_users = [] in_groups = [] out_groups = [] - for user in ( - sa_session.query(app.model.User) - .filter(app.model.User.table.c.deleted == false()) - .order_by(app.model.User.table.c.email) - ): + for user in get_current_users(sa_session, app.model.User): if user in [x.user for x in role.users]: in_users.append((user.id, user.email)) else: out_users.append((user.id, user.email)) - for group in ( - sa_session.query(app.model.Group) - .filter(app.model.Group.table.c.deleted == false()) - .order_by(app.model.Group.table.c.name) - ): + for group in get_current_groups(sa_session, app.model.Group): if group in [x.group for x in role.groups]: in_groups.append((group.id, group.name)) else: @@ -467,7 +443,7 @@ def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Op message = None flush_needed = False sa_session = app.model.session - repository = sa_session.query(app.model.Repository).get(app.security.decode_id(id)) + repository = sa_session.get(app.model.Repository, app.security.decode_id(id)) if repository is None: return None, "Unknown repository ID" @@ -483,17 +459,14 @@ def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Op flush_needed = True if "category_ids" in kwds and isinstance(kwds["category_ids"], list): - # Get existing category associations - category_associations = sa_session.query(app.model.RepositoryCategoryAssociation).filter( - app.model.RepositoryCategoryAssociation.table.c.repository_id == app.security.decode_id(id) + # Remove existing category associations + delete_repository_category_associations( + sa_session, app.model.RepositoryCategoryAssociation, app.security.decode_id(id) ) - # Remove all of them - for rca in category_associations: - sa_session.delete(rca) # Then (re)create category associations for category_id in kwds["category_ids"]: - category = sa_session.query(app.model.Category).get(app.security.decode_id(category_id)) + category = sa_session.get(app.model.Category, app.security.decode_id(category_id)) if category: rca = app.model.RepositoryCategoryAssociation(repository, category) sa_session.add(rca) @@ -562,6 +535,66 @@ def validate_repository_name(app: "ToolShedApp", name, user): return "" +def get_repositories( + session, + repository_model, + repository_category_assoc_model, + user_model, + repository_metadata_model, + category_id, + installable, + sort_order, + sort_key, + page, + per_page, +): + Repository = repository_model + RepositoryCategoryAssociation = repository_category_assoc_model + User = user_model + RepositoryMetadata = repository_metadata_model + + stmt = ( + select(Repository) + .join( + RepositoryCategoryAssociation, + Repository.id == RepositoryCategoryAssociation.repository_id, + ) + .join(User, User.id == Repository.user_id) + .where(RepositoryCategoryAssociation.category_id == category_id) + ) + if installable: + stmt1 = select(RepositoryMetadata.repository_id) + stmt = stmt.where(Repository.id.in_(stmt1)) + if sort_key == "owner": + stmt = stmt.order_by(User.username) + else: + stmt = stmt.order_by(Repository.name) + if sort_order == "desc": + stmt = stmt.desc() + if page is not None: + page = int(page) + stmt = stmt.limit(per_page) + if page > 1: + stmt = stmt.offset((page - 1) * per_page) + + return session.scalars(stmt) + + +def get_current_users(session, user_model): + stmt = select(user_model).where(user_model.deleted == false()).order_by(user_model.email) + return session.scalars(stmt) + + +def get_current_groups(session, group_model): + stmt = select(group_model).where(group_model.deleted == false()).order_by(group_model.name) + return session.scalars(stmt) + + +def delete_repository_category_associations(session, repository_category_assoc_model, repository_id): + stmt = delete(repository_category_assoc_model).where(repository_category_assoc_model.repository_id == repository_id) + return session.execute(stmt) + + __all__ = ( "change_repository_name_in_hgrc_file", "create_or_update_tool_shed_repository", diff --git a/lib/tool_shed/util/search_util.py b/lib/tool_shed/util/search_util.py index 243d16a809c8..02e65c3f8288 100644 --- a/lib/tool_shed/util/search_util.py +++ b/lib/tool_shed/util/search_util.py @@ -1,8 +1,8 @@ import logging from sqlalchemy import ( - and_, false, + select, true, ) @@ -105,16 +105,7 @@ def search_repository_metadata(app, exact_matches_checked, tool_ids="", tool_nam match_tuples = [] ok = True if tool_ids or tool_names or tool_versions: - for repository_metadata in ( - sa_session.query(app.model.RepositoryMetadata) - .filter(app.model.RepositoryMetadata.table.c.includes_tools == true()) - .join(app.model.Repository) - .filter( - and_( - app.model.Repository.table.c.deleted == false(), app.model.Repository.table.c.deprecated == false() - ) - ) - ): + for repository_metadata in get_metadata(sa_session, app.model.RepositoryMetadata, app.model.Repository): metadata = repository_metadata.metadata if metadata: tools = metadata.get("tools", []) @@ -221,3 +212,14 @@ def search_repository_metadata(app, exact_matches_checked, tool_ids="", tool_nam else: ok = False return ok, match_tuples + + +def get_metadata(session, repository_metadata_model, repository_model): + stmt = ( + select(repository_metadata_model) + .where(repository_metadata_model.includes_tools == true()) + .join(repository_model) + .where(repository_model.deleted == false()) + .where(repository_model.deprecated == false()) + ) + return session.scalars(stmt) diff --git a/lib/tool_shed/util/shed_index.py b/lib/tool_shed/util/shed_index.py index 4035a8c5c5c3..bd823b6cb094 100644 --- a/lib/tool_shed/util/shed_index.py +++ b/lib/tool_shed/util/shed_index.py @@ -5,6 +5,10 @@ hg, ui, ) +from sqlalchemy import ( + false, + select, +) from whoosh.writing import AsyncWriter import tool_shed.webapp.model.mapping as ts_mapping @@ -91,21 +95,11 @@ def get_repos(sa_session, file_path, hgweb_config_dir, **kwargs): """ hgwcm = hgweb_config_manager hgwcm.hgweb_config_dir = hgweb_config_dir - # Do not index deleted, deprecated, or "tool_dependency_definition" type repositories. - q = ( - sa_session.query(model.Repository) - .filter_by(deleted=False) - .filter_by(deprecated=False) - .order_by(model.Repository.update_time.desc()) - ) - q = q.filter(model.Repository.type != "tool_dependency_definition") - for repo in q: + for repo in get_repositories_for_indexing(sa_session): category_names = [] - for rca in sa_session.query(model.RepositoryCategoryAssociation).filter( - model.RepositoryCategoryAssociation.repository_id == repo.id - ): - for category in sa_session.query(model.Category).filter(model.Category.id == rca.category.id): - category_names.append(category.name.lower()) + for rca in get_repo_cat_associations(sa_session, repo.id): + category = sa_session.get(model.Category, rca.category.id) + category_names.append(category.name.lower()) categories = (",").join(category_names) repo_id = repo.id name = repo.name @@ -118,7 +112,7 @@ def get_repos(sa_session, file_path, hgweb_config_dir, **kwargs): repo_owner_username = "" if repo.user_id is not None: - user = sa_session.query(model.User).filter(model.User.id == repo.user_id).one() + user = sa_session.get(model.User, repo.user_id) repo_owner_username = user.username.lower() last_updated = pretty_print_time_interval(repo.update_time) @@ -196,3 +190,23 @@ def load_one_dir(path): ) tools_in_dir.append(tool) return tools_in_dir + + +def get_repositories_for_indexing(session): + # Do not index deleted, deprecated, or "tool_dependency_definition" type repositories. + Repository = model.Repository + stmt = ( + select(Repository) + .where(Repository.deleted == false()) + .where(Repository.deprecated == false()) + .where(Repository.type != "tool_dependency_definition") + .order_by(Repository.update_time.desc()) + ) + return session.scalars(stmt) + + +def get_repo_cat_associations(session, repository_id): + stmt = select(model.RepositoryCategoryAssociation).where( + model.RepositoryCategoryAssociation.repository_id == repository_id + ) + return session.scalars(stmt) diff --git a/lib/tool_shed/util/shed_util_common.py b/lib/tool_shed/util/shed_util_common.py index 548f8ede40e4..3f107c652ab3 100644 --- a/lib/tool_shed/util/shed_util_common.py +++ b/lib/tool_shed/util/shed_util_common.py @@ -5,10 +5,10 @@ import string from typing import TYPE_CHECKING -import sqlalchemy.orm.exc from sqlalchemy import ( - and_, false, + func, + select, true, ) @@ -89,46 +89,41 @@ def count_repositories_in_category(app: "ToolShedApp", category_id: str) -> int: - sa_session = app.model.session - return ( - sa_session.query(app.model.RepositoryCategoryAssociation) - .filter(app.model.RepositoryCategoryAssociation.table.c.category_id == app.security.decode_id(category_id)) - .count() + stmt = ( + select(func.count()) + .select_from(app.model.RepositoryCategoryAssociation) + .where(app.model.RepositoryCategoryAssociation.category_id == app.security.decode_id(category_id)) ) + return app.model.session.scalar(stmt) def get_categories(app: "ToolShedApp"): """Get all categories from the database.""" sa_session = app.model.session - return ( - sa_session.query(app.model.Category) - .filter(app.model.Category.table.c.deleted == false()) - .order_by(app.model.Category.table.c.name) - .all() - ) + stmt = select(app.model.Category).where(app.model.Category.deleted == false()).order_by(app.model.Category.name) + return sa_session.scalars(stmt).all() def get_category(app: "ToolShedApp", id: str): """Get a category from the database.""" sa_session = app.model.session - return sa_session.query(app.model.Category).get(app.security.decode_id(id)) + return sa_session.get(app.model.Category, app.security.decode_id(id)) def get_category_by_name(app: "ToolShedApp", name: str): """Get a category from the database via name.""" sa_session = app.model.session - try: - return sa_session.query(app.model.Category).filter_by(name=name).one() - except sqlalchemy.orm.exc.NoResultFound: - return None + stmt = select(app.model.Category).filter_by(name=name).limit(1) + return sa_session.scalars(stmt).first() def get_repository_categories(app, id): """Get categories of a repository on the tool shed side from the database via id""" sa_session = app.model.session - return sa_session.query(app.model.RepositoryCategoryAssociation).filter( - app.model.RepositoryCategoryAssociation.table.c.repository_id == app.security.decode_id(id) + stmt = select(app.model.RepositoryCategoryAssociation).where( + app.model.RepositoryCategoryAssociation.repository_id == app.security.decode_id(id) ) + return sa_session.scalars(stmt).all() def get_repository_file_contents(app, file_path, repository_id, is_admin=False): @@ -336,9 +331,7 @@ def handle_email_alerts(app, host, repository, content_alert_str="", new_repo_al subject = f"Galaxy tool shed alert for new repository named {str(repository.name)}" subject = subject[:80] email_alerts = [] - for user in sa_session.query(app.model.User).filter( - and_(app.model.User.table.c.deleted == false(), app.model.User.table.c.new_repo_alert == true()) - ): + for user in get_users_with_repo_alert(sa_session.query, app.model.User): if admin_only: if user.email in admin_users: email_alerts.append(user.email) @@ -468,3 +461,8 @@ def tool_shed_is_this_tool_shed(toolshed_base_url, trans=None): "set_image_paths", "tool_shed_is_this_tool_shed", ) + + +def get_users_with_repo_alert(session, user_model): + stmt = select(user_model).where(user_model.deleted == false()).where(user_model.new_repo_alert == true()) + return session.scalars(stmt) diff --git a/lib/tool_shed/webapp/api/groups.py b/lib/tool_shed/webapp/api/groups.py index 72329e20b787..14228462e675 100644 --- a/lib/tool_shed/webapp/api/groups.py +++ b/lib/tool_shed/webapp/api/groups.py @@ -4,6 +4,8 @@ Dict, ) +from sqlalchemy import select + from galaxy import ( util, web, @@ -24,6 +26,13 @@ ) from tool_shed.managers import groups from tool_shed.structured_app import ToolShedApp +from tool_shed.webapp.model import ( + Category, + Repository, + RepositoryCategoryAssociation, + RepositoryMetadata, + User, +) from . import BaseShedAPIController log = logging.getLogger(__name__) @@ -109,22 +118,14 @@ def _populate(self, trans, group): Turn the given group information from DB into a dict and add other characteristics like members and repositories. """ - model = trans.app.model group_dict = group.to_dict(view="collection", value_mapper=self.__get_value_mapper(trans)) group_members = [] group_repos = [] total_downloads = 0 for uga in group.users: - user = trans.sa_session.query(model.User).filter(model.User.table.c.id == uga.user_id).one() + user = trans.sa_session.get(User, uga.user_id) user_repos_count = 0 - for repo in ( - trans.sa_session.query(model.Repository) - .filter(model.Repository.table.c.user_id == uga.user_id) - .join(model.RepositoryMetadata.table) - .join(model.User.table) - .outerjoin(model.RepositoryCategoryAssociation.table) - .outerjoin(model.Category.table) - ): + for repo in get_user_repositories(trans.sa_session, uga.user_id): categories = [] for rca in repo.categories: cat_dict = dict(name=rca.category.name, id=trans.app.security.encode_id(rca.category.id)) @@ -169,3 +170,15 @@ def _populate(self, trans, group): group_dict["total_repos"] = len(group_repos) group_dict["total_downloads"] = total_downloads return group_dict + + +def get_user_repositories(session, user_id): + stmt = ( + select(Repository) + .where(Repository.user_id == user_id) + .join(RepositoryMetadata) + .join(User) + .outerjoin(RepositoryCategoryAssociation) + .outerjoin(Category) + ) + return session.scalars(stmt) diff --git a/lib/tool_shed/webapp/api/repositories.py b/lib/tool_shed/webapp/api/repositories.py index 4499534c0cdf..740f3a995421 100644 --- a/lib/tool_shed/webapp/api/repositories.py +++ b/lib/tool_shed/webapp/api/repositories.py @@ -427,9 +427,8 @@ def handle_repository(trans, repository, results): updating_installed_repository=False, persist=False, ) - query = rmm.get_query_for_setting_metadata_on_repositories(my_writable=my_writable, order=False) # First reset metadata on all repositories of type repository_dependency_definition. - for repository in query: + for repository in rmm.get_repositories_for_setting_metadata(my_writable=my_writable, order=False): encoded_id = trans.security.encode_id(repository.id) if encoded_id in encoded_ids_to_skip: log.debug( @@ -440,7 +439,7 @@ def handle_repository(trans, repository, results): elif repository.type == rt_util.TOOL_DEPENDENCY_DEFINITION and repository.id not in handled_repository_ids: results = handle_repository(trans, repository, results) # Now reset metadata on all remaining repositories. - for repository in query: + for repository in rmm.get_repositories_for_setting_metadata(my_writable=my_writable, order=False): encoded_id = trans.security.encode_id(repository.id) if encoded_id in encoded_ids_to_skip: log.debug( diff --git a/lib/tool_shed/webapp/api/repository_revisions.py b/lib/tool_shed/webapp/api/repository_revisions.py index 1b413827aced..be3b358adedf 100644 --- a/lib/tool_shed/webapp/api/repository_revisions.py +++ b/lib/tool_shed/webapp/api/repository_revisions.py @@ -4,7 +4,7 @@ Dict, ) -from sqlalchemy import and_ +from sqlalchemy import select from galaxy import ( util, @@ -16,6 +16,7 @@ metadata_util, repository_util, ) +from tool_shed.webapp.model import RepositoryMetadata from . import BaseShedAPIController log = logging.getLogger(__name__) @@ -39,32 +40,15 @@ def index(self, trans, **kwd): Displays a collection (list) of repository revisions. """ # Example URL: http://localhost:9009/api/repository_revisions - repository_metadata_dicts = [] - # Build up an anded clause list of filters. - clause_list = [] - # Filter by downloadable if received. downloadable = kwd.get("downloadable", None) - if downloadable is not None: - clause_list.append(trans.model.RepositoryMetadata.table.c.downloadable == util.asbool(downloadable)) - # Filter by malicious if received. malicious = kwd.get("malicious", None) - if malicious is not None: - clause_list.append(trans.model.RepositoryMetadata.table.c.malicious == util.asbool(malicious)) - # Filter by missing_test_components if received. missing_test_components = kwd.get("missing_test_components", None) - if missing_test_components is not None: - clause_list.append( - trans.model.RepositoryMetadata.table.c.missing_test_components == util.asbool(missing_test_components) - ) - # Filter by includes_tools if received. includes_tools = kwd.get("includes_tools", None) - if includes_tools is not None: - clause_list.append(trans.model.RepositoryMetadata.table.c.includes_tools == util.asbool(includes_tools)) - for repository_metadata in ( - trans.sa_session.query(trans.app.model.RepositoryMetadata) - .filter(and_(*clause_list)) - .order_by(trans.app.model.RepositoryMetadata.table.c.repository_id.desc()) - ): + repository_metadata_dicts = [] + all_repository_metadata = get_repository_metadata( + trans.sa_session, downloadable, malicious, missing_test_components, includes_tools + ) + for repository_metadata in all_repository_metadata: repository_metadata_dict = repository_metadata.to_dict( view="collection", value_mapper=self.__get_value_mapper(trans) ) @@ -222,3 +206,17 @@ def update(self, trans, payload, **kwd): controller="repository_revisions", action="show", id=repository_metadata_id ) return repository_metadata_dict + + +def get_repository_metadata(session, downloadable, malicious, missing_test_components, includes_tools): + stmt = select(RepositoryMetadata) + if downloadable is not None: + stmt = stmt.where(RepositoryMetadata.downloadable == util.asbool(downloadable)) + if malicious is not None: + stmt = stmt.where(RepositoryMetadata.malicious == util.asbool(malicious)) + if missing_test_components is not None: + stmt = stmt.where(RepositoryMetadata.missing_test_components == util.asbool(missing_test_components)) + if includes_tools is not None: + stmt = stmt.where(RepositoryMetadata.includes_tools == util.asbool(includes_tools)) + stmt = stmt.order_by(RepositoryMetadata.repository_id.desc()) + return session.scalars(stmt) diff --git a/lib/tool_shed/webapp/controllers/repository.py b/lib/tool_shed/webapp/controllers/repository.py index b3c1725e8507..ce313c9d0821 100644 --- a/lib/tool_shed/webapp/controllers/repository.py +++ b/lib/tool_shed/webapp/controllers/repository.py @@ -12,9 +12,13 @@ patch, ) from sqlalchemy import ( - and_, false, null, + select, +) +from toolshed.webapp.model import ( + Repository, + RepositoryMetadata, ) import tool_shed.grids.repository_grids as repository_grids @@ -56,6 +60,10 @@ from tool_shed.util.web_util import escape from tool_shed.utility_containers import ToolShedUtilityContainerManager from tool_shed.webapp.framework.decorators import require_login +from tool_shed.webapp.model import ( + Category, + RepositoryCategoryAssociation, +) from tool_shed.webapp.util import ratings_util log = logging.getLogger(__name__) @@ -1498,7 +1506,7 @@ def index(self, trans, **kwd): message = escape(kwd.get("message", "")) status = kwd.get("status", "done") # See if there are any RepositoryMetadata records since menu items require them. - repository_metadata = trans.sa_session.query(trans.model.RepositoryMetadata).first() + repository_metadata = get_first_repository_metadata(trans.sa_session) current_user = trans.user # TODO: move the following to some in-memory register so these queries can be done once # at startup. The in-memory register can then be managed during the current session. @@ -1515,9 +1523,7 @@ def index(self, trans, **kwd): if current_user.active_repositories: can_administer_repositories = True else: - for repository in trans.sa_session.query(trans.model.Repository).filter( - trans.model.Repository.table.c.deleted == false() - ): + for repository in get_current_repositories(trans.sa_session): if trans.app.security_agent.user_can_administer_repository(current_user, repository): can_administer_repositories = True break @@ -1615,16 +1621,7 @@ def manage_email_alerts(self, trans, **kwd): checked = new_repo_alert_checked or (user and user.new_repo_alert) new_repo_alert_check_box = CheckboxField("new_repo_alert", value=checked) email_alert_repositories = [] - for repository in ( - trans.sa_session.query(trans.model.Repository) - .filter( - and_( - trans.model.Repository.table.c.deleted == false(), - trans.model.Repository.table.c.email_alerts != null(), - ) - ) - .order_by(trans.model.Repository.table.c.name) - ): + for repository in get_current_email_alert_repositories(trans.sa_session): if user.email in repository.email_alerts: email_alert_repositories.append(repository) return trans.fill_template( @@ -1692,8 +1689,8 @@ def manage_repository(self, trans, id, **kwd): if category_ids: # Create category associations for category_id in category_ids: - category = trans.sa_session.query(trans.model.Category).get(trans.security.decode_id(category_id)) - rca = trans.app.model.RepositoryCategoryAssociation(repository, category) + category = trans.sa_session.get(Category, trans.security.decode_id(category_id)) + rca = RepositoryCategoryAssociation(repository, category) trans.sa_session.add(rca) with transaction(trans.sa_session): trans.sa_session.commit() @@ -1707,7 +1704,7 @@ def manage_repository(self, trans, id, **kwd): user_ids = util.listify(allow_push) usernames = [] for user_id in user_ids: - user = trans.sa_session.query(trans.model.User).get(trans.security.decode_id(user_id)) + user = trans.sa_session.get(trans.model.User, trans.security.decode_id(user_id)) usernames.append(user.username) usernames = ",".join(usernames) repository.set_allow_push(usernames, remove_auth=remove_auth) @@ -1738,7 +1735,7 @@ def manage_repository(self, trans, id, **kwd): else: current_allow_push_list = [] options = [] - for user in trans.sa_session.query(trans.model.User): + for user in trans.sa_session.scalars(select(trans.model.User)): if user.username not in current_allow_push_list: options.append(user) for obj in options: @@ -2660,3 +2657,23 @@ def validate_changeset_revision(self, trans, changeset_revision, repository_id): status="error", ) ) + + +def get_first_repository_metadata(session): + stmt = select(RepositoryMetadata).limit(1) + return session.select(stmt).first() + + +def get_current_repositories(session): + stmt = select(Repository).where(Repository.deleted == false()) + return session.scalars(stmt) + + +def get_current_email_alert_repositories(session): + stmt = ( + select(Repository) + .where(Repository.deleted == false()) + .where(Repository.email_alerts != null()) + .order_by(Repository.name) + ) + return session.scalars(stmt) diff --git a/lib/tool_shed/webapp/controllers/user.py b/lib/tool_shed/webapp/controllers/user.py index 1091647bc8e1..2740ee9f6c7d 100644 --- a/lib/tool_shed/webapp/controllers/user.py +++ b/lib/tool_shed/webapp/controllers/user.py @@ -2,13 +2,13 @@ import socket from markupsafe import escape -from sqlalchemy import func from galaxy import ( util, web, ) from galaxy.managers.api_keys import ApiKeyManager +from galaxy.managers.users import get_user_by_email from galaxy.model.base import transaction from galaxy.security.validate_user_input import ( validate_email, @@ -247,18 +247,10 @@ def reset_password(self, trans, email=None, **kwd): "Please check your email account for more instructions. " "If you do not receive an email shortly, please contact an administrator." % (escape(email)) ) - reset_user = ( - trans.sa_session.query(trans.app.model.User) - .filter(trans.app.model.User.table.c.email == email) - .first() - ) + reset_user = get_user_by_email(trans.sa_session, email, trans.app.model.User) if not reset_user: # Perform a case-insensitive check only if the user wasn't found - reset_user = ( - trans.sa_session.query(trans.app.model.User) - .filter(func.lower(trans.app.model.User.table.c.email) == func.lower(email)) - .first() - ) + reset_user = get_user_by_email(trans.sa_session, email, trans.app.model.User, False) if reset_user: prt = trans.app.model.PasswordResetToken(reset_user) trans.sa_session.add(prt) @@ -291,7 +283,7 @@ def manage_user_info(self, trans, cntrller, **kwd): params = util.Params(kwd) user_id = params.get("id", None) if user_id: - user = trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(user_id)) + user = trans.sa_session.get(trans.app.model.User, trans.security.decode_id(user_id)) else: user = trans.user if not user: @@ -336,7 +328,7 @@ def edit_username(self, trans, cntrller, **kwd): status = params.get("status", "done") user_id = params.get("user_id", None) if user_id and is_admin: - user = trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(user_id)) + user = trans.sa_session.get(trans.app.model.User, trans.security.decode_id(user_id)) else: user = trans.user if user and params.get("change_username_button", False): @@ -371,7 +363,7 @@ def edit_info(self, trans, cntrller, **kwd): status = params.get("status", "done") user_id = params.get("user_id", None) if user_id and is_admin: - user = trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(user_id)) + user = trans.sa_session.get(trans.app.model.User, trans.security.decode_id(user_id)) elif user_id and (not trans.user or trans.user.id != trans.security.decode_id(user_id)): message = "Invalid user id" status = "error" @@ -422,8 +414,8 @@ def edit_info(self, trans, cntrller, **kwd): # Edit user information - webapp MUST BE 'galaxy' user_type_fd_id = params.get("user_type_fd_id", "none") if user_type_fd_id not in ["none"]: - user_type_form_definition = trans.sa_session.query(trans.app.model.FormDefinition).get( - trans.security.decode_id(user_type_fd_id) + user_type_form_definition = trans.sa_session.get( + trans.app.model.FormDefinition, trans.security.decode_id(user_type_fd_id) ) elif user.values: user_type_form_definition = user.values.form_definition