Skip to content

Commit

Permalink
feature(tags-db): update tags related services and endpoints (#1925)
Browse files Browse the repository at this point in the history
feature(tags-db): update tags related services and endpoints (#1925)
  • Loading branch information
mabw-rte authored Feb 8, 2024
2 parents d11b7a5 + de3f20f commit bac31cb
Show file tree
Hide file tree
Showing 15 changed files with 369 additions and 267 deletions.
3 changes: 0 additions & 3 deletions antarest/core/interfaces/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ class CacheConstants(Enum):
This cache is used by the `create_from_fs` function when retrieving the configuration
of a study from the data on the disk.
- `STUDY_LISTING`: variable used to store objects of type `StudyMetadataDTO`.
This cache is used by the `get_studies_information` function to store the list of studies.
"""

RAW_STUDY = "RAW_STUDY"
STUDY_FACTORY = "STUDY_FACTORY"
STUDY_LISTING = "STUDY_LISTING"


class ICache:
Expand Down
60 changes: 0 additions & 60 deletions antarest/study/common/utils.py

This file was deleted.

32 changes: 27 additions & 5 deletions antarest/study/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

DEFAULT_WORKSPACE_NAME = "default"

STUDY_REFERENCE_TEMPLATES: t.Dict[str, str] = {
STUDY_REFERENCE_TEMPLATES: t.Mapping[str, str] = {
"600": "empty_study_613.zip",
"610": "empty_study_613.zip",
"640": "empty_study_613.zip",
Expand Down Expand Up @@ -61,6 +61,10 @@
class StudyTag(Base): # type:ignore
"""
A table to manage the many-to-many relationship between `Study` and `Tag`
Attributes:
study_id (str): The ID of the study associated with the tag.
tag_label (str): The label of the tag associated with the study.
"""

__tablename__ = "study_tag"
Expand All @@ -69,13 +73,25 @@ class StudyTag(Base): # type:ignore
study_id: str = Column(String(36), ForeignKey("study.id", ondelete="CASCADE"), index=True, nullable=False)
tag_label: str = Column(String(40), ForeignKey("tag.label", ondelete="CASCADE"), index=True, nullable=False)

def __str__(self) -> str:
def __str__(self) -> str: # pragma: no cover
return f"[StudyTag] study_id={self.study_id}, tag={self.tag}"

def __repr__(self) -> str: # pragma: no cover
cls_name = self.__class__.__name__
study_id = self.study_id
tag = self.tag
return f"{cls_name}({study_id=}, {tag=})"


class Tag(Base): # type:ignore
"""
A table to store all tags
Represents a tag in the database.
This class is used to store tags associated with studies.
Attributes:
label (str): The label of the tag.
color (str): The color code associated with the tag.
"""

__tablename__ = "tag"
Expand All @@ -85,8 +101,14 @@ class Tag(Base): # type:ignore

studies: t.List["Study"] = relationship("Study", secondary=StudyTag.__table__, back_populates="tags")

def __str__(self) -> str:
return f"[Tag] label={self.label}, css-color-code={self.color}"
def __str__(self) -> str: # pragma: no cover
return t.cast(str, self.label)

def __repr__(self) -> str: # pragma: no cover
cls_name = self.__class__.__name__
label = self.label
color = self.color
return f"{cls_name}({label=}, {color=})"


class StudyContentStatus(enum.Enum):
Expand Down
86 changes: 32 additions & 54 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import datetime
import enum
import logging
import typing as t

from pydantic import BaseModel, NonNegativeInt
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import CacheConstants, ICache
from antarest.core.interfaces.cache import ICache
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.login.model import Group
from antarest.study.common.utils import get_study_information
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData

logger = logging.getLogger(__name__)
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag


def escape_like(string: str, escape_char: str = "\\") -> str:
Expand Down Expand Up @@ -126,10 +122,7 @@ def save(
self,
metadata: Study,
update_modification_date: bool = False,
update_in_listing: bool = True,
) -> Study:
metadata_id = metadata.id or metadata.name
logger.debug(f"Saving study {metadata_id}")
if update_modification_date:
metadata.updated_at = datetime.datetime.utcnow()

Expand All @@ -140,8 +133,6 @@ def save(
session.add(metadata)
session.commit()

if update_in_listing:
self._update_study_from_cache_listing(metadata)
return metadata

def refresh(self, metadata: Study) -> None:
Expand All @@ -159,6 +150,7 @@ def get(self, id: str) -> t.Optional[Study]:
self.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.options(joinedload(Study.tags))
.get(id)
# fmt: on
)
Expand All @@ -175,6 +167,7 @@ def one(self, study_id: str) -> Study:
self.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.options(joinedload(Study.tags))
.filter_by(id=study_id)
.one()
)
Expand All @@ -189,23 +182,22 @@ def get_all(
study_filter: StudyFilter = StudyFilter(),
sort_by: t.Optional[StudySortBy] = None,
pagination: StudyPagination = StudyPagination(),
) -> t.List[Study]:
) -> t.Sequence[Study]:
"""
This function goal is to create a search engine throughout the studies with optimal
runtime.
Retrieve studies based on specified filters, sorting, and pagination.
Args:
study_filter: composed of all filtering criteria
sort_by: how the user would like the results to be sorted
pagination: specifies the number of results to displayed in each page and the actually displayed page
study_filter: composed of all filtering criteria.
sort_by: how the user would like the results to be sorted.
pagination: specifies the number of results to displayed in each page and the actually displayed page.
Returns:
The matching studies in proper order and pagination
The matching studies in proper order and pagination.
"""
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
# We also need to fetch the additional data to display the study information
# efficiently (see: `utils.get_study_information`)
# efficiently (see: `AbstractStorageService.get_study_information`)
entity = with_polymorphic(Study, "*")

# noinspection PyTypeChecker
Expand All @@ -218,6 +210,7 @@ def get_all(
q = q.options(joinedload(entity.owner))
q = q.options(joinedload(entity.groups))
q = q.options(joinedload(entity.additional_data))
q = q.options(joinedload(entity.tags))
if study_filter.managed is not None:
if study_filter.managed:
q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
Expand All @@ -230,6 +223,8 @@ def get_all(
q = q.filter(entity.owner_id.in_(study_filter.users))
if study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
if study_filter.tags:
q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags))
if study_filter.archived is not None:
q = q.filter(entity.archived == study_filter.archived)
if study_filter.name:
Expand Down Expand Up @@ -264,53 +259,36 @@ def get_all(
if pagination.page_nb or pagination.page_size:
q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size)

studies: t.List[Study] = q.all()
studies: t.Sequence[Study] = q.all()
return studies

def get_all_raw(self, exists: t.Optional[bool] = None) -> t.List[RawStudy]:
def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]:
query = self.session.query(RawStudy)
if exists is not None:
if exists:
query = query.filter(RawStudy.missing.is_(None))
else:
query = query.filter(not_(RawStudy.missing.is_(None)))
studies: t.List[RawStudy] = query.all()
studies: t.Sequence[RawStudy] = query.all()
return studies

def delete(self, id: str) -> None:
logger.debug(f"Deleting study {id}")
session = self.session
u: Study = session.query(Study).get(id)
session.delete(u)
session.commit()
self._remove_study_from_cache_listing(id)

def _remove_study_from_cache_listing(self, study_id: str) -> None:
try:
cached_studies = self.cache_service.get(CacheConstants.STUDY_LISTING.value)
if cached_studies:
if study_id in cached_studies:
del cached_studies[study_id]
self.cache_service.put(CacheConstants.STUDY_LISTING.value, cached_studies)
except Exception as e:
logger.error("Failed to update study listing cache", exc_info=e)
try:
self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value)
except Exception as e:
logger.error("Failed to invalidate listing cache", exc_info=e)

def _update_study_from_cache_listing(self, study: Study) -> None:
try:
cached_studies = self.cache_service.get(CacheConstants.STUDY_LISTING.value)
if cached_studies:
if isinstance(study, RawStudy) and study.missing is not None:
del cached_studies[study.id]
else:
cached_studies[study.id] = get_study_information(study)
self.cache_service.put(CacheConstants.STUDY_LISTING.value, cached_studies)
except Exception as e:
logger.error("Failed to update study listing cache", exc_info=e)
try:
self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value)
except Exception as e:
logger.error("Failed to invalidate listing cache", exc_info=e)

def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None:
"""
Updates the tags associated with a given study in the database,
replacing existing tags with new ones.
Args:
study: The pre-existing study to be updated with the new tags.
new_tags: The new tags to be associated with the input study in the database.
"""
existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all()
new_labels = set(new_tags) - set([tag.label for tag in existing_tags])
study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags
self.session.merge(study)
self.session.commit()
Loading

0 comments on commit bac31cb

Please sign in to comment.