Skip to content

Commit

Permalink
Refactor get_user_by_username, get_user_by_email; use across code base
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Sep 21, 2023
1 parent 9c0c569 commit c6cac90
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 69 deletions.
34 changes: 9 additions & 25 deletions lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
and_,
exc,
func,
select,
true,
)
from sqlalchemy.orm.exc import NoResultFound
Expand All @@ -38,6 +37,10 @@
)
from galaxy.model import UserQuotaUsage
from galaxy.model.base import transaction
from galaxy.model.repositories import (
get_user_by_email,
get_user_by_username,
)
from galaxy.security.validate_user_input import (
VALID_EMAIL_RE,
validate_email,
Expand Down Expand Up @@ -355,7 +358,7 @@ def get_user_by_identity(self, identity):
user = None
if VALID_EMAIL_RE.match(identity):
# VALID_PUBLICNAME and VALID_EMAIL do not overlap, so 'identity' here is an email address
user = self.session().query(self.model_class).filter(self.model_class.table.c.email == identity).first()
user = get_user_by_email(self.session(), identity, self.model_class)
if not user:
# Try a case-insensitive match on the email
user = (
Expand All @@ -365,7 +368,7 @@ def get_user_by_identity(self, identity):
.first()
)
else:
user = self.session().query(self.model_class).filter(self.model_class.table.c.username == identity).first()
user = get_user_by_username(self.session(), identity, self.model_class)
return user

# ---- current
Expand Down Expand Up @@ -528,7 +531,7 @@ def __get_activation_token(self, trans, email):
"""
Check for the activation token. Create new activation token and store it in the database if no token found.
"""
user = trans.sa_session.query(self.app.model.User).filter(self.app.model.User.table.c.email == email).first()
user = get_user_by_email(trans.sa_session, email, self.app.model.User)
activation_token = user.activation_token
if activation_token is None:
activation_token = util.hash_util.new_secure_hash_v2(str(random.getrandbits(256)))
Expand Down Expand Up @@ -572,9 +575,7 @@ def send_reset_email(self, trans, payload, **kwd):
return "Failed to produce password reset token. User not found."

def get_reset_token(self, trans, email):
reset_user = (
trans.sa_session.query(self.app.model.User).filter(self.app.model.User.table.c.email == email).first()
)
reset_user = get_user_by_email(trans.sa_session, email, self.app.model.User)
if not reset_user and email != email.lower():
reset_user = (
trans.sa_session.query(self.app.model.User)
Expand Down Expand Up @@ -618,12 +619,7 @@ def get_or_create_remote_user(self, remote_user_email):
return None
if getattr(self.app.config, "normalize_remote_user_email", False):
remote_user_email = remote_user_email.lower()
user = (
self.session()
.query(self.app.model.User)
.filter(self.app.model.User.table.c.email == remote_user_email)
.first()
)
user = get_user_by_email(self.session(), remote_user_email, self.app.model.User)
if user:
# GVK: June 29, 2009 - This is to correct the behavior of a previous bug where a private
# role and default user / history permissions were not set for remote users. When a
Expand Down Expand Up @@ -838,15 +834,3 @@ def _add_parsers(self):
)

self.fn_filter_parsers.update({})


def get_user_by_username(session, user_class, username):
"""
Get a user from the database by username.
(We pass the session and the user_class to accommodate usage from the tool_shed app.)
"""
try:
stmt = select(user_class).filter(user_class.username == username)
return session.execute(stmt).scalar_one()
except Exception:
return None
20 changes: 8 additions & 12 deletions lib/galaxy/visualization/genomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from typing import Dict

from bx.seq.twobit import TwoBitFile
from sqlalchemy import select

from galaxy.exceptions import (
ObjectNotFound,
ReferenceDataError,
)
from galaxy.model.repositories.hda import HistoryDatasetAssociationRepository as hda_repo
from galaxy.model import HistoryDatasetAssociation
from galaxy.model.repositories import get_user_by_username
from galaxy.structured_app import StructuredApp
from galaxy.util.bunch import Bunch

Expand Down Expand Up @@ -289,11 +289,12 @@ def chroms(self, trans, dbkey=None, num=None, chrom=None, low=None):
Returns a naturally sorted list of chroms/contigs for a given dbkey.
Use either chrom or low to specify the starting chrom in the return list.
"""
session = trans.sa_session
self.check_and_reload()
# If there is no dbkey owner, default to current user.
dbkey_owner, dbkey = decode_dbkey(dbkey)
if dbkey_owner:
dbkey_user = self._get_dbkey_user(trans, dbkey_owner)
dbkey_user = get_user_by_username(session, dbkey_owner)
else:
dbkey_user = trans.user

Expand All @@ -309,10 +310,9 @@ def chroms(self, trans, dbkey=None, num=None, chrom=None, low=None):
if dbkey in user_keys:
dbkey_attributes = user_keys[dbkey]
dbkey_name = dbkey_attributes["name"]
_hda_repo = hda_repo(trans.sa_session)
# If there's a fasta for genome, convert to 2bit for later use.
if "fasta" in dbkey_attributes:
build_fasta = _hda_repo.get(dbkey_attributes["fasta"])
build_fasta = session.get(HistoryDatasetAssociation, dbkey_attributes["fasta"])
len_file = build_fasta.get_converted_dataset(trans, "len").file_name
build_fasta.get_converted_dataset(trans, "twobit")
# HACK: set twobit_file to True rather than a file name because
Expand All @@ -321,7 +321,7 @@ def chroms(self, trans, dbkey=None, num=None, chrom=None, low=None):
twobit_file = True
# Backwards compatibility: look for len file directly.
elif "len" in dbkey_attributes:
len_file = _hda_repo.get(user_keys[dbkey]["len"]).file_name
len_file = session.get(HistoryDatasetAssociation, user_keys[dbkey]["len"]).file_name
if len_file:
genome = Genome(dbkey, dbkey_name, len_file=len_file, twobit_file=twobit_file)

Expand Down Expand Up @@ -370,7 +370,7 @@ def reference(self, trans, dbkey, chrom, low, high):
# If there is no dbkey owner, default to current user.
dbkey_owner, dbkey = decode_dbkey(dbkey)
if dbkey_owner:
dbkey_user = self._get_dbkey_user(trans, dbkey_owner)
dbkey_user = get_user_by_username(trans.sa_session, dbkey_owner)
else:
dbkey_user = trans.user

Expand All @@ -387,7 +387,7 @@ def reference(self, trans, dbkey, chrom, low, high):
else:
user_keys = loads(dbkey_user.preferences["dbkeys"])
dbkey_attributes = user_keys[dbkey]
fasta_dataset = hda_repo(trans.sa_session).get(dbkey_attributes["fasta"])
fasta_dataset = trans.sa_session.get(HistoryDatasetAssociation, dbkey_attributes["fasta"])
msg = fasta_dataset.convert_dataset(trans, "twobit")
if msg:
return msg
Expand All @@ -405,7 +405,3 @@ def _get_reference_data(twobit_file_name, chrom, low, high):
if chrom in twobit:
seq_data = twobit[chrom].get(int(low), int(high))
return GenomeRegion(chrom=chrom, start=low, end=high, sequence=seq_data)

def _get_dbkey_user(self, trans, dbkey_owner):
stmt = select(trans.app.model.User).filter_by(username=dbkey_owner).limit(1)
return trans.sa_session.scalars(stmt).first()
5 changes: 2 additions & 3 deletions lib/galaxy/webapps/galaxy/controllers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from galaxy.exceptions import Conflict
from galaxy.managers import users
from galaxy.model.repositories import get_user_by_email
from galaxy.security.validate_user_input import (
validate_email,
validate_publicname,
Expand Down Expand Up @@ -293,9 +294,7 @@ def activate(self, trans, **kwd):
)
else:
# Find the user
user = (
trans.sa_session.query(trans.app.model.User).filter(trans.app.model.User.table.c.email == email).first()
)
user = get_user_by_email(trans.sa_session, email)
if not user:
# Probably wrong email address
return trans.show_error_message(
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/galaxy/services/quotas.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_user_id(item):
try:
return trans.security.decode_id(item)
except Exception:
return user_repo.get_by_email(item).id
return get_user_by_email(trans.sa_session, item).id

def get_group_id(item):
try:
Expand Down
4 changes: 1 addition & 3 deletions lib/galaxy/webapps/reports/controllers/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
and_,
not_,
or_,
select,
)

from galaxy import (
Expand Down Expand Up @@ -1307,8 +1306,7 @@ def get_monitor_id(trans, monitor_email):
A convenience method to obtain the monitor job id.
"""
monitor_user_id = None
stmt = select(trans.model.User.id).filter(trans.model.User.email == monitor_email).limit(1)
monitor_row = trans.sa_session.scalars(stmt).first()
monitor_row = get_user_by_email(trans.sa_session, monitor_email)
if monitor_row is not None:
monitor_user_id = monitor_row[0]
return monitor_user_id
7 changes: 2 additions & 5 deletions lib/tool_shed/test/base/test_db_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import galaxy.model
import galaxy.model.tool_shed_install
import tool_shed.webapp.model as model
from galaxy.model.repositories import get_user_by_username

log = logging.getLogger("test.tool_shed.test_db_util")

Expand Down Expand Up @@ -171,10 +172,6 @@ def get_user(email):
return sa_session().query(model.User).filter(model.User.table.c.email == email).first()


def get_user_by_name(username):
return sa_session().query(model.User).filter(model.User.table.c.username == username).first()


def mark_obj_deleted(obj):
obj.deleted = True
sa_session().add(obj)
Expand All @@ -190,7 +187,7 @@ def ga_refresh(obj):


def get_repository_by_name_and_owner(name, owner_username, return_multiple=False):
owner = get_user_by_name(owner_username)
owner = get_user_by_username(sa_session(), owner_username, model.User)
repository = (
sa_session()
.query(model.Repository)
Expand Down
6 changes: 3 additions & 3 deletions lib/tool_shed/webapp/controllers/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
util,
web,
)
from galaxy.managers.users import get_user_by_username
from galaxy.model.base import transaction
from galaxy.model.repositories import get_user_by_username
from galaxy.tool_shed.util import dependency_display
from galaxy.tools.repositories import ValidationContext
from galaxy.web.form_builder import (
Expand Down Expand Up @@ -2293,7 +2293,7 @@ def set_malicious(self, trans, id, ctx_str, **kwd):
def sharable_owner(self, trans, owner):
"""Support for sharable URL for each repository owner's tools, e.g. http://example.org/view/owner."""
try:
user = get_user_by_username(trans.model.session, trans.model.User, owner)
user = get_user_by_username(trans.model.session, owner, trans.model.User)
except Exception:
user = None
if user:
Expand Down Expand Up @@ -2321,7 +2321,7 @@ def sharable_repository(self, trans, owner, name):
else:
# If the owner is valid, then show all of their repositories.
try:
user = get_user_by_username(trans.model.session, trans.model.User, owner)
user = get_user_by_username(trans.model.session, owner, trans.model.User)
except Exception:
user = None
if user:
Expand Down
6 changes: 3 additions & 3 deletions test/integration/test_user_preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
get,
put,
)
from sqlalchemy import select

from galaxy.model.repositories import get_user_by_email
from galaxy_test.driver import integration_util

TEST_USER_EMAIL = "[email protected]"
Expand All @@ -19,8 +19,8 @@ def test_user_theme(self):
user = self._setup_user(TEST_USER_EMAIL)
url = self._api_url(f"users/{user['id']}/theme/test_theme", params=dict(key=self.master_api_key))
app = cast(Any, self._test_driver.app if self._test_driver else None)
stmt = select(app.model.User).filter(app.model.User.email == user["email"]).limit(1)
db_user = app.model.session.scalars(stmt).first()

db_user = get_user_by_email(app.model.session, user["email"])

# create some initial data
put(url)
Expand Down
5 changes: 2 additions & 3 deletions test/integration/test_vault_extra_prefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
get,
put,
)
from sqlalchemy import select

from galaxy.model.repositories import get_user_by_email
from galaxy_test.driver import integration_util

TEST_USER_EMAIL = "[email protected]"
Expand Down Expand Up @@ -133,5 +133,4 @@ def __url(self, action, user):
return self._api_url(f"users/{user['id']}/{action}", params=dict(key=self.master_api_key))

def _get_dbuser(self, app, user):
stmt = select(app.model.User).filter(app.model.User.email == user["email"]).limit(1)
return app.model.session.scalars(stmt).first()
return get_user_by_email(app.model.session, user["email"])
17 changes: 6 additions & 11 deletions test/integration/test_vault_file_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import tempfile

from sqlalchemy import select

from galaxy.model.repositories import get_user_by_email
from galaxy.security.vault import UserVaultWrapper
from galaxy_test.base import api_asserts
from galaxy_test.base.populators import DatasetPopulator
Expand Down Expand Up @@ -36,7 +35,7 @@ def test_vault_secret_per_user_in_file_source(self):
app = self._app

with self._different_user(email=self.USER_1_APP_VAULT_ENTRY):
user = self._get_user_by_email(self.USER_1_APP_VAULT_ENTRY)
user = get_user_by_email(self._app.model.session, self.USER_1_APP_VAULT_ENTRY)
user_vault = UserVaultWrapper(app.vault, user)
# use a valid symlink path so the posix list succeeds
user_vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0])
Expand All @@ -48,7 +47,7 @@ def test_vault_secret_per_user_in_file_source(self):
print(remote_files)

with self._different_user(email=self.USER_2_APP_VAULT_ENTRY):
user = self._get_user_by_email(self.USER_2_APP_VAULT_ENTRY)
user = get_user_by_email(self._app.model.session, self.USER_2_APP_VAULT_ENTRY)
user_vault = UserVaultWrapper(app.vault, user)
# use an invalid symlink path so the posix list fails
user_vault.write_secret("posix/root_path", "/invalid/root")
Expand All @@ -69,7 +68,7 @@ def test_vault_secret_per_app_in_file_source(self):
app.vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0])

with self._different_user(email=self.USER_1_APP_VAULT_ENTRY):
user = self._get_user_by_email(self.USER_1_APP_VAULT_ENTRY)
user = get_user_by_email(self._app.model.session, self.USER_1_APP_VAULT_ENTRY)
user_vault = UserVaultWrapper(app.vault, user)
# use a valid symlink path so the posix list succeeds
user_vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0])
Expand All @@ -81,7 +80,7 @@ def test_vault_secret_per_app_in_file_source(self):
print(remote_files)

with self._different_user(email=self.USER_2_APP_VAULT_ENTRY):
user = self._get_user_by_email(self.USER_2_APP_VAULT_ENTRY)
user = get_user_by_email(self._app.model.session, self.USER_2_APP_VAULT_ENTRY)
user_vault = UserVaultWrapper(app.vault, user)
# use an invalid symlink path so the posix list would fail if used
user_vault.write_secret("posix/root_path", "/invalid/root")
Expand All @@ -95,7 +94,7 @@ def test_vault_secret_per_app_in_file_source(self):
def test_upload_file_from_remote_source(self):
with self._different_user(email=self.USER_1_APP_VAULT_ENTRY):
app = self._app
user = self._get_user_by_email(self.USER_1_APP_VAULT_ENTRY)
user = get_user_by_email(self._app.model.session, self.USER_1_APP_VAULT_ENTRY)
user_vault = UserVaultWrapper(app.vault, user)
# use a valid symlink path so the posix list succeeds
user_vault.write_secret("posix/root_path", app.config.user_library_import_symlink_allowlist[0])
Expand All @@ -120,7 +119,3 @@ def test_upload_file_from_remote_source(self):
new_dataset = self.dataset_populator.fetch(payload, assert_ok=True).json()["outputs"][0]
content = self.dataset_populator.get_history_dataset_content(history_id, dataset=new_dataset)
assert content == "I require access to the vault", content

def _get_user_by_email(self, email):
stmt = select(self._app.model.User).filter(self._app.model.User.email == email).limit(1)
return self._app.model.session.scalars(stmt).first()

0 comments on commit c6cac90

Please sign in to comment.