Skip to content

Commit

Permalink
perf(db): improved study query performance using owner and groups pre…
Browse files Browse the repository at this point in the history
…loading
  • Loading branch information
laurent-laporte-pro committed Nov 13, 2023
1 parent 96ad182 commit db391cb
Showing 1 changed file with 62 additions and 15 deletions.
77 changes: 62 additions & 15 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime
import logging
from datetime import datetime
from typing import List, Optional
import typing as t

from sqlalchemy.orm import with_polymorphic # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import CacheConstants, ICache
from antarest.core.utils.fastapi_sqlalchemy import db
Expand All @@ -17,8 +17,30 @@ class StudyMetadataRepository:
Database connector to manage Study entity
"""

def __init__(self, cache_service: ICache):
def __init__(self, cache_service: ICache, session: t.Optional[Session] = None):
"""
Initialize the repository.
Args:
cache_service: Cache service for the repository.
session: Optional SQLAlchemy session to be used.
"""
self.cache_service = cache_service
self._session = session

@property
def session(self) -> Session:
"""
Get the SQLAlchemy session for the repository.
Returns:
SQLAlchemy session.
"""
if self._session is None:
# Get or create the session from a context variable (thread local variable)
return db.session
# Get the user-defined session
return self._session

def save(
self,
Expand All @@ -29,7 +51,7 @@ def save(
metadata_id = metadata.id or metadata.name
logger.debug(f"Saving study {metadata_id}")
if update_modification_date:
metadata.updated_at = datetime.utcnow()
metadata.updated_at = datetime.datetime.utcnow()

metadata.groups = [db.session.merge(g) for g in metadata.groups]
if metadata.owner:
Expand All @@ -44,34 +66,59 @@ def save(
def refresh(self, metadata: Study) -> None:
db.session.refresh(metadata)

def get(self, id: str) -> Optional[Study]:
def get(self, id: str) -> t.Optional[Study]:
"""Get the study by ID or return `None` if not found in database."""
metadata: Study = db.session.query(Study).get(id)
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
metadata: Study = (
# fmt: off
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.get(id)
# fmt: on
)
return metadata

def one(self, id: str) -> Study:
"""Get the study by ID or raise `sqlalchemy.exc.NoResultFound` if not found in database."""
study: Study = db.session.query(Study).filter_by(id=id).one()
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
study: Study = (
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.filter_by(id=id)
.one()
)
return study

def get_list(self, study_id: List[str]) -> List[Study]:
studies: List[Study] = db.session.query(Study).where(Study.id.in_(study_id)).all()
def get_list(self, study_id: t.List[str]) -> t.List[Study]:
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
studies: t.List[Study] = (
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.where(Study.id.in_(study_id))
.all()
)
return studies

def get_additional_data(self, study_id: str) -> Optional[StudyAdditionalData]:
def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]:
metadata: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id)
return metadata

def get_all(self) -> List[Study]:
def get_all(self) -> t.List[Study]:
entity = with_polymorphic(Study, "*")
metadatas: List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all()
metadatas: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all()
return metadatas

def get_all_raw(self, show_missing: bool = True) -> List[RawStudy]:
def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]:
query = db.session.query(RawStudy)
if not show_missing:
query = query.filter(RawStudy.missing.is_(None))
metadatas: List[RawStudy] = query.all()
metadatas: t.List[RawStudy] = query.all()
return metadatas

def delete(self, id: str) -> None:
Expand Down

0 comments on commit db391cb

Please sign in to comment.