From c6cac90b21d3168876c782f1e4d84512f7ebfa16 Mon Sep 17 00:00:00 2001 From: John Davis Date: Wed, 20 Sep 2023 14:39:26 -0400 Subject: [PATCH] Refactor get_user_by_username, get_user_by_email; use across code base --- lib/galaxy/managers/users.py | 34 +++++-------------- lib/galaxy/visualization/genomes.py | 20 +++++------ lib/galaxy/webapps/galaxy/controllers/user.py | 5 ++- lib/galaxy/webapps/galaxy/services/quotas.py | 2 +- .../webapps/reports/controllers/jobs.py | 4 +-- lib/tool_shed/test/base/test_db_util.py | 7 ++-- .../webapp/controllers/repository.py | 6 ++-- test/integration/test_user_preferences.py | 6 ++-- test/integration/test_vault_extra_prefs.py | 5 ++- test/integration/test_vault_file_source.py | 17 ++++------ 10 files changed, 37 insertions(+), 69 deletions(-) diff --git a/lib/galaxy/managers/users.py b/lib/galaxy/managers/users.py index 4489a00689e6..86879131d12d 100644 --- a/lib/galaxy/managers/users.py +++ b/lib/galaxy/managers/users.py @@ -20,7 +20,6 @@ and_, exc, func, - select, true, ) from sqlalchemy.orm.exc import NoResultFound @@ -38,6 +37,10 @@ ) from galaxy.model import UserQuotaUsage from galaxy.model.base import transaction +from galaxy.model.repositories import ( + get_user_by_email, + get_user_by_username, +) from galaxy.security.validate_user_input import ( VALID_EMAIL_RE, validate_email, @@ -355,7 +358,7 @@ def get_user_by_identity(self, identity): user = None if VALID_EMAIL_RE.match(identity): # VALID_PUBLICNAME and VALID_EMAIL do not overlap, so 'identity' here is an email address - user = self.session().query(self.model_class).filter(self.model_class.table.c.email == identity).first() + user = get_user_by_email(self.session(), identity, self.model_class) if not user: # Try a case-insensitive match on the email user = ( @@ -365,7 +368,7 @@ def get_user_by_identity(self, identity): .first() ) else: - user = self.session().query(self.model_class).filter(self.model_class.table.c.username == identity).first() + user = get_user_by_username(self.session(), identity, self.model_class) return user # ---- current @@ -528,7 +531,7 @@ def __get_activation_token(self, trans, email): """ Check for the activation token. Create new activation token and store it in the database if no token found. """ - user = trans.sa_session.query(self.app.model.User).filter(self.app.model.User.table.c.email == email).first() + user = get_user_by_email(trans.sa_session, email, self.app.model.User) activation_token = user.activation_token if activation_token is None: activation_token = util.hash_util.new_secure_hash_v2(str(random.getrandbits(256))) @@ -572,9 +575,7 @@ def send_reset_email(self, trans, payload, **kwd): return "Failed to produce password reset token. User not found." def get_reset_token(self, trans, email): - reset_user = ( - trans.sa_session.query(self.app.model.User).filter(self.app.model.User.table.c.email == email).first() - ) + reset_user = get_user_by_email(trans.sa_session, email, self.app.model.User) if not reset_user and email != email.lower(): reset_user = ( trans.sa_session.query(self.app.model.User) @@ -618,12 +619,7 @@ def get_or_create_remote_user(self, remote_user_email): return None if getattr(self.app.config, "normalize_remote_user_email", False): remote_user_email = remote_user_email.lower() - user = ( - self.session() - .query(self.app.model.User) - .filter(self.app.model.User.table.c.email == remote_user_email) - .first() - ) + user = get_user_by_email(self.session(), remote_user_email, self.app.model.User) if user: # GVK: June 29, 2009 - This is to correct the behavior of a previous bug where a private # role and default user / history permissions were not set for remote users. When a @@ -838,15 +834,3 @@ def _add_parsers(self): ) self.fn_filter_parsers.update({}) - - -def get_user_by_username(session, user_class, username): - """ - Get a user from the database by username. - (We pass the session and the user_class to accommodate usage from the tool_shed app.) - """ - try: - stmt = select(user_class).filter(user_class.username == username) - return session.execute(stmt).scalar_one() - except Exception: - return None diff --git a/lib/galaxy/visualization/genomes.py b/lib/galaxy/visualization/genomes.py index 95cbb31b4654..6981e4deccec 100644 --- a/lib/galaxy/visualization/genomes.py +++ b/lib/galaxy/visualization/genomes.py @@ -6,13 +6,13 @@ from typing import Dict from bx.seq.twobit import TwoBitFile -from sqlalchemy import select from galaxy.exceptions import ( ObjectNotFound, ReferenceDataError, ) -from galaxy.model.repositories.hda import HistoryDatasetAssociationRepository as hda_repo +from galaxy.model import HistoryDatasetAssociation +from galaxy.model.repositories import get_user_by_username from galaxy.structured_app import StructuredApp from galaxy.util.bunch import Bunch @@ -289,11 +289,12 @@ def chroms(self, trans, dbkey=None, num=None, chrom=None, low=None): Returns a naturally sorted list of chroms/contigs for a given dbkey. Use either chrom or low to specify the starting chrom in the return list. """ + session = trans.sa_session self.check_and_reload() # If there is no dbkey owner, default to current user. dbkey_owner, dbkey = decode_dbkey(dbkey) if dbkey_owner: - dbkey_user = self._get_dbkey_user(trans, dbkey_owner) + dbkey_user = get_user_by_username(session, dbkey_owner) else: dbkey_user = trans.user @@ -309,10 +310,9 @@ def chroms(self, trans, dbkey=None, num=None, chrom=None, low=None): if dbkey in user_keys: dbkey_attributes = user_keys[dbkey] dbkey_name = dbkey_attributes["name"] - _hda_repo = hda_repo(trans.sa_session) # If there's a fasta for genome, convert to 2bit for later use. if "fasta" in dbkey_attributes: - build_fasta = _hda_repo.get(dbkey_attributes["fasta"]) + build_fasta = session.get(HistoryDatasetAssociation, dbkey_attributes["fasta"]) len_file = build_fasta.get_converted_dataset(trans, "len").file_name build_fasta.get_converted_dataset(trans, "twobit") # HACK: set twobit_file to True rather than a file name because @@ -321,7 +321,7 @@ def chroms(self, trans, dbkey=None, num=None, chrom=None, low=None): twobit_file = True # Backwards compatibility: look for len file directly. elif "len" in dbkey_attributes: - len_file = _hda_repo.get(user_keys[dbkey]["len"]).file_name + len_file = session.get(HistoryDatasetAssociation, user_keys[dbkey]["len"]).file_name if len_file: genome = Genome(dbkey, dbkey_name, len_file=len_file, twobit_file=twobit_file) @@ -370,7 +370,7 @@ def reference(self, trans, dbkey, chrom, low, high): # If there is no dbkey owner, default to current user. dbkey_owner, dbkey = decode_dbkey(dbkey) if dbkey_owner: - dbkey_user = self._get_dbkey_user(trans, dbkey_owner) + dbkey_user = get_user_by_username(trans.sa_session, dbkey_owner) else: dbkey_user = trans.user @@ -387,7 +387,7 @@ def reference(self, trans, dbkey, chrom, low, high): else: user_keys = loads(dbkey_user.preferences["dbkeys"]) dbkey_attributes = user_keys[dbkey] - fasta_dataset = hda_repo(trans.sa_session).get(dbkey_attributes["fasta"]) + fasta_dataset = trans.sa_session.get(HistoryDatasetAssociation, dbkey_attributes["fasta"]) msg = fasta_dataset.convert_dataset(trans, "twobit") if msg: return msg @@ -405,7 +405,3 @@ def _get_reference_data(twobit_file_name, chrom, low, high): if chrom in twobit: seq_data = twobit[chrom].get(int(low), int(high)) return GenomeRegion(chrom=chrom, start=low, end=high, sequence=seq_data) - - def _get_dbkey_user(self, trans, dbkey_owner): - stmt = select(trans.app.model.User).filter_by(username=dbkey_owner).limit(1) - return trans.sa_session.scalars(stmt).first() diff --git a/lib/galaxy/webapps/galaxy/controllers/user.py b/lib/galaxy/webapps/galaxy/controllers/user.py index 193255a0b59a..275510cff2ee 100644 --- a/lib/galaxy/webapps/galaxy/controllers/user.py +++ b/lib/galaxy/webapps/galaxy/controllers/user.py @@ -18,6 +18,7 @@ ) from galaxy.exceptions import Conflict from galaxy.managers import users +from galaxy.model.repositories import get_user_by_email from galaxy.security.validate_user_input import ( validate_email, validate_publicname, @@ -293,9 +294,7 @@ def activate(self, trans, **kwd): ) else: # Find the user - user = ( - trans.sa_session.query(trans.app.model.User).filter(trans.app.model.User.table.c.email == email).first() - ) + user = get_user_by_email(trans.sa_session, email) if not user: # Probably wrong email address return trans.show_error_message( diff --git a/lib/galaxy/webapps/galaxy/services/quotas.py b/lib/galaxy/webapps/galaxy/services/quotas.py index 8e7a9d9fff74..17f4630183e3 100644 --- a/lib/galaxy/webapps/galaxy/services/quotas.py +++ b/lib/galaxy/webapps/galaxy/services/quotas.py @@ -119,7 +119,7 @@ def get_user_id(item): try: return trans.security.decode_id(item) except Exception: - return user_repo.get_by_email(item).id + return get_user_by_email(trans.sa_session, item).id def get_group_id(item): try: diff --git a/lib/galaxy/webapps/reports/controllers/jobs.py b/lib/galaxy/webapps/reports/controllers/jobs.py index 1c6e9a9c2ec6..0c4e684cf62e 100644 --- a/lib/galaxy/webapps/reports/controllers/jobs.py +++ b/lib/galaxy/webapps/reports/controllers/jobs.py @@ -18,7 +18,6 @@ and_, not_, or_, - select, ) from galaxy import ( @@ -1307,8 +1306,7 @@ def get_monitor_id(trans, monitor_email): A convenience method to obtain the monitor job id. """ monitor_user_id = None - stmt = select(trans.model.User.id).filter(trans.model.User.email == monitor_email).limit(1) - monitor_row = trans.sa_session.scalars(stmt).first() + monitor_row = get_user_by_email(trans.sa_session, monitor_email) if monitor_row is not None: monitor_user_id = monitor_row[0] return monitor_user_id diff --git a/lib/tool_shed/test/base/test_db_util.py b/lib/tool_shed/test/base/test_db_util.py index 3b4cbce31f16..1f1655ec73a3 100644 --- a/lib/tool_shed/test/base/test_db_util.py +++ b/lib/tool_shed/test/base/test_db_util.py @@ -10,6 +10,7 @@ import galaxy.model import galaxy.model.tool_shed_install import tool_shed.webapp.model as model +from galaxy.model.repositories import get_user_by_username log = logging.getLogger("test.tool_shed.test_db_util") @@ -171,10 +172,6 @@ def get_user(email): return sa_session().query(model.User).filter(model.User.table.c.email == email).first() -def get_user_by_name(username): - return sa_session().query(model.User).filter(model.User.table.c.username == username).first() - - def mark_obj_deleted(obj): obj.deleted = True sa_session().add(obj) @@ -190,7 +187,7 @@ def ga_refresh(obj): def get_repository_by_name_and_owner(name, owner_username, return_multiple=False): - owner = get_user_by_name(owner_username) + owner = get_user_by_username(sa_session(), owner_username, model.User) repository = ( sa_session() .query(model.Repository) diff --git a/lib/tool_shed/webapp/controllers/repository.py b/lib/tool_shed/webapp/controllers/repository.py index e7627292f426..0f0b9548e7f7 100644 --- a/lib/tool_shed/webapp/controllers/repository.py +++ b/lib/tool_shed/webapp/controllers/repository.py @@ -24,8 +24,8 @@ util, web, ) -from galaxy.managers.users import get_user_by_username from galaxy.model.base import transaction +from galaxy.model.repositories import get_user_by_username from galaxy.tool_shed.util import dependency_display from galaxy.tools.repositories import ValidationContext from galaxy.web.form_builder import ( @@ -2293,7 +2293,7 @@ def set_malicious(self, trans, id, ctx_str, **kwd): def sharable_owner(self, trans, owner): """Support for sharable URL for each repository owner's tools, e.g. http://example.org/view/owner.""" try: - user = get_user_by_username(trans.model.session, trans.model.User, owner) + user = get_user_by_username(trans.model.session, owner, trans.model.User) except Exception: user = None if user: @@ -2321,7 +2321,7 @@ def sharable_repository(self, trans, owner, name): else: # If the owner is valid, then show all of their repositories. try: - user = get_user_by_username(trans.model.session, trans.model.User, owner) + user = get_user_by_username(trans.model.session, owner, trans.model.User) except Exception: user = None if user: diff --git a/test/integration/test_user_preferences.py b/test/integration/test_user_preferences.py index 84dd304b643c..f7a365393535 100644 --- a/test/integration/test_user_preferences.py +++ b/test/integration/test_user_preferences.py @@ -7,8 +7,8 @@ get, put, ) -from sqlalchemy import select +from galaxy.model.repositories import get_user_by_email from galaxy_test.driver import integration_util TEST_USER_EMAIL = "test_user_preferences@bx.psu.edu" @@ -19,8 +19,8 @@ def test_user_theme(self): user = self._setup_user(TEST_USER_EMAIL) url = self._api_url(f"users/{user['id']}/theme/test_theme", params=dict(key=self.master_api_key)) app = cast(Any, self._test_driver.app if self._test_driver else None) - stmt = select(app.model.User).filter(app.model.User.email == user["email"]).limit(1) - db_user = app.model.session.scalars(stmt).first() + + db_user = get_user_by_email(app.model.session, user["email"]) # create some initial data put(url) diff --git a/test/integration/test_vault_extra_prefs.py b/test/integration/test_vault_extra_prefs.py index 4599d2decbe5..40b82c0ae04b 100644 --- a/test/integration/test_vault_extra_prefs.py +++ b/test/integration/test_vault_extra_prefs.py @@ -9,8 +9,8 @@ get, put, ) -from sqlalchemy import select +from galaxy.model.repositories import get_user_by_email from galaxy_test.driver import integration_util TEST_USER_EMAIL = "vault_test_user@bx.psu.edu" @@ -133,5 +133,4 @@ def __url(self, action, user): return self._api_url(f"users/{user['id']}/{action}", params=dict(key=self.master_api_key)) def _get_dbuser(self, app, user): - stmt = select(app.model.User).filter(app.model.User.email == user["email"]).limit(1) - return app.model.session.scalars(stmt).first() + return get_user_by_email(app.model.session, user["email"]) diff --git a/test/integration/test_vault_file_source.py b/test/integration/test_vault_file_source.py index ab6295058ff9..a3cf477721a5 100644 --- a/test/integration/test_vault_file_source.py +++ b/test/integration/test_vault_file_source.py @@ -1,8 +1,7 @@ import os import tempfile -from sqlalchemy import select - +from galaxy.model.repositories import get_user_by_email from galaxy.security.vault import UserVaultWrapper from galaxy_test.base import api_asserts from galaxy_test.base.populators import DatasetPopulator @@ -36,7 +35,7 @@ def test_vault_secret_per_user_in_file_source(self): app = self._app with self._different_user(email=self.USER_1_APP_VAULT_ENTRY): - user = self._get_user_by_email(self.USER_1_APP_VAULT_ENTRY) + user = get_user_by_email(self._app.model.session, self.USER_1_APP_VAULT_ENTRY) user_vault = UserVaultWrapper(app.vault, user) # use a valid symlink path so the posix list succeeds user_vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0]) @@ -48,7 +47,7 @@ def test_vault_secret_per_user_in_file_source(self): print(remote_files) with self._different_user(email=self.USER_2_APP_VAULT_ENTRY): - user = self._get_user_by_email(self.USER_2_APP_VAULT_ENTRY) + user = get_user_by_email(self._app.model.session, self.USER_2_APP_VAULT_ENTRY) user_vault = UserVaultWrapper(app.vault, user) # use an invalid symlink path so the posix list fails user_vault.write_secret("posix/root_path", "/invalid/root") @@ -69,7 +68,7 @@ def test_vault_secret_per_app_in_file_source(self): app.vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0]) with self._different_user(email=self.USER_1_APP_VAULT_ENTRY): - user = self._get_user_by_email(self.USER_1_APP_VAULT_ENTRY) + user = get_user_by_email(self._app.model.session, self.USER_1_APP_VAULT_ENTRY) user_vault = UserVaultWrapper(app.vault, user) # use a valid symlink path so the posix list succeeds user_vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0]) @@ -81,7 +80,7 @@ def test_vault_secret_per_app_in_file_source(self): print(remote_files) with self._different_user(email=self.USER_2_APP_VAULT_ENTRY): - user = self._get_user_by_email(self.USER_2_APP_VAULT_ENTRY) + user = get_user_by_email(self._app.model.session, self.USER_2_APP_VAULT_ENTRY) user_vault = UserVaultWrapper(app.vault, user) # use an invalid symlink path so the posix list would fail if used user_vault.write_secret("posix/root_path", "/invalid/root") @@ -95,7 +94,7 @@ def test_vault_secret_per_app_in_file_source(self): def test_upload_file_from_remote_source(self): with self._different_user(email=self.USER_1_APP_VAULT_ENTRY): app = self._app - user = self._get_user_by_email(self.USER_1_APP_VAULT_ENTRY) + user = get_user_by_email(self._app.model.session, self.USER_1_APP_VAULT_ENTRY) user_vault = UserVaultWrapper(app.vault, user) # use a valid symlink path so the posix list succeeds user_vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0]) @@ -120,7 +119,3 @@ def test_upload_file_from_remote_source(self): new_dataset = self.dataset_populator.fetch(payload, assert_ok=True).json()["outputs"][0] content = self.dataset_populator.get_history_dataset_content(history_id, dataset=new_dataset) assert content == "I require access to the vault", content - - def _get_user_by_email(self, email): - stmt = select(self._app.model.User).filter(self._app.model.User.email == email).limit(1) - return self._app.model.session.scalars(stmt).first()