Skip to content

Commit

Permalink
Drop old repositories, group repo functions in one module
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Sep 20, 2023
1 parent 077a8b3 commit 1175c77
Show file tree
Hide file tree
Showing 35 changed files with 218 additions and 348 deletions.
5 changes: 2 additions & 3 deletions lib/galaxy/jobs/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from galaxy.jobs.mapper import JobNotReadyException
from galaxy.model.base import transaction
from galaxy.model.repositories.job import JobRepository
from galaxy.model.repositories import get_jobs_to_check_at_startup
from galaxy.structured_app import MinimalManagerApp
from galaxy.util import unicodify
from galaxy.util.custom_logging import get_logger
Expand Down Expand Up @@ -275,8 +275,7 @@ def __check_jobs_at_startup(self):
"""
with self.sa_session() as session, session.begin():
try:
job_repo = JobRepository(self.sa_session())
for job in job_repo.get_jobs_to_check_at_startup(self.track_jobs_in_database, self.app.config):
for job in get_jobs_to_check_at_startup(session, self.track_jobs_in_database, self.app.config):
with session.begin_nested():
self._check_job_at_startup(job)
finally:
Expand Down
23 changes: 12 additions & 11 deletions lib/galaxy/managers/api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from galaxy.model import User
from galaxy.model.base import transaction
from galaxy.model.repositories.api_keys import APIKeysRepository
from galaxy.model.repositories import (
get_api_key,
mark_all_api_keys_as_deleted,
)
from galaxy.structured_app import BasicSharedApp

if TYPE_CHECKING:
Expand All @@ -15,20 +18,19 @@
class ApiKeyManager:
def __init__(self, app: BasicSharedApp):
self.app = app
self.apikeys_repo = APIKeysRepository(self.app.model.context)
self.session = self.app.model.context

def get_api_key(self, user: User) -> Optional["APIKeys"]:
return self.apikeys_repo.get_api_key(user.id)
return get_api_key(self.session, user.id)

def create_api_key(self, user: User) -> "APIKeys":
guid = self.app.security.get_new_guid()
new_key = self.app.model.APIKeys()
new_key.user_id = user.id
new_key.key = guid
sa_session = self.app.model.context
sa_session.add(new_key)
with transaction(sa_session):
sa_session.commit()
self.session.add(new_key)
with transaction(self.session):
self.session.commit()
return new_key

def get_or_create_api_key(self, user: User) -> str:
Expand All @@ -43,7 +45,6 @@ def delete_api_key(self, user: User) -> None:
"""Marks the current user API key as deleted."""
# Before it was possible to create multiple API keys for the same user although they were not considered valid
# So all non-deleted keys are marked as deleted for backward compatibility
self.apikeys_repo.mark_all_as_deleted(user.id)
sa_session = self.app.model.context
with transaction(sa_session):
sa_session.commit()
mark_all_api_keys_as_deleted(self.session, user.id)
with transaction(self.session):
self.session.commit()
4 changes: 2 additions & 2 deletions lib/galaxy/managers/dbkeys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from galaxy.model import HistoryDatasetAssociation
from galaxy.model.repositories.hda import HistoryDatasetAssociationRepository as HDARepository
from galaxy.model.repositories import get_len_files_by_history
from galaxy.util import (
galaxy_directory,
sanitize_lists_to_string,
Expand Down Expand Up @@ -98,7 +98,7 @@ def get_genome_build_names(self, trans=None):
# It does allow one-off, history specific dbkeys to be created by a user. But we are not filtering,
# so a len file will be listed twice (as the build name and again as dataset name),
# if custom dbkey creation/conversion occurred within the current history.
datasets = HDARepository(trans.sa_session).get_len_files_by_history(trans.history.id)
datasets = get_len_files_by_history(trans.sa_session, trans.history.id)
for dataset in datasets:
rval.append((dataset.dbkey, f"{dataset.name} ({dataset.dbkey}) [History]"))
user = trans.user
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/group_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from galaxy.exceptions import ObjectNotFound
from galaxy.managers.context import ProvidesAppContext
from galaxy.model.base import transaction
from galaxy.model.repositories.gra import GroupRoleAssociationRepository
from galaxy.model.repositories import get_group_role
from galaxy.structured_app import MinimalManagerApp

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,7 +76,7 @@ def _get_role(self, trans: ProvidesAppContext, role_id: int) -> model.Role:
def _get_group_role(
self, trans: ProvidesAppContext, group: model.Group, role: model.Role
) -> Optional[model.GroupRoleAssociation]:
return GroupRoleAssociationRepository(trans.sa_session).get_group_role(group, role)
return get_group_role(trans.sa_session, group, role)

def _add_role_to_group(self, trans: ProvidesAppContext, group: model.Group, role: model.Role):
gra = model.GroupRoleAssociation(group, role)
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/group_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from galaxy.managers.context import ProvidesAppContext
from galaxy.model import User
from galaxy.model.base import transaction
from galaxy.model.repositories.uga import UserGroupAssociationRepository
from galaxy.model.repositories import get_group_user
from galaxy.structured_app import MinimalManagerApp

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _get_user(self, trans: ProvidesAppContext, user_id: int) -> model.User:
def _get_group_user(
self, trans: ProvidesAppContext, group: model.Group, user: model.User
) -> Optional[model.UserGroupAssociation]:
return UserGroupAssociationRepository(trans.sa_session).get_group_user(user, group)
return get_group_user(trans.sa_session, user, group)

def _add_user_to_group(self, trans: ProvidesAppContext, group: model.Group, user: model.User):
gra = model.UserGroupAssociation(user, group)
Expand Down
21 changes: 13 additions & 8 deletions lib/galaxy/managers/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
List,
)

from sqlalchemy import select

from galaxy import model
from galaxy.exceptions import (
Conflict,
Expand All @@ -12,10 +14,13 @@
)
from galaxy.managers.base import decode_id
from galaxy.managers.context import ProvidesAppContext
from galaxy.model import Role
from galaxy.model.base import transaction
from galaxy.model.repositories.group import GroupRepository
from galaxy.model.repositories.role import RoleRepository
from galaxy.model.repositories.user import UserRepository
from galaxy.model.repositories import (
get_group_by_name,
get_not_deleted_groups,
get_users_by_ids,
)
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.fields import (
DecodedDatabaseIdField,
Expand All @@ -36,8 +41,7 @@ def index(self, trans: ProvidesAppContext):
Displays a collection (list) of groups.
"""
rval = []
group_repo = GroupRepository(trans.sa_session)
for group in group_repo.get_not_deleted_groups():
for group in get_not_deleted_groups(trans.sa_session):
item = group.to_dict(value_mapper={"id": DecodedDatabaseIdField.encode})
encoded_id = DecodedDatabaseIdField.encode(group.id)
item["url"] = url_for("group", id=encoded_id)
Expand Down Expand Up @@ -103,7 +107,7 @@ def update(self, trans: ProvidesAppContext, group_id: int, payload: Dict[str, An
sa_session.commit()

def _check_duplicated_group_name(self, sa_session: galaxy_scoped_session, group_name: str) -> None:
if GroupRepository(sa_session).get_by_name(group_name):
if get_group_by_name(sa_session, group_name):
raise Conflict(f"A group with name '{group_name}' already exists")

def _get_group(self, sa_session: galaxy_scoped_session, group_id: int) -> model.Group:
Expand All @@ -116,13 +120,14 @@ def _get_users_by_encoded_ids(
self, sa_session: galaxy_scoped_session, encoded_user_ids: List[EncodedDatabaseIdField]
) -> List[model.User]:
user_ids = self._decode_ids(encoded_user_ids)
return UserRepository(sa_session).get_users_by_ids(user_ids)
return get_users_by_ids(sa_session, user_ids)

def _get_roles_by_encoded_ids(
self, sa_session: galaxy_scoped_session, encoded_role_ids: List[EncodedDatabaseIdField]
) -> List[model.Role]:
role_ids = self._decode_ids(encoded_role_ids)
return RoleRepository(sa_session).get_roles_by_ids(role_ids)
stmt = select(Role).where(Role.id.in_(role_ids))
return sa_session.scalars(stmt).all()

def _decode_id(self, encoded_id: EncodedDatabaseIdField) -> int:
return decode_id(self._app, encoded_id)
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from galaxy.model import UserQuotaUsage
from galaxy.model.base import transaction
from galaxy.model.repositories.user import (
from galaxy.model.repositories import (
get_user_by_email,
get_user_by_username,
)
Expand Down
167 changes: 157 additions & 10 deletions lib/galaxy/model/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,164 @@
from typing import Any
from typing import Optional

from sqlalchemy import select
from sqlalchemy import (
cast,
false,
func,
Integer,
null,
or_,
select,
true,
update,
)
from sqlalchemy.orm import Session

from galaxy.model import Base
from galaxy.model import (
APIKeys,
Group,
GroupRoleAssociation,
HistoryDatasetAssociation,
Job,
PageRevision,
Quota,
StoredWorkflowUserShareAssociation,
User,
UserGroupAssociation,
YIELD_PER_ROWS,
)
from galaxy.model.tool_shed_install import ToolShedRepository

MappedType = Base

# all models
def get_all(session: Session, model_class):
return session.scalars(select(model_class)).all()

class ModelRepository:
def __init__(self, session: Session, model_class: MappedType):
self.session = session
self.model_class = model_class

def get_all(self):
return self.session.scalars(select(self.model_class)).all()
# api_keys
def get_api_key(session: Session, user_id: int):
stmt = select(APIKeys).filter_by(user_id=user_id, deleted=False).order_by(APIKeys.create_time.desc()).limit(1)
return session.scalars(stmt).first()


def mark_all_api_keys_as_deleted(session: Session, user_id: int):
stmt = (
update(APIKeys)
.where(APIKeys.user_id == user_id)
.where(APIKeys.deleted == false())
.values(deleted=True)
.execution_options(synchronize_session="evaluate")
)
return session.execute(stmt)


# groups
def get_group_by_name(session: Session, name: str):
stmt = select(Group).filter(Group.name == name).limit(1)
return session.scalars(stmt).first()


def get_not_deleted_groups(session: Session):
stmt = select(Group).where(Group.deleted == false())
return session.scalars(stmt)


# hdas
def get_fasta_hdas_by_history(session: Session, history_id: int):
stmt = (
select(HistoryDatasetAssociation)
.filter_by(history_id=history_id, extension="fasta", deleted=False)
.order_by(HistoryDatasetAssociation.hid.desc())
)
return session.scalars(stmt).all()


def get_len_files_by_history(session: Session, history_id: int):
stmt = select(HistoryDatasetAssociation).filter_by(history_id=history_id, extension="len", deleted=False)
return session.scalars(stmt)


# jobs
def get_jobs_to_check_at_startup(session: Session, track_jobs_in_database: bool, config):
if track_jobs_in_database:
in_list = (Job.states.QUEUED, Job.states.RUNNING, Job.states.STOPPED)
else:
in_list = (Job.states.NEW, Job.states.QUEUED, Job.states.RUNNING)

stmt = (
select(Job)
.execution_options(yield_per=YIELD_PER_ROWS)
.filter(Job.state.in_(in_list) & (Job.handler == config.server_name))
)
if config.user_activation_on:
# Filter out the jobs of inactive users.
stmt = stmt.outerjoin(User).filter(or_((Job.user_id == null()), (User.active == true())))

return session.scalars(stmt)


# page_revisions
def get_page_revision(session: Session, page_id: int):
stmt = select(PageRevision).filter_by(page_id=page_id)
return session.scalars(stmt)


# quotas
def get_quotas(session: Session, deleted: bool = False):
is_deleted = true()
if not deleted:
is_deleted = false()
stmt = select(Quota).where(Quota.deleted == is_deleted)
return session.scalars(stmt)


# roles
def get_group_role(session: Session, group, role) -> Optional[GroupRoleAssociation]:
stmt = (
select(GroupRoleAssociation).where(GroupRoleAssociation.group == group).where(GroupRoleAssociation.role == role)
)
return session.execute(stmt).scalar_one_or_none()


# tool_shed_repositories
def get_tool_shed_repositories(session: Session, **kwd):
stmt = select(ToolShedRepository)
for key, value in kwd.items():
if value is not None:
column = ToolShedRepository.table.c[key]
stmt = stmt.filter(column == value)
stmt = stmt.order_by(ToolShedRepository.name).order_by(cast(ToolShedRepository.ctx_rev, Integer).desc())
return session.scalars(stmt).all()


# users
def get_group_user(session: Session, user, group) -> Optional[UserGroupAssociation]:
stmt = (
select(UserGroupAssociation).where(UserGroupAssociation.user == user).where(UserGroupAssociation.group == group)
)
return session.execute(stmt).scalar_one_or_none()


def get_users_by_ids(session: Session, user_ids):
stmt = select(User).where(User.id.in_(user_ids))
return session.scalars(stmt).all()


# The get_user_by_email and get_user_by_username functions may be called from
# the tool_shed app, which has its own User model, which is different from
# galaxy.model.User. In that case, the tool_shed user model should be passed as
# the model_class argument.
def get_user_by_email(session, email: str, model_class=User):
stmt = select(model_class).filter(model_class.email == email).limit(1)
return session.scalars(stmt).first()


def get_user_by_username(session, username: str, model_class=User):
stmt = select(model_class).filter(model_class.username == username).limit(1)
return session.scalars(stmt).first()


# workflows
def count_stored_workflow_user_assocs(session: Session, user, stored_workflow) -> int:
stmt = select(StoredWorkflowUserShareAssociation).filter_by(user=user, stored_workflow=stored_workflow)
stmt = select(func.count()).select_from(stmt)
return session.scalar(stmt)
27 changes: 0 additions & 27 deletions lib/galaxy/model/repositories/api_keys.py

This file was deleted.

Loading

0 comments on commit 1175c77

Please sign in to comment.