Skip to content

Commit

Permalink
Merge pull request #17778 from jdavcs/dev_sa20
Browse files Browse the repository at this point in the history
SQLAlchemy 2.0
  • Loading branch information
jdavcs authored Apr 3, 2024
2 parents 6034a35 + 2c527d4 commit af53d03
Show file tree
Hide file tree
Showing 109 changed files with 2,102 additions and 1,847 deletions.
2 changes: 1 addition & 1 deletion lib/galaxy/app_unittest_utils/galaxy_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, config=None, **kwargs) -> None:
self[ShortTermStorageMonitor] = sts_manager # type: ignore[type-abstract]
self[galaxy_scoped_session] = self.model.context
self.visualizations_registry = MockVisualizationsRegistry()
self.tag_handler = tags.GalaxyTagHandler(self.model.context)
self.tag_handler = tags.GalaxyTagHandler(self.model.session)
self[tags.GalaxyTagHandler] = self.tag_handler
self.quota_agent = quota.DatabaseQuotaAgent(self.model)
self.job_config = Bunch(
Expand Down
45 changes: 17 additions & 28 deletions lib/galaxy/celery/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from sqlalchemy.dialects.postgresql import insert as ps_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from galaxy.model import CeleryUserRateLimit
from galaxy.model.base import transaction
Expand Down Expand Up @@ -70,7 +69,7 @@ def __call__(self, task: Task, task_id, args, kwargs):

@abstractmethod
def calculate_task_start_time(
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
return now

Expand All @@ -81,38 +80,28 @@ class GalaxyTaskBeforeStartUserRateLimitPostgres(GalaxyTaskBeforeStartUserRateLi
We take advantage of efficiencies in its dialect.
"""

_update_stmt = (
update(CeleryUserRateLimit)
.where(CeleryUserRateLimit.user_id == bindparam("userid"))
.values(last_scheduled_time=text("greatest(last_scheduled_time + ':interval second', " ":now) "))
.returning(CeleryUserRateLimit.last_scheduled_time)
)

_insert_stmt = (
ps_insert(CeleryUserRateLimit)
.values(user_id=bindparam("userid"), last_scheduled_time=bindparam("now"))
.returning(CeleryUserRateLimit.last_scheduled_time)
)

_upsert_stmt = _insert_stmt.on_conflict_do_update(
index_elements=["user_id"], set_=dict(last_scheduled_time=bindparam("sched_time"))
)

def calculate_task_start_time( # type: ignore
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
with transaction(sa_session):
result = sa_session.execute(
self._update_stmt, {"userid": user_id, "interval": task_interval_secs, "now": now}
update_stmt = (
update(CeleryUserRateLimit)
.where(CeleryUserRateLimit.user_id == user_id)
.values(last_scheduled_time=text("greatest(last_scheduled_time + ':interval second', " ":now) "))
.returning(CeleryUserRateLimit.last_scheduled_time)
)
if result.rowcount == 0:
result = sa_session.execute(update_stmt, {"interval": task_interval_secs, "now": now}).all()
if not result:
sched_time = now + datetime.timedelta(seconds=task_interval_secs)
result = sa_session.execute(
self._upsert_stmt, {"userid": user_id, "now": now, "sched_time": sched_time}
upsert_stmt = (
ps_insert(CeleryUserRateLimit) # type:ignore[attr-defined]
.values(user_id=user_id, last_scheduled_time=now)
.returning(CeleryUserRateLimit.last_scheduled_time)
.on_conflict_do_update(index_elements=["user_id"], set_=dict(last_scheduled_time=sched_time))
)
for row in result:
return row[0]
result = sa_session.execute(upsert_stmt).all()
sa_session.commit()
return result[0][0]


class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLimit):
Expand All @@ -138,7 +127,7 @@ class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLi
)

def calculate_task_start_time(
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
last_scheduled_time = None
with transaction(sa_session):
Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def set_metadata(
try:
if overwrite:
hda_manager.overwrite_metadata(dataset_instance)
dataset_instance.datatype.set_meta(dataset_instance)
dataset_instance.datatype.set_meta(dataset_instance) # type:ignore [arg-type]
dataset_instance.set_peek()
# Reset SETTING_METADATA state so the dataset instance getter picks the dataset state
dataset_instance.set_metadata_success_state()
Expand Down Expand Up @@ -228,6 +228,7 @@ def setup_fetch_data(
):
tool = cached_create_tool_from_representation(app=app, raw_tool_source=raw_tool_source)
job = sa_session.get(Job, job_id)
assert job
# self.request.hostname is the actual worker name given by the `-n` argument, not the hostname as you might think.
job.handler = self.request.hostname
job.job_runner_name = "celery"
Expand Down Expand Up @@ -260,6 +261,7 @@ def finish_job(
):
tool = cached_create_tool_from_representation(app=app, raw_tool_source=raw_tool_source)
job = sa_session.get(Job, job_id)
assert job
# TODO: assert state ?
mini_job_wrapper = MinimalJobWrapper(job=job, app=app, tool=tool)
mini_job_wrapper.finish("", "")
Expand Down Expand Up @@ -320,6 +322,7 @@ def fetch_data(
task_user_id: Optional[int] = None,
) -> str:
job = sa_session.get(Job, job_id)
assert job
mini_job_wrapper = MinimalJobWrapper(job=job, app=app)
mini_job_wrapper.change_state(model.Job.states.RUNNING, flush=True, job=job)
return abort_when_job_stops(_fetch_data, session=sa_session, job_id=job_id, setup_return=setup_return)
Expand Down
31 changes: 0 additions & 31 deletions lib/galaxy/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,6 @@ class GalaxyAppConfiguration(BaseAppConfiguration, CommonConfigurationMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._override_tempdir(kwargs)
self._configure_sqlalchemy20_warnings(kwargs)
self._process_config(kwargs)
self._set_dependent_defaults()

Expand All @@ -764,36 +763,6 @@ def _set_dependent_defaults(self):
f"{dependent_config_param}, {config_param}"
)

def _configure_sqlalchemy20_warnings(self, kwargs):
"""
This method should be deleted after migration to SQLAlchemy 2.0 is complete.
To enable warnings, set `GALAXY_CONFIG_SQLALCHEMY_WARN_20=1`,
"""
warn = string_as_bool(kwargs.get("sqlalchemy_warn_20", False))
if warn:
import sqlalchemy

sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20 = True
self._setup_sqlalchemy20_warnings_filters()

def _setup_sqlalchemy20_warnings_filters(self):
import warnings

from sqlalchemy.exc import RemovedIn20Warning

# Always display RemovedIn20Warning warnings.
warnings.filterwarnings("always", category=RemovedIn20Warning)
# Optionally, enable filters for specific warnings (raise error, or log, etc.)
# messages = [
# r"replace with warning text to match",
# ]
# for msg in messages:
# warnings.filterwarnings('error', message=msg, category=RemovedIn20Warning)
#
# See documentation:
# https://docs.python.org/3.7/library/warnings.html#the-warnings-filter
# https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#migration-to-2-0-step-three-resolve-all-removedin20warnings

def _load_schema(self):
return AppSchema(GALAXY_CONFIG_SCHEMA_PATH, GALAXY_APP_NAME)

Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/dependencies/pinned-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ sniffio==1.3.1 ; python_version >= "3.8" and python_version < "3.13"
social-auth-core==4.5.3 ; python_version >= "3.8" and python_version < "3.13"
sortedcontainers==2.4.0 ; python_version >= "3.8" and python_version < "3.13"
spython==0.3.13 ; python_version >= "3.8" and python_version < "3.13"
sqlalchemy==1.4.52 ; python_version >= "3.8" and python_version < "3.13"
sqlalchemy==2.0.25 ; python_version >= "3.8" and python_version < "3.13"
sqlitedict==2.1.0 ; python_version >= "3.8" and python_version < "3.13"
sqlparse==0.4.4 ; python_version >= "3.8" and python_version < "3.13"
starlette-context==0.3.6 ; python_version >= "3.8" and python_version < "3.13"
Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,9 @@ def galaxy_url(self):
return self.get_destination_configuration("galaxy_infrastructure_url")

def get_job(self) -> model.Job:
return self.sa_session.get(Job, self.job_id)
job = self.sa_session.get(Job, self.job_id)
assert job
return job

def get_id_tag(self):
# For compatibility with drmaa, which uses job_id right now, and TaskWrapper
Expand Down Expand Up @@ -1552,7 +1554,7 @@ def change_state(self, state, info=False, flush=True, job=None):
def get_state(self) -> str:
job = self.get_job()
self.sa_session.refresh(job)
return job.state
return job.state # type:ignore[return-value]

def set_runner(self, runner_url, external_id):
log.warning("set_runner() is deprecated, use set_job_destination()")
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def _one_with_recast_errors(self, query: Query) -> U:
# overridden to raise serializable errors
try:
return query.one()
except sqlalchemy.orm.exc.NoResultFound:
except sqlalchemy.exc.NoResultFound:
raise exceptions.ObjectNotFound(f"{self.model_class.__name__} not found")
except sqlalchemy.orm.exc.MultipleResultsFound:
except sqlalchemy.exc.MultipleResultsFound:
raise exceptions.InconsistentDatabase(f"found more than one {self.model_class.__name__}")

# NOTE: at this layer, all ids are expected to be decoded and in int form
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def get_collection_contents(self, trans: ProvidesAppContext, parent_id, limit=No
def _get_collection_contents_qry(self, parent_id, limit=None, offset=None):
"""Build query to find first level of collection contents by containing collection parent_id"""
DCE = model.DatasetCollectionElement
qry = Query(DCE).filter(DCE.dataset_collection_id == parent_id)
qry = Query(DCE).filter(DCE.dataset_collection_id == parent_id) # type:ignore[var-annotated]
qry = qry.order_by(DCE.element_index)
qry = qry.options(
joinedload(model.DatasetCollectionElement.child_collection), joinedload(model.DatasetCollectionElement.hda)
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def purge_datasets(self, request: PurgeDatasetsTaskRequest):
self.error_unless_dataset_purge_allowed()
with self.session().begin():
for dataset_id in request.dataset_ids:
dataset: Dataset = self.session().get(Dataset, dataset_id)
if dataset.user_can_purge:
dataset: Optional[Dataset] = self.session().get(Dataset, dataset_id)
if dataset and dataset.user_can_purge:
try:
dataset.full_delete()
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/dbkeys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
)

from sqlalchemy import select
from sqlalchemy.orm import Session

from galaxy.model import HistoryDatasetAssociation
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.util import (
galaxy_directory,
sanitize_lists_to_string,
Expand Down Expand Up @@ -166,6 +166,6 @@ def get_chrom_info(self, dbkey, trans=None, custom_build_hack_get_len_from_fasta
return (chrom_info, db_dataset)


def get_len_files_by_history(session: Session, history_id: int):
def get_len_files_by_history(session: galaxy_scoped_session, history_id: int):
stmt = select(HistoryDatasetAssociation).filter_by(history_id=history_id, extension="len", deleted=False)
return session.scalars(stmt)
6 changes: 3 additions & 3 deletions lib/galaxy/managers/export_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
and_,
select,
)
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm.scoping import scoped_session

from galaxy.exceptions import ObjectNotFound
Expand Down Expand Up @@ -44,7 +44,7 @@ def set_export_association_metadata(self, export_association_id: int, export_met
export_association: StoreExportAssociation = self.session.execute(stmt).scalars().one()
except NoResultFound:
raise ObjectNotFound("Cannot set export metadata. Reason: Export association not found")
export_association.export_metadata = export_metadata.json()
export_association.export_metadata = export_metadata.model_dump_json() # type:ignore[assignment]
with transaction(self.session):
self.session.commit()

Expand Down Expand Up @@ -76,4 +76,4 @@ def get_object_exports(
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
return self.session.execute(stmt).scalars()
return self.session.execute(stmt).scalars() # type:ignore[return-value]
9 changes: 5 additions & 4 deletions lib/galaxy/managers/folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
or_,
select,
)
from sqlalchemy.orm import aliased
from sqlalchemy.orm.exc import (
from sqlalchemy.exc import (
MultipleResultsFound,
NoResultFound,
)
from sqlalchemy.orm import aliased

from galaxy import (
model,
Expand Down Expand Up @@ -505,7 +505,7 @@ def _get_contained_datasets_statement(
stmt = stmt.where(
or_(
func.lower(ldda.name).contains(search_text, autoescape=True),
func.lower(ldda.message).contains(search_text, autoescape=True),
func.lower(ldda.message).contains(search_text, autoescape=True), # type:ignore[attr-defined]
)
)
sort_column = LDDA_SORT_COLUMN_MAP[payload.order_by](ldda, associated_dataset)
Expand Down Expand Up @@ -536,7 +536,7 @@ def _filter_by_include_deleted(

def build_folder_path(
self, sa_session: galaxy_scoped_session, folder: model.LibraryFolder
) -> List[Tuple[str, str]]:
) -> List[Tuple[int, Optional[str]]]:
"""
Returns the folder path from root to the given folder.
Expand All @@ -546,6 +546,7 @@ def build_folder_path(
path_to_root = [(current_folder.id, current_folder.name)]
while current_folder.parent_id is not None:
parent_folder = sa_session.get(LibraryFolder, current_folder.parent_id)
assert parent_folder
current_folder = parent_folder
path_to_root.insert(0, (current_folder.id, current_folder.name))
return path_to_root
Expand Down
9 changes: 6 additions & 3 deletions lib/galaxy/managers/forms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy import select
from sqlalchemy.orm import exc as sqlalchemy_exceptions
from sqlalchemy.exc import (
MultipleResultsFound,
NoResultFound,
)

from galaxy.exceptions import (
InconsistentDatabase,
Expand Down Expand Up @@ -59,9 +62,9 @@ def get(self, trans: ProvidesUserContext, form_id: int) -> FormDefinitionCurrent
try:
stmt = select(FormDefinitionCurrent).where(FormDefinitionCurrent.id == form_id)
form = self.session().execute(stmt).scalar_one()
except sqlalchemy_exceptions.MultipleResultsFound:
except MultipleResultsFound:
raise InconsistentDatabase("Multiple forms found with the same id.")
except sqlalchemy_exceptions.NoResultFound:
except NoResultFound:
raise RequestParameterInvalidException("No accessible form found with the id provided.")
except Exception as e:
raise InternalServerError(f"Error loading from the database.{unicodify(e)}")
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/genomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _create_genome_filter(model_class=None):
if self.database_connection.startswith("postgres"):
column = text("convert_from(metadata, 'UTF8')::json ->> 'dbkey'")
else:
column = func.json_extract(model_class.table.c._metadata, "$.dbkey")
column = func.json_extract(model_class.table.c._metadata, "$.dbkey") # type:ignore[assignment]
lower_val = val.lower() # Ignore case
# dbkey can either be "hg38" or '["hg38"]', so we need to check both
if op == "eq":
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/group_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
)

from sqlalchemy import select
from sqlalchemy.orm import Session

from galaxy import model
from galaxy.exceptions import ObjectNotFound
from galaxy.managers.context import ProvidesAppContext
from galaxy.model import GroupRoleAssociation
from galaxy.model.base import transaction
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.structured_app import MinimalManagerApp

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,7 +93,7 @@ def _remove_role_from_group(self, trans: ProvidesAppContext, group_role: model.G
trans.sa_session.commit()


def get_group_role(session: Session, group, role) -> Optional[GroupRoleAssociation]:
def get_group_role(session: galaxy_scoped_session, group, role) -> Optional[GroupRoleAssociation]:
stmt = (
select(GroupRoleAssociation).where(GroupRoleAssociation.group == group).where(GroupRoleAssociation.role == role)
)
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/group_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
)

from sqlalchemy import select
from sqlalchemy.orm import Session

from galaxy import model
from galaxy.exceptions import ObjectNotFound
Expand All @@ -15,6 +14,7 @@
UserGroupAssociation,
)
from galaxy.model.base import transaction
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.structured_app import MinimalManagerApp

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,7 +96,7 @@ def _remove_user_from_group(self, trans: ProvidesAppContext, group_user: model.U
trans.sa_session.commit()


def get_group_user(session: Session, user, group) -> Optional[UserGroupAssociation]:
def get_group_user(session: galaxy_scoped_session, user, group) -> Optional[UserGroupAssociation]:
stmt = (
select(UserGroupAssociation).where(UserGroupAssociation.user == user).where(UserGroupAssociation.group == group)
)
Expand Down
Loading

0 comments on commit af53d03

Please sign in to comment.