Skip to content

Commit

Permalink
feature(tags-db): update tags related services and endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Feb 6, 2024
1 parent d11b7a5 commit e561322
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 153 deletions.
3 changes: 0 additions & 3 deletions antarest/core/interfaces/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ class CacheConstants(Enum):
This cache is used by the `create_from_fs` function when retrieving the configuration
of a study from the data on the disk.
- `STUDY_LISTING`: variable used to store objects of type `StudyMetadataDTO`.
This cache is used by the `get_studies_information` function to store the list of studies.
"""

RAW_STUDY = "RAW_STUDY"
STUDY_FACTORY = "STUDY_FACTORY"
STUDY_LISTING = "STUDY_LISTING"


class ICache:
Expand Down
60 changes: 0 additions & 60 deletions antarest/study/common/utils.py

This file was deleted.

59 changes: 22 additions & 37 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import CacheConstants, ICache
from antarest.core.interfaces.cache import ICache
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.login.model import Group
from antarest.study.common.utils import get_study_information
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,7 +125,6 @@ def save(
self,
metadata: Study,
update_modification_date: bool = False,
update_in_listing: bool = True,
) -> Study:
metadata_id = metadata.id or metadata.name
logger.debug(f"Saving study {metadata_id}")
Expand All @@ -140,8 +138,6 @@ def save(
session.add(metadata)
session.commit()

if update_in_listing:
self._update_study_from_cache_listing(metadata)
return metadata

def refresh(self, metadata: Study) -> None:
Expand Down Expand Up @@ -218,6 +214,7 @@ def get_all(
q = q.options(joinedload(entity.owner))
q = q.options(joinedload(entity.groups))
q = q.options(joinedload(entity.additional_data))
q = q.options(joinedload(entity.tags))
if study_filter.managed is not None:
if study_filter.managed:
q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
Expand All @@ -230,6 +227,8 @@ def get_all(
q = q.filter(entity.owner_id.in_(study_filter.users))
if study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
if study_filter.tags:
q = q.join(entity.tags).filter(Tag.id.in_(study_filter.tags))
if study_filter.archived is not None:
q = q.filter(entity.archived == study_filter.archived)
if study_filter.name:
Expand Down Expand Up @@ -283,34 +282,20 @@ def delete(self, id: str) -> None:
u: Study = session.query(Study).get(id)
session.delete(u)
session.commit()
self._remove_study_from_cache_listing(id)

def _remove_study_from_cache_listing(self, study_id: str) -> None:
try:
cached_studies = self.cache_service.get(CacheConstants.STUDY_LISTING.value)
if cached_studies:
if study_id in cached_studies:
del cached_studies[study_id]
self.cache_service.put(CacheConstants.STUDY_LISTING.value, cached_studies)
except Exception as e:
logger.error("Failed to update study listing cache", exc_info=e)
try:
self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value)
except Exception as e:
logger.error("Failed to invalidate listing cache", exc_info=e)

def _update_study_from_cache_listing(self, study: Study) -> None:
try:
cached_studies = self.cache_service.get(CacheConstants.STUDY_LISTING.value)
if cached_studies:
if isinstance(study, RawStudy) and study.missing is not None:
del cached_studies[study.id]
else:
cached_studies[study.id] = get_study_information(study)
self.cache_service.put(CacheConstants.STUDY_LISTING.value, cached_studies)
except Exception as e:
logger.error("Failed to update study listing cache", exc_info=e)
try:
self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value)
except Exception as e:
logger.error("Failed to invalidate listing cache", exc_info=e)

def update_tags(self, study: Study, new_tags: t.List[str]) -> None:
"""
Using the repository session we can update the study tags on the DB.
Thus, the tables `study_tag` and `tag` will be updated too accordingly.
Args:
study: a pre-existing study to be updated with the new tags
new_tags: the new tags to be associated with the input study on the db
Returns:
"""
logger.debug(f"Updating tags for study: {study.id}")
study.tags = [Tag(label=tag) for tag in new_tags]
self.session.merge(study)
self.session.commit()
15 changes: 6 additions & 9 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,9 @@ def get_study_information(self, uuid: str, params: RequestParameters) -> StudyMe
logger.info("study %s metadata asked by user %s", uuid, params.get_user_id())
# todo debounce this with a "update_study_last_access" method updating only every some seconds
study.last_access = datetime.utcnow()
self.repository.save(study, update_in_listing=False)
self.repository.save(study)
return self.storage_service.get_storage(study).get_study_information(study)

def invalidate_cache_listing(self, params: RequestParameters) -> None:
if params.user and params.user.is_site_admin():
self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value)
else:
logger.error(f"User {params.user} is not site admin")
raise UserHasNotPermissionError()

def update_study_information(
self,
uuid: str,
Expand Down Expand Up @@ -567,6 +560,10 @@ def update_study_information(
permissions=PermissionInfo.from_study(study),
)
)

new_tags = new_metadata.tags
self.repository.update_tags(study, new_tags)

return new_metadata

def check_study_access(
Expand Down Expand Up @@ -676,7 +673,7 @@ def get_study_synthesis(self, study_id: str, params: RequestParameters) -> FileS
study = self.get_study(study_id)
assert_permission(params.user, study, StudyPermissionType.READ)
study.last_access = datetime.utcnow()
self.repository.save(study, update_in_listing=False)
self.repository.save(study)
study_storage_service = self.storage_service.get_storage(study)
return study_storage_service.get_synthesis(study, params)

Expand Down
49 changes: 46 additions & 3 deletions antarest/study/storage/abstract_storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from antarest.core.config import Config
from antarest.core.exceptions import BadOutputError, StudyOutputNotFoundError
from antarest.core.interfaces.cache import CacheConstants, ICache
from antarest.core.model import JSON
from antarest.core.model import JSON, PublicMode
from antarest.core.utils.utils import StopWatch, extract_zip, unzip, zip_dir
from antarest.login.model import GroupDTO
from antarest.study.common.studystorage import IStudyStorageService, T
from antarest.study.common.utils import get_study_information
from antarest.study.model import (
DEFAULT_WORKSPACE_NAME,
OwnerInfo,
Patch,
PatchOutputs,
PatchStudy,
StudyAdditionalData,
Expand Down Expand Up @@ -68,7 +71,47 @@ def get_study_information(
self,
study: T,
) -> StudyMetadataDTO:
return get_study_information(study)
additional_data = study.additional_data or StudyAdditionalData()

try:
patch = Patch.parse_raw(additional_data.patch or "{}")
except Exception as e:
logger.warning(f"Failed to parse patch for study {study.id}", exc_info=e)
patch = Patch()

patch_metadata = patch.study or PatchStudy()

study_workspace = getattr(study, "workspace", DEFAULT_WORKSPACE_NAME)
folder: Optional[str] = None
if hasattr(study, "folder"):
folder = study.folder

owner_info = (
OwnerInfo(id=study.owner.id, name=study.owner.name)
if study.owner is not None
else OwnerInfo(name=additional_data.author or "Unknown")
)

return StudyMetadataDTO(
id=study.id,
name=study.name,
version=int(study.version),
created=str(study.created_at),
updated=str(study.updated_at),
workspace=study_workspace,
managed=study_workspace == DEFAULT_WORKSPACE_NAME,
type=study.type,
archived=study.archived if study.archived is not None else False,
owner=owner_info,
groups=[GroupDTO(id=group.id, name=group.name) for group in study.groups],
public_mode=study.public_mode or PublicMode.NONE,
horizon=additional_data.horizon,
scenario=patch_metadata.scenario,
status=patch_metadata.status,
doc=patch_metadata.doc,
folder=folder,
tags=[tag.label for tag in study.tags],
)

def get(
self,
Expand Down
3 changes: 1 addition & 2 deletions antarest/study/storage/variantstudy/variant_study_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ def invalidate_cache(
self.repository.save(
metadata=variant_study,
update_modification_date=True,
update_in_listing=False,
)
for child in self.repository.get_children(parent_id=variant_study.id):
self.invalidate_cache(child, invalidate_self_snapshot=True)
Expand Down Expand Up @@ -631,7 +630,7 @@ def callback(notifier: TaskUpdateNotifier) -> TaskResult:
custom_event_messages=CustomTaskEventMessages(start=metadata.id, running=metadata.id, end=metadata.id),
request_params=RequestParameters(DEFAULT_ADMIN_USER),
)
self.repository.save(metadata, update_in_listing=False)
self.repository.save(metadata)
return str(metadata.generation_task)

def generate(
Expand Down
6 changes: 2 additions & 4 deletions antarest/study/web/studies_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,18 +817,16 @@ def unarchive_study(

@bp.post(
"/studies/_invalidate_cache_listing",
summary="Invalidate the study listing cache",
summary="Invalidate the study listing cache [DEPRECATED] and will be removed soon",
tags=[APITag.study_management],
)
def invalidate_study_listing_cache(
current_user: JWTUser = Depends(auth.get_current_user),
) -> t.Any:
logger.info(
"Invalidating the study listing cache",
"Invalidating the study listing cache endpoint is deprecated",
extra={"user": current_user.id},
)
params = RequestParameters(user=current_user)
return study_service.invalidate_cache_listing(params)

@bp.get(
"/studies/{uuid}/disk-usage",
Expand Down
35 changes: 0 additions & 35 deletions tests/storage/repository/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore

from antarest.core.cache.business.local_chache import LocalCache
from antarest.core.interfaces.cache import CacheConstants
from antarest.login.model import Group, User
from antarest.study.common.utils import get_study_information
from antarest.study.model import DEFAULT_WORKSPACE_NAME, PublicMode, RawStudy, Study, StudyContentStatus
from antarest.study.repository import StudyMetadataRepository
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy
Expand Down Expand Up @@ -100,36 +98,3 @@ def test_study_inheritance():

assert isinstance(b, RawStudy)
assert b.path == "study"


@with_db_context
def test_cache():
user = User(id=0, name="admin")
group = Group(id="my-group", name="group")

cache = LocalCache()

repo = StudyMetadataRepository(cache)
a = RawStudy(
name="a",
version="42",
author="John Smith",
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
public_mode=PublicMode.FULL,
owner=user,
groups=[group],
workspace=DEFAULT_WORKSPACE_NAME,
path="study",
content_status=StudyContentStatus.WARNING,
)

repo.save(a)
cache.put(
CacheConstants.STUDY_LISTING.value,
{a.id: get_study_information(a)},
)
repo.save(a)
repo.delete(a.id)

assert len(cache.get(CacheConstants.STUDY_LISTING.value)) == 0
10 changes: 10 additions & 0 deletions tests/study/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,13 @@ def test_study_tag_relationship(self, db_session: Session) -> None:
assert len(studies) == 1
assert set(study.id for study in studies) == {study_id_1}
assert set(tag.label for tag in studies[0].tags) == {"test-tag-1"}

# verify updating works
study = db_session.query(Study).get(study_id_1)
study.tags = [Tag(label="test-tag-2"), Tag(label="test-tag-3")]
db_session.merge(study)
db_session.commit()
study_tag_pairs = db_session.query(StudyTag).all()
assert len(study_tag_pairs) == 2
assert set(e.tag_label for e in study_tag_pairs) == {"test-tag-2", "test-tag-3"}
assert set(e.study_id for e in study_tag_pairs) == {study_id_1}

0 comments on commit e561322

Please sign in to comment.