From e561322ad9adb66e9ab2f426a65ffd767b50e9e2 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Tue, 6 Feb 2024 01:24:32 +0100 Subject: [PATCH 01/11] feature(tags-db): update tags related services and endpoints --- antarest/core/interfaces/cache.py | 3 - antarest/study/common/utils.py | 60 ------------------- antarest/study/repository.py | 59 +++++++----------- antarest/study/service.py | 15 ++--- .../study/storage/abstract_storage_service.py | 49 ++++++++++++++- .../variantstudy/variant_study_service.py | 3 +- antarest/study/web/studies_blueprint.py | 6 +- tests/storage/repository/test_study.py | 35 ----------- tests/study/test_model.py | 10 ++++ 9 files changed, 87 insertions(+), 153 deletions(-) delete mode 100644 antarest/study/common/utils.py diff --git a/antarest/core/interfaces/cache.py b/antarest/core/interfaces/cache.py index 33d5fa541e..3fce146145 100644 --- a/antarest/core/interfaces/cache.py +++ b/antarest/core/interfaces/cache.py @@ -19,13 +19,10 @@ class CacheConstants(Enum): This cache is used by the `create_from_fs` function when retrieving the configuration of a study from the data on the disk. - - `STUDY_LISTING`: variable used to store objects of type `StudyMetadataDTO`. - This cache is used by the `get_studies_information` function to store the list of studies. """ RAW_STUDY = "RAW_STUDY" STUDY_FACTORY = "STUDY_FACTORY" - STUDY_LISTING = "STUDY_LISTING" class ICache: diff --git a/antarest/study/common/utils.py b/antarest/study/common/utils.py deleted file mode 100644 index e1dd26e86f..0000000000 --- a/antarest/study/common/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -from typing import Optional - -from antarest.core.model import PublicMode -from antarest.login.model import GroupDTO -from antarest.study.model import ( - DEFAULT_WORKSPACE_NAME, - OwnerInfo, - Patch, - PatchStudy, - Study, - StudyAdditionalData, - StudyMetadataDTO, -) - -logger = logging.getLogger(__name__) - - -def get_study_information(study: Study) -> StudyMetadataDTO: - additional_data = study.additional_data or StudyAdditionalData() - - try: - patch = Patch.parse_raw(additional_data.patch or "{}") - except Exception as e: - logger.warning(f"Failed to parse patch for study {study.id}", exc_info=e) - patch = Patch() - - patch_metadata = patch.study or PatchStudy() - - study_workspace = getattr(study, "workspace", DEFAULT_WORKSPACE_NAME) - folder: Optional[str] = None - if hasattr(study, "folder"): - folder = study.folder - - owner_info = ( - OwnerInfo(id=study.owner.id, name=study.owner.name) - if study.owner is not None - else OwnerInfo(name=additional_data.author or "Unknown") - ) - - return StudyMetadataDTO( - id=study.id, - name=study.name, - version=int(study.version), - created=str(study.created_at), - updated=str(study.updated_at), - workspace=study_workspace, - managed=study_workspace == DEFAULT_WORKSPACE_NAME, - type=study.type, - archived=study.archived if study.archived is not None else False, - owner=owner_info, - groups=[GroupDTO(id=group.id, name=group.name) for group in study.groups], - public_mode=study.public_mode or PublicMode.NONE, - horizon=additional_data.horizon, - scenario=patch_metadata.scenario, - status=patch_metadata.status, - doc=patch_metadata.doc, - folder=folder, - tags=patch_metadata.tags, - ) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index e4646e1546..d8e2675cd6 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -7,11 +7,10 @@ from sqlalchemy import func, not_, or_ # type: ignore from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore -from antarest.core.interfaces.cache import CacheConstants, ICache +from antarest.core.interfaces.cache import ICache from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group -from antarest.study.common.utils import get_study_information -from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData +from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag logger = logging.getLogger(__name__) @@ -126,7 +125,6 @@ def save( self, metadata: Study, update_modification_date: bool = False, - update_in_listing: bool = True, ) -> Study: metadata_id = metadata.id or metadata.name logger.debug(f"Saving study {metadata_id}") @@ -140,8 +138,6 @@ def save( session.add(metadata) session.commit() - if update_in_listing: - self._update_study_from_cache_listing(metadata) return metadata def refresh(self, metadata: Study) -> None: @@ -218,6 +214,7 @@ def get_all( q = q.options(joinedload(entity.owner)) q = q.options(joinedload(entity.groups)) q = q.options(joinedload(entity.additional_data)) + q = q.options(joinedload(entity.tags)) if study_filter.managed is not None: if study_filter.managed: q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME)) @@ -230,6 +227,8 @@ def get_all( q = q.filter(entity.owner_id.in_(study_filter.users)) if study_filter.groups: q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) + if study_filter.tags: + q = q.join(entity.tags).filter(Tag.id.in_(study_filter.tags)) if study_filter.archived is not None: q = q.filter(entity.archived == study_filter.archived) if study_filter.name: @@ -283,34 +282,20 @@ def delete(self, id: str) -> None: u: Study = session.query(Study).get(id) session.delete(u) session.commit() - self._remove_study_from_cache_listing(id) - - def _remove_study_from_cache_listing(self, study_id: str) -> None: - try: - cached_studies = self.cache_service.get(CacheConstants.STUDY_LISTING.value) - if cached_studies: - if study_id in cached_studies: - del cached_studies[study_id] - self.cache_service.put(CacheConstants.STUDY_LISTING.value, cached_studies) - except Exception as e: - logger.error("Failed to update study listing cache", exc_info=e) - try: - self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value) - except Exception as e: - logger.error("Failed to invalidate listing cache", exc_info=e) - - def _update_study_from_cache_listing(self, study: Study) -> None: - try: - cached_studies = self.cache_service.get(CacheConstants.STUDY_LISTING.value) - if cached_studies: - if isinstance(study, RawStudy) and study.missing is not None: - del cached_studies[study.id] - else: - cached_studies[study.id] = get_study_information(study) - self.cache_service.put(CacheConstants.STUDY_LISTING.value, cached_studies) - except Exception as e: - logger.error("Failed to update study listing cache", exc_info=e) - try: - self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value) - except Exception as e: - logger.error("Failed to invalidate listing cache", exc_info=e) + + def update_tags(self, study: Study, new_tags: t.List[str]) -> None: + """ + Using the repository session we can update the study tags on the DB. + Thus, the tables `study_tag` and `tag` will be updated too accordingly. + + Args: + study: a pre-existing study to be updated with the new tags + new_tags: the new tags to be associated with the input study on the db + + Returns: + + """ + logger.debug(f"Updating tags for study: {study.id}") + study.tags = [Tag(label=tag) for tag in new_tags] + self.session.merge(study) + self.session.commit() diff --git a/antarest/study/service.py b/antarest/study/service.py index 79cdb930fe..3d10b2bef1 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -505,16 +505,9 @@ def get_study_information(self, uuid: str, params: RequestParameters) -> StudyMe logger.info("study %s metadata asked by user %s", uuid, params.get_user_id()) # todo debounce this with a "update_study_last_access" method updating only every some seconds study.last_access = datetime.utcnow() - self.repository.save(study, update_in_listing=False) + self.repository.save(study) return self.storage_service.get_storage(study).get_study_information(study) - def invalidate_cache_listing(self, params: RequestParameters) -> None: - if params.user and params.user.is_site_admin(): - self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value) - else: - logger.error(f"User {params.user} is not site admin") - raise UserHasNotPermissionError() - def update_study_information( self, uuid: str, @@ -567,6 +560,10 @@ def update_study_information( permissions=PermissionInfo.from_study(study), ) ) + + new_tags = new_metadata.tags + self.repository.update_tags(study, new_tags) + return new_metadata def check_study_access( @@ -676,7 +673,7 @@ def get_study_synthesis(self, study_id: str, params: RequestParameters) -> FileS study = self.get_study(study_id) assert_permission(params.user, study, StudyPermissionType.READ) study.last_access = datetime.utcnow() - self.repository.save(study, update_in_listing=False) + self.repository.save(study) study_storage_service = self.storage_service.get_storage(study) return study_storage_service.get_synthesis(study, params) diff --git a/antarest/study/storage/abstract_storage_service.py b/antarest/study/storage/abstract_storage_service.py index 60cd48b782..cd6b8be080 100644 --- a/antarest/study/storage/abstract_storage_service.py +++ b/antarest/study/storage/abstract_storage_service.py @@ -9,11 +9,14 @@ from antarest.core.config import Config from antarest.core.exceptions import BadOutputError, StudyOutputNotFoundError from antarest.core.interfaces.cache import CacheConstants, ICache -from antarest.core.model import JSON +from antarest.core.model import JSON, PublicMode from antarest.core.utils.utils import StopWatch, extract_zip, unzip, zip_dir +from antarest.login.model import GroupDTO from antarest.study.common.studystorage import IStudyStorageService, T -from antarest.study.common.utils import get_study_information from antarest.study.model import ( + DEFAULT_WORKSPACE_NAME, + OwnerInfo, + Patch, PatchOutputs, PatchStudy, StudyAdditionalData, @@ -68,7 +71,47 @@ def get_study_information( self, study: T, ) -> StudyMetadataDTO: - return get_study_information(study) + additional_data = study.additional_data or StudyAdditionalData() + + try: + patch = Patch.parse_raw(additional_data.patch or "{}") + except Exception as e: + logger.warning(f"Failed to parse patch for study {study.id}", exc_info=e) + patch = Patch() + + patch_metadata = patch.study or PatchStudy() + + study_workspace = getattr(study, "workspace", DEFAULT_WORKSPACE_NAME) + folder: Optional[str] = None + if hasattr(study, "folder"): + folder = study.folder + + owner_info = ( + OwnerInfo(id=study.owner.id, name=study.owner.name) + if study.owner is not None + else OwnerInfo(name=additional_data.author or "Unknown") + ) + + return StudyMetadataDTO( + id=study.id, + name=study.name, + version=int(study.version), + created=str(study.created_at), + updated=str(study.updated_at), + workspace=study_workspace, + managed=study_workspace == DEFAULT_WORKSPACE_NAME, + type=study.type, + archived=study.archived if study.archived is not None else False, + owner=owner_info, + groups=[GroupDTO(id=group.id, name=group.name) for group in study.groups], + public_mode=study.public_mode or PublicMode.NONE, + horizon=additional_data.horizon, + scenario=patch_metadata.scenario, + status=patch_metadata.status, + doc=patch_metadata.doc, + folder=folder, + tags=[tag.label for tag in study.tags], + ) def get( self, diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index f8e9fb95a1..e59ef3fa94 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -373,7 +373,6 @@ def invalidate_cache( self.repository.save( metadata=variant_study, update_modification_date=True, - update_in_listing=False, ) for child in self.repository.get_children(parent_id=variant_study.id): self.invalidate_cache(child, invalidate_self_snapshot=True) @@ -631,7 +630,7 @@ def callback(notifier: TaskUpdateNotifier) -> TaskResult: custom_event_messages=CustomTaskEventMessages(start=metadata.id, running=metadata.id, end=metadata.id), request_params=RequestParameters(DEFAULT_ADMIN_USER), ) - self.repository.save(metadata, update_in_listing=False) + self.repository.save(metadata) return str(metadata.generation_task) def generate( diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index e4fcbe100e..f4b7a5e741 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -817,18 +817,16 @@ def unarchive_study( @bp.post( "/studies/_invalidate_cache_listing", - summary="Invalidate the study listing cache", + summary="Invalidate the study listing cache [DEPRECATED] and will be removed soon", tags=[APITag.study_management], ) def invalidate_study_listing_cache( current_user: JWTUser = Depends(auth.get_current_user), ) -> t.Any: logger.info( - "Invalidating the study listing cache", + "Invalidating the study listing cache endpoint is deprecated", extra={"user": current_user.id}, ) - params = RequestParameters(user=current_user) - return study_service.invalidate_cache_listing(params) @bp.get( "/studies/{uuid}/disk-usage", diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index 12afde7b05..0b2f6c4e2f 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -3,9 +3,7 @@ from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore from antarest.core.cache.business.local_chache import LocalCache -from antarest.core.interfaces.cache import CacheConstants from antarest.login.model import Group, User -from antarest.study.common.utils import get_study_information from antarest.study.model import DEFAULT_WORKSPACE_NAME, PublicMode, RawStudy, Study, StudyContentStatus from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy @@ -100,36 +98,3 @@ def test_study_inheritance(): assert isinstance(b, RawStudy) assert b.path == "study" - - -@with_db_context -def test_cache(): - user = User(id=0, name="admin") - group = Group(id="my-group", name="group") - - cache = LocalCache() - - repo = StudyMetadataRepository(cache) - a = RawStudy( - name="a", - version="42", - author="John Smith", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - public_mode=PublicMode.FULL, - owner=user, - groups=[group], - workspace=DEFAULT_WORKSPACE_NAME, - path="study", - content_status=StudyContentStatus.WARNING, - ) - - repo.save(a) - cache.put( - CacheConstants.STUDY_LISTING.value, - {a.id: get_study_information(a)}, - ) - repo.save(a) - repo.delete(a.id) - - assert len(cache.get(CacheConstants.STUDY_LISTING.value)) == 0 diff --git a/tests/study/test_model.py b/tests/study/test_model.py index 6683366b48..04785a28df 100644 --- a/tests/study/test_model.py +++ b/tests/study/test_model.py @@ -120,3 +120,13 @@ def test_study_tag_relationship(self, db_session: Session) -> None: assert len(studies) == 1 assert set(study.id for study in studies) == {study_id_1} assert set(tag.label for tag in studies[0].tags) == {"test-tag-1"} + + # verify updating works + study = db_session.query(Study).get(study_id_1) + study.tags = [Tag(label="test-tag-2"), Tag(label="test-tag-3")] + db_session.merge(study) + db_session.commit() + study_tag_pairs = db_session.query(StudyTag).all() + assert len(study_tag_pairs) == 2 + assert set(e.tag_label for e in study_tag_pairs) == {"test-tag-2", "test-tag-3"} + assert set(e.study_id for e in study_tag_pairs) == {study_id_1} From dac11fda515ddf0c9db7218efc4572f39a2127e8 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Tue, 6 Feb 2024 14:48:06 +0100 Subject: [PATCH 02/11] test(tags-db): unittests for repository tags filter --- antarest/study/repository.py | 4 +- tests/study/test_repository.py | 92 +++++++++++++++++++++++++++++----- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index d8e2675cd6..077d4ccd3c 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -155,6 +155,7 @@ def get(self, id: str) -> t.Optional[Study]: self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) + .options(joinedload(Study.tags)) .get(id) # fmt: on ) @@ -171,6 +172,7 @@ def one(self, study_id: str) -> Study: self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) + .options(joinedload(Study.tags)) .filter_by(id=study_id) .one() ) @@ -228,7 +230,7 @@ def get_all( if study_filter.groups: q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) if study_filter.tags: - q = q.join(entity.tags).filter(Tag.id.in_(study_filter.tags)) + q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags)) if study_filter.archived is not None: q = q.filter(entity.archived == study_filter.archived) if study_filter.name: diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index 77e4f1554c..a5d08ca658 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -7,7 +7,7 @@ from antarest.core.interfaces.cache import ICache from antarest.login.model import Group, User -from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy +from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag from antarest.study.repository import StudyFilter, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.db_statement_recorder import DBStatementRecorder @@ -63,13 +63,14 @@ def test_repository_get_all__general_case( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -102,6 +103,7 @@ def test_repository_get_all__incompatible_case( _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) assert not {s.id for s in all_studies} @@ -112,6 +114,7 @@ def test_repository_get_all__incompatible_case( _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) assert not {s.id for s in all_studies} @@ -122,6 +125,7 @@ def test_repository_get_all__incompatible_case( _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) assert not {s.id for s in all_studies} @@ -162,13 +166,14 @@ def test_repository_get_all__study_name_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(name=name)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -206,13 +211,14 @@ def test_repository_get_all__managed_study_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(managed=managed)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -245,13 +251,14 @@ def test_repository_get_all__archived_study_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(archived=archived)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -284,13 +291,14 @@ def test_repository_get_all__variant_study_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(variant=variant)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -325,13 +333,14 @@ def test_repository_get_all__study_version_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(versions=versions)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -372,13 +381,14 @@ def test_repository_get_all__study_users_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(users=users)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -419,13 +429,14 @@ def test_repository_get_all__study_groups_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(groups=groups)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -461,13 +472,14 @@ def test_repository_get_all__study_ids_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(study_ids=study_ids)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -500,13 +512,14 @@ def test_repository_get_all__study_existence_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(exists=exists)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -540,13 +553,14 @@ def test_repository_get_all__study_workspace_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(workspace=workspace)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -582,13 +596,65 @@ def test_repository_get_all__study_folder_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(folder=folder)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] + + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "tags, expected_ids", + [ + ([], {"1", "2", "3", "4", "5", "6", "7", "8"}), + (["decennial"], {"2", "4", "6", "8"}), + (["winter_transition"], {"3", "4", "7", "8"}), + (["decennial", "winter_transition"], {"2", "3", "4", "6", "7", "8"}), + (["no-study-tag"], set()), + ], +) +def test_repository_get_all__study_tags_filter( + db_session: Session, + tags: t.List[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + test_tag_1 = Tag(label="hidden-tag") + test_tag_2 = Tag(label="decennial") + test_tag_3 = Tag(label="winter_transition") + + study_1 = VariantStudy(id=1, tags=[test_tag_1]) + study_2 = VariantStudy(id=2, tags=[test_tag_2]) + study_3 = VariantStudy(id=3, tags=[test_tag_3]) + study_4 = VariantStudy(id=4, tags=[test_tag_2, test_tag_3]) + study_5 = RawStudy(id=5, tags=[test_tag_1]) + study_6 = RawStudy(id=6, tags=[test_tag_2]) + study_7 = RawStudy(id=7, tags=[test_tag_3]) + study_8 = RawStudy(id=8, tags=[test_tag_2, test_tag_3]) + + db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does not require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(tags=tags)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: From 23ad643a69e2e58f79255560a767743fb7db5954 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Tue, 6 Feb 2024 19:22:51 +0100 Subject: [PATCH 03/11] test(tags-db): integration tests for study tags filtering and updating --- antarest/study/repository.py | 4 +- antarest/study/service.py | 6 +- .../studies_blueprint/test_get_studies.py | 171 +++++++++++++++--- 3 files changed, 151 insertions(+), 30 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 077d4ccd3c..7a6e99cae2 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -298,6 +298,8 @@ def update_tags(self, study: Study, new_tags: t.List[str]) -> None: """ logger.debug(f"Updating tags for study: {study.id}") - study.tags = [Tag(label=tag) for tag in new_tags] + existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all() + new_labels = set(new_tags) - set([tag.label for tag in existing_tags]) + study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags self.session.merge(study) self.session.commit() diff --git a/antarest/study/service.py b/antarest/study/service.py index 3d10b2bef1..20d4640ffe 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -551,6 +551,9 @@ def update_study_information( if metadata_patch.horizon: study.additional_data.horizon = metadata_patch.horizon + new_tags = metadata_patch.tags + self.repository.update_tags(study, new_tags) + new_metadata = self.storage_service.get_storage(study).patch_update_study_metadata(study, metadata_patch) self.event_bus.push( @@ -561,9 +564,6 @@ def update_study_information( ) ) - new_tags = new_metadata.tags - self.repository.update_tags(study, new_tags) - return new_metadata def check_study_access( diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index b134406e50..c9bf08bb62 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -266,24 +266,6 @@ def test_study_listing( assert res.status_code in CREATE_STATUS_CODES, res.json() archived_raw_850_id = res.json() - # create a variant study version 840 - res = client.post( - f"{STUDIES_URL}/{archived_raw_840_id}/variants", - headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"name": "archived-variant-840", "version": "840"}, - ) - assert res.status_code in CREATE_STATUS_CODES, res.json() - archived_variant_840_id = res.json() - - # create a variant study version 850 to be archived - res = client.post( - f"{STUDIES_URL}/{archived_raw_850_id}/variants", - headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"name": "archived-variant-850", "version": "850"}, - ) - assert res.status_code in CREATE_STATUS_CODES, res.json() - archived_variant_850_id = res.json() - # create a raw study to be transferred in folder1 zip_path = ASSETS_DIR / "STA-mini.zip" res = client.post( @@ -337,6 +319,120 @@ def test_study_listing( task = wait_task_completion(client, admin_access_token, archiving_study_task_id) assert task.status == TaskStatus.COMPLETED, task + # create a raw study version 840 to be tagged with `winter_transition` + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-raw-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_raw_840_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_raw_840_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["winter_transition"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-raw-840"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_raw_840_id).get("tags")) == {"winter_transition"} + + # create a raw study version 850 to be tagged with `decennial` + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "decennial-raw-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_raw_850_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_raw_850_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["decennial"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, params={"name": "decennial-raw-850"} + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_raw_850_id).get("tags")) == {"decennial"} + + # create a variant study version 840 to be tagged with `decennial` + res = client.post( + f"{STUDIES_URL}/{tagged_raw_840_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "decennial-variant-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_variant_840_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_variant_840_id}/generate", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + generation_task_id = res.json() + task = wait_task_completion(client, admin_access_token, generation_task_id) + assert task.status == TaskStatus.COMPLETED, task + res = client.put( + f"{STUDIES_URL}/{tagged_variant_840_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["decennial"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "decennial-variant-840"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_variant_840_id).get("tags")) == {"decennial"} + + # create a variant study version 850 to be tagged with `winter_transition` + res = client.post( + f"{STUDIES_URL}/{tagged_raw_850_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-variant-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_variant_850_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_variant_850_id}/generate", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + generation_task_id = res.json() + task = wait_task_completion(client, admin_access_token, generation_task_id) + assert task.status == TaskStatus.COMPLETED, task + res = client.put( + f"{STUDIES_URL}/{tagged_variant_850_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["winter_transition"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-variant-850"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_variant_850_id).get("tags")) == {"winter_transition"} + + # ========================== + # 2. Filtering testing + # ========================== + # the testing studies set all_studies = { raw_840_id, @@ -350,10 +446,12 @@ def test_study_listing( variant_860_id, archived_raw_840_id, archived_raw_850_id, - archived_variant_840_id, - archived_variant_850_id, folder1_study_id, to_be_deleted_id, + tagged_raw_840_id, + tagged_raw_850_id, + tagged_variant_840_id, + tagged_variant_850_id, } pm = operator.itemgetter("public_mode") @@ -392,15 +490,14 @@ def test_study_listing( [e for k, e in study_map.items() if k not in james_bond_studies], ) ) - # #TODO you need to update the permission for James Bond bot + # #TODO you need to update the permission for James Bond bot above # test 1.d for a user bot with access to select studies res = client.get( STUDIES_URL, headers={"Authorization": f"Bearer {james_bond_bot_token}"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() - # #TODO add the correct test assertions # ] = res.json() # assert not set(james_bond_studies).difference(study_map) @@ -477,8 +574,8 @@ def test_study_listing( variant_840_id, variant_850_id, variant_860_id, - archived_variant_840_id, - archived_variant_850_id, + tagged_variant_840_id, + tagged_variant_850_id, } # test 5.a get variant studies res = client.get( @@ -507,7 +604,8 @@ def test_study_listing( non_managed_850_id, variant_850_id, archived_raw_850_id, - archived_variant_850_id, + tagged_variant_850_id, + tagged_raw_850_id, } studies_version_860 = { raw_860_id, @@ -579,10 +677,31 @@ def test_study_listing( assert not all_studies.difference(group_x_studies.union(group_y_studies)).intersection(study_map) assert not group_x_studies.union(group_y_studies).difference(study_map) - # TODO you need to add filtering through tags to the search engine # tests (9) for tags filtering # test 9.a filtering for one tag: decennial + decennial_tagged_studies = {tagged_raw_850_id, tagged_variant_840_id} + winter_transition_tagged_studies = {tagged_raw_840_id, tagged_variant_850_id} + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"tags": f"decennial"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(decennial_tagged_studies).intersection(study_map) + assert not decennial_tagged_studies.difference(study_map) # test 9.b filtering for two tags: decennial,winter_transition + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"tags": f"decennial,winter_transition"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference( + decennial_tagged_studies.union(winter_transition_tagged_studies) + ).intersection(study_map) + assert not decennial_tagged_studies.union(winter_transition_tagged_studies).difference(study_map) # tests (10) for studies uuids sequence filtering # test 10.a filter for one uuid From 8e2800bded17ccb49c766fe317c0047a4f95c339 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Wed, 7 Feb 2024 16:45:17 +0100 Subject: [PATCH 04/11] test(tags-db): update following code review --- antarest/study/model.py | 6 ++++++ antarest/study/service.py | 7 +++---- .../study/storage/abstract_storage_service.py | 7 +++++-- antarest/study/storage/patch_service.py | 4 ++-- antarest/study/web/studies_blueprint.py | 15 +------------- .../studies_blueprint/test_get_studies.py | 20 +++---------------- tests/integration/test_integration.py | 3 --- 7 files changed, 20 insertions(+), 42 deletions(-) diff --git a/antarest/study/model.py b/antarest/study/model.py index f8218fd31d..4307e99945 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -88,6 +88,12 @@ class Tag(Base): # type:ignore def __str__(self) -> str: return f"[Tag] label={self.label}, css-color-code={self.color}" + def __repr__(self) -> str: + cls = self.__class__.__name__ + label = getattr(self, "label", None) + color = getattr(self, "color", None) + return f"{cls}(label={label!r}, color={color!r})" + class StudyContentStatus(enum.Enum): VALID = "VALID" diff --git a/antarest/study/service.py b/antarest/study/service.py index 20d4640ffe..0a5cbe3875 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -527,7 +527,7 @@ def update_study_information( params.get_user_id(), ) study = self.get_study(uuid) - assert_permission(params.user, study, StudyPermissionType.READ) + assert_permission(params.user, study, StudyPermissionType.WRITE) if metadata_patch.horizon: study_settings_url = "settings/generaldata/general" @@ -550,9 +550,8 @@ def update_study_information( study.additional_data.author = metadata_patch.author if metadata_patch.horizon: study.additional_data.horizon = metadata_patch.horizon - - new_tags = metadata_patch.tags - self.repository.update_tags(study, new_tags) + if metadata_patch.tags: + self.repository.update_tags(study, metadata_patch.tags) new_metadata = self.storage_service.get_storage(study).patch_update_study_metadata(study, metadata_patch) diff --git a/antarest/study/storage/abstract_storage_service.py b/antarest/study/storage/abstract_storage_service.py index cd6b8be080..af76bf533d 100644 --- a/antarest/study/storage/abstract_storage_service.py +++ b/antarest/study/storage/abstract_storage_service.py @@ -1,3 +1,4 @@ +import json import logging import shutil import tempfile @@ -74,8 +75,10 @@ def get_study_information( additional_data = study.additional_data or StudyAdditionalData() try: - patch = Patch.parse_raw(additional_data.patch or "{}") - except Exception as e: + patch_obj = json.loads(additional_data.patch or "{}") + patch = Patch.parse_obj(patch_obj) + except ValueError as e: + # The conversion to JSON and the parsing can fail if the patch is not valid logger.warning(f"Failed to parse patch for study {study.id}", exc_info=e) patch = Patch() diff --git a/antarest/study/storage/patch_service.py b/antarest/study/storage/patch_service.py index c71833beb2..e3ece9071e 100644 --- a/antarest/study/storage/patch_service.py +++ b/antarest/study/storage/patch_service.py @@ -53,6 +53,6 @@ def save(self, study: Union[RawStudy, VariantStudy], patch: Patch) -> None: study.additional_data.patch = patch.json() self.repository.save(study) - patch_content = patch.json() patch_path = (Path(study.path)) / "patch.json" - patch_path.write_text(patch_content) + patch_path.parent.mkdir(parents=True, exist_ok=True) + patch_path.write_text(patch.json()) diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index f4b7a5e741..beeecd65c5 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -803,7 +803,7 @@ def archive_study( @bp.put( "/studies/{study_id}/unarchive", - summary="Dearchive a study", + summary="Unarchive a study", tags=[APITag.study_management], ) def unarchive_study( @@ -815,19 +815,6 @@ def unarchive_study( params = RequestParameters(user=current_user) return study_service.unarchive(study_id, params) - @bp.post( - "/studies/_invalidate_cache_listing", - summary="Invalidate the study listing cache [DEPRECATED] and will be removed soon", - tags=[APITag.study_management], - ) - def invalidate_study_listing_cache( - current_user: JWTUser = Depends(auth.get_current_user), - ) -> t.Any: - logger.info( - "Invalidating the study listing cache endpoint is deprecated", - extra={"user": current_user.id}, - ) - @bp.get( "/studies/{uuid}/disk-usage", summary="Compute study disk usage", diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index c9bf08bb62..869466e9f4 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -358,7 +358,9 @@ def test_study_listing( ) assert res.status_code in CREATE_STATUS_CODES, res.json() res = client.get( - STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, params={"name": "decennial-raw-850"} + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "decennial-raw-850"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() @@ -373,14 +375,6 @@ def test_study_listing( ) assert res.status_code in CREATE_STATUS_CODES, res.json() tagged_variant_840_id = res.json() - res = client.put( - f"{STUDIES_URL}/{tagged_variant_840_id}/generate", - headers={"Authorization": f"Bearer {admin_access_token}"}, - ) - assert res.status_code == LIST_STATUS_CODE, res.json() - generation_task_id = res.json() - task = wait_task_completion(client, admin_access_token, generation_task_id) - assert task.status == TaskStatus.COMPLETED, task res = client.put( f"{STUDIES_URL}/{tagged_variant_840_id}", headers={"Authorization": f"Bearer {admin_access_token}"}, @@ -405,14 +399,6 @@ def test_study_listing( ) assert res.status_code in CREATE_STATUS_CODES, res.json() tagged_variant_850_id = res.json() - res = client.put( - f"{STUDIES_URL}/{tagged_variant_850_id}/generate", - headers={"Authorization": f"Bearer {admin_access_token}"}, - ) - assert res.status_code == LIST_STATUS_CODE, res.json() - generation_task_id = res.json() - task = wait_task_completion(client, admin_access_token, generation_task_id) - assert task.status == TaskStatus.COMPLETED, task res = client.put( f"{STUDIES_URL}/{tagged_variant_850_id}", headers={"Authorization": f"Bearer {admin_access_token}"}, diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 03e3cbafe6..ffa4d4a91d 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -223,9 +223,6 @@ def test_main(client: TestClient, admin_access_token: str, study_id: str) -> Non assert len(res.json()) == 3 assert filter(lambda s: s["id"] == copied.json(), res.json().values()).__next__()["folder"] == "foo/bar" - res = client.post("/v1/studies/_invalidate_cache_listing", headers=admin_headers) - assert res.status_code == 200 - # Study delete client.delete( f"/v1/studies/{copied.json()}", From 50d3f5d07e5598a6acc762fee1b422ce38f23701 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 7 Feb 2024 23:25:37 +0100 Subject: [PATCH 05/11] test(tags-db): correct typing in unit tests --- .../studies_blueprint/test_get_studies.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index 869466e9f4..54292345f5 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -341,7 +341,7 @@ def test_study_listing( assert res.status_code == LIST_STATUS_CODE, res.json() study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() assert len(study_map) == 1 - assert set(study_map.get(tagged_raw_840_id).get("tags")) == {"winter_transition"} + assert set(study_map[tagged_raw_840_id]["tags"]) == {"winter_transition"} # create a raw study version 850 to be tagged with `decennial` res = client.post( @@ -363,9 +363,9 @@ def test_study_listing( params={"name": "decennial-raw-850"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() - study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + study_map = res.json() assert len(study_map) == 1 - assert set(study_map.get(tagged_raw_850_id).get("tags")) == {"decennial"} + assert set(study_map[tagged_raw_850_id]["tags"]) == {"decennial"} # create a variant study version 840 to be tagged with `decennial` res = client.post( @@ -387,9 +387,9 @@ def test_study_listing( params={"name": "decennial-variant-840"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() - study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + study_map = res.json() assert len(study_map) == 1 - assert set(study_map.get(tagged_variant_840_id).get("tags")) == {"decennial"} + assert set(study_map[tagged_variant_840_id]["tags"]) == {"decennial"} # create a variant study version 850 to be tagged with `winter_transition` res = client.post( @@ -411,9 +411,9 @@ def test_study_listing( params={"name": "winter-transition-variant-850"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() - study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + study_map = res.json() assert len(study_map) == 1 - assert set(study_map.get(tagged_variant_850_id).get("tags")) == {"winter_transition"} + assert set(study_map[tagged_variant_850_id]["tags"]) == {"winter_transition"} # ========================== # 2. Filtering testing @@ -449,7 +449,7 @@ def test_study_listing( headers={"Authorization": f"Bearer {john_doe_access_token}"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() - study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + study_map = res.json() assert not all_studies.intersection(study_map) assert all(map(lambda x: pm(x) in [PublicMode.READ, PublicMode.FULL], study_map.values())) @@ -499,7 +499,7 @@ def test_study_listing( res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, params={"name": "840"}) assert res.status_code == LIST_STATUS_CODE, res.json() study_map = res.json() - assert all(map(lambda x: "840" in x.get("name"), study_map.values())) and len(study_map) >= 5 + assert all(map(lambda x: "840" in x["name"], study_map.values())) and len(study_map) >= 5 # test 2.b with no matching studies res = client.get( STUDIES_URL, @@ -670,7 +670,7 @@ def test_study_listing( res = client.get( STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"tags": f"decennial"}, + params={"tags": "decennial"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() study_map = res.json() @@ -680,7 +680,7 @@ def test_study_listing( res = client.get( STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"tags": f"decennial,winter_transition"}, + params={"tags": "decennial,winter_transition"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() study_map = res.json() From 136efca3da8d68874df6c2115e18994c9053492e Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 7 Feb 2024 23:33:36 +0100 Subject: [PATCH 06/11] feat(tags-db): remove calls to logging function from `StudyMetadataRepository` (useless) --- antarest/study/repository.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 7a6e99cae2..ef434d94ab 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -1,6 +1,5 @@ import datetime import enum -import logging import typing as t from pydantic import BaseModel, NonNegativeInt @@ -12,8 +11,6 @@ from antarest.login.model import Group from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag -logger = logging.getLogger(__name__) - def escape_like(string: str, escape_char: str = "\\") -> str: """ @@ -127,7 +124,6 @@ def save( update_modification_date: bool = False, ) -> Study: metadata_id = metadata.id or metadata.name - logger.debug(f"Saving study {metadata_id}") if update_modification_date: metadata.updated_at = datetime.datetime.utcnow() @@ -279,7 +275,6 @@ def get_all_raw(self, exists: t.Optional[bool] = None) -> t.List[RawStudy]: return studies def delete(self, id: str) -> None: - logger.debug(f"Deleting study {id}") session = self.session u: Study = session.query(Study).get(id) session.delete(u) @@ -297,7 +292,6 @@ def update_tags(self, study: Study, new_tags: t.List[str]) -> None: Returns: """ - logger.debug(f"Updating tags for study: {study.id}") existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all() new_labels = set(new_tags) - set([tag.label for tag in existing_tags]) study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags From 9a79c5d58fc69c75a0497c4ee8224aa68053e55c Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 7 Feb 2024 23:45:58 +0100 Subject: [PATCH 07/11] docs(tags-db): rephrase the function documentations in docstrings --- antarest/study/repository.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index ef434d94ab..62b7904a28 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -123,7 +123,6 @@ def save( metadata: Study, update_modification_date: bool = False, ) -> Study: - metadata_id = metadata.id or metadata.name if update_modification_date: metadata.updated_at = datetime.datetime.utcnow() @@ -185,21 +184,20 @@ def get_all( pagination: StudyPagination = StudyPagination(), ) -> t.List[Study]: """ - This function goal is to create a search engine throughout the studies with optimal - runtime. + Retrieve studies based on specified filters, sorting, and pagination. Args: - study_filter: composed of all filtering criteria - sort_by: how the user would like the results to be sorted - pagination: specifies the number of results to displayed in each page and the actually displayed page + study_filter: composed of all filtering criteria. + sort_by: how the user would like the results to be sorted. + pagination: specifies the number of results to displayed in each page and the actually displayed page. Returns: - The matching studies in proper order and pagination + The matching studies in proper order and pagination. """ # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. # We also need to fetch the additional data to display the study information - # efficiently (see: `utils.get_study_information`) + # efficiently (see: `AbstractStorageService.get_study_information`) entity = with_polymorphic(Study, "*") # noinspection PyTypeChecker @@ -282,15 +280,12 @@ def delete(self, id: str) -> None: def update_tags(self, study: Study, new_tags: t.List[str]) -> None: """ - Using the repository session we can update the study tags on the DB. - Thus, the tables `study_tag` and `tag` will be updated too accordingly. + Updates the tags associated with a given study in the database, + replacing existing tags with new ones. Args: - study: a pre-existing study to be updated with the new tags - new_tags: the new tags to be associated with the input study on the db - - Returns: - + study: The pre-existing study to be updated with the new tags. + new_tags: The new tags to be associated with the input study in the database. """ existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all() new_labels = set(new_tags) - set([tag.label for tag in existing_tags]) From ee707671d9b29e5da2856a7a9aaf2ae0a28190b9 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 7 Feb 2024 23:57:43 +0100 Subject: [PATCH 08/11] docs(tags-db): improve the documentation of the `Tag` and `StudyTag` models --- antarest/study/model.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/antarest/study/model.py b/antarest/study/model.py index 4307e99945..a8510f3494 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -61,6 +61,10 @@ class StudyTag(Base): # type:ignore """ A table to manage the many-to-many relationship between `Study` and `Tag` + + Attributes: + study_id (str): The ID of the study associated with the tag. + tag_label (str): The label of the tag associated with the study. """ __tablename__ = "study_tag" @@ -69,13 +73,25 @@ class StudyTag(Base): # type:ignore study_id: str = Column(String(36), ForeignKey("study.id", ondelete="CASCADE"), index=True, nullable=False) tag_label: str = Column(String(40), ForeignKey("tag.label", ondelete="CASCADE"), index=True, nullable=False) - def __str__(self) -> str: + def __str__(self) -> str: # pragma: no cover return f"[StudyTag] study_id={self.study_id}, tag={self.tag}" + def __repr__(self) -> str: # pragma: no cover + cls_name = self.__class__.__name__ + study_id = self.study_id + tag = self.tag + return f"{cls_name}({study_id=}, {tag=})" + class Tag(Base): # type:ignore """ - A table to store all tags + Represents a tag in the database. + + This class is used to store tags associated with studies. + + Attributes: + label (str): The label of the tag. + color (str): The color code associated with the tag. """ __tablename__ = "tag" @@ -85,14 +101,14 @@ class Tag(Base): # type:ignore studies: t.List["Study"] = relationship("Study", secondary=StudyTag.__table__, back_populates="tags") - def __str__(self) -> str: - return f"[Tag] label={self.label}, css-color-code={self.color}" + def __str__(self) -> str: # pragma: no cover + return t.cast(str, self.label) - def __repr__(self) -> str: - cls = self.__class__.__name__ - label = getattr(self, "label", None) - color = getattr(self, "color", None) - return f"{cls}(label={label!r}, color={color!r})" + def __repr__(self) -> str: # pragma: no cover + cls_name = self.__class__.__name__ + label = self.label + color = self.color + return f"{cls_name}({label=}, {color=})" class StudyContentStatus(enum.Enum): From e95c3e3086936bbffe46de31c2b2bd22440dad95 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 8 Feb 2024 00:11:31 +0100 Subject: [PATCH 09/11] style(tags-db): correct typing --- tests/storage/repository/test_study.py | 11 +++++------ tests/study/test_repository.py | 12 ++++++------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index 0b2f6c4e2f..f865ab613a 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -1,17 +1,16 @@ from datetime import datetime -from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore - from antarest.core.cache.business.local_chache import LocalCache +from antarest.core.model import PublicMode from antarest.login.model import Group, User -from antarest.study.model import DEFAULT_WORKSPACE_NAME, PublicMode, RawStudy, Study, StudyContentStatus +from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyContentStatus from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.helpers import with_db_context @with_db_context -def test_cyclelife(): +def test_lifecycle() -> None: user = User(id=0, name="admin") group = Group(id="my-group", name="group") repo = StudyMetadataRepository(LocalCache()) @@ -62,7 +61,7 @@ def test_cyclelife(): repo.save(c) repo.save(d) assert b.id - c = repo.get(a.id) + c = repo.one(a.id) assert a == c assert len(repo.get_all()) == 4 @@ -75,7 +74,7 @@ def test_cyclelife(): @with_db_context -def test_study_inheritance(): +def test_study_inheritance() -> None: user = User(id=0, name="admin") group = Group(id="my-group", name="group") repo = StudyMetadataRepository(LocalCache()) diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index a5d08ca658..d30c051a6a 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -41,7 +41,7 @@ def test_repository_get_all__general_case( db_session: Session, managed: t.Union[bool, None], - study_ids: t.List[str], + study_ids: t.Sequence[str], exists: t.Union[bool, None], expected_ids: t.Set[str], ) -> None: @@ -317,7 +317,7 @@ def test_repository_get_all__variant_study_filter( ) def test_repository_get_all__study_version_filter( db_session: Session, - versions: t.List[str], + versions: t.Sequence[str], expected_ids: t.Set[str], ) -> None: icache: Mock = Mock(spec=ICache) @@ -359,7 +359,7 @@ def test_repository_get_all__study_version_filter( ) def test_repository_get_all__study_users_filter( db_session: Session, - users: t.List["int"], + users: t.Sequence["int"], expected_ids: t.Set[str], ) -> None: icache: Mock = Mock(spec=ICache) @@ -407,7 +407,7 @@ def test_repository_get_all__study_users_filter( ) def test_repository_get_all__study_groups_filter( db_session: Session, - groups: t.List[str], + groups: t.Sequence[str], expected_ids: t.Set[str], ) -> None: icache: Mock = Mock(spec=ICache) @@ -456,7 +456,7 @@ def test_repository_get_all__study_groups_filter( ) def test_repository_get_all__study_ids_filter( db_session: Session, - study_ids: t.List[str], + study_ids: t.Sequence[str], expected_ids: t.Set[str], ) -> None: icache: Mock = Mock(spec=ICache) @@ -623,7 +623,7 @@ def test_repository_get_all__study_folder_filter( ) def test_repository_get_all__study_tags_filter( db_session: Session, - tags: t.List[str], + tags: t.Sequence[str], expected_ids: t.Set[str], ) -> None: icache: Mock = Mock(spec=ICache) From aaaa37b34b6020b49e30b8e7b3558ebe933dd669 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 8 Feb 2024 00:18:51 +0100 Subject: [PATCH 10/11] docs(tags-db): improve docstring in `StudyService` --- antarest/study/service.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/antarest/study/service.py b/antarest/study/service.py index 0a5cbe3875..b80b0b4438 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -30,7 +30,7 @@ ) from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.filetransfer.service import FileTransferManager -from antarest.core.interfaces.cache import CacheConstants, ICache +from antarest.core.interfaces.cache import ICache from antarest.core.interfaces.eventbus import Event, EventType, IEventBus from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTGroup, JWTUser from antarest.core.model import JSON, SUB_JSON, PermissionInfo, PublicMode, StudyPermissionType @@ -492,18 +492,19 @@ def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadata def get_study_information(self, uuid: str, params: RequestParameters) -> StudyMetadataDTO: """ - Get study information - Args: - uuid: study uuid - params: request parameters + Retrieve study information. - Returns: study information + Args: + uuid: The UUID of the study. + params: The request parameters. + Returns: + Information about the study. """ study = self.get_study(uuid) assert_permission(params.user, study, StudyPermissionType.READ) - logger.info("study %s metadata asked by user %s", uuid, params.get_user_id()) - # todo debounce this with a "update_study_last_access" method updating only every some seconds + logger.info("Study metadata requested for study %s by user %s", uuid, params.get_user_id()) + # TODO: Debounce this with an "update_study_last_access" method updating only every few seconds. study.last_access = datetime.utcnow() self.repository.save(study) return self.storage_service.get_storage(study).get_study_information(study) @@ -1994,14 +1995,15 @@ def _assert_study_unarchived(self, study: Study, raise_exception: bool = True) - def _analyse_study(self, metadata: Study) -> StudyContentStatus: """ - Analyze study integrity - Args: - metadata: study to analyze + Analyzes the integrity of a study. - Returns: VALID if study has any integrity mistakes. - WARNING if studies has mistakes. - ERROR if tree was not able to analyse structuree without raise error. + Args: + metadata: The study to analyze. + Returns: + - VALID if the study has no integrity issues. + - WARNING if the study has some issues. + - ERROR if the tree was unable to analyze the structure without raising an error. """ try: if not isinstance(metadata, RawStudy): From de3f20f2b2bd315b7ced597a3e8f98a06f3e0d59 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 8 Feb 2024 00:22:27 +0100 Subject: [PATCH 11/11] style(tags-db): correct typing in `StudyMetadataRepository` --- antarest/study/model.py | 2 +- antarest/study/repository.py | 10 +++++----- antarest/study/storage/auto_archive_service.py | 11 +++-------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/antarest/study/model.py b/antarest/study/model.py index a8510f3494..fe10b4f211 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -32,7 +32,7 @@ DEFAULT_WORKSPACE_NAME = "default" -STUDY_REFERENCE_TEMPLATES: t.Dict[str, str] = { +STUDY_REFERENCE_TEMPLATES: t.Mapping[str, str] = { "600": "empty_study_613.zip", "610": "empty_study_613.zip", "640": "empty_study_613.zip", diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 62b7904a28..3aa6e60681 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -182,7 +182,7 @@ def get_all( study_filter: StudyFilter = StudyFilter(), sort_by: t.Optional[StudySortBy] = None, pagination: StudyPagination = StudyPagination(), - ) -> t.List[Study]: + ) -> t.Sequence[Study]: """ Retrieve studies based on specified filters, sorting, and pagination. @@ -259,17 +259,17 @@ def get_all( if pagination.page_nb or pagination.page_size: q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size) - studies: t.List[Study] = q.all() + studies: t.Sequence[Study] = q.all() return studies - def get_all_raw(self, exists: t.Optional[bool] = None) -> t.List[RawStudy]: + def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]: query = self.session.query(RawStudy) if exists is not None: if exists: query = query.filter(RawStudy.missing.is_(None)) else: query = query.filter(not_(RawStudy.missing.is_(None))) - studies: t.List[RawStudy] = query.all() + studies: t.Sequence[RawStudy] = query.all() return studies def delete(self, id: str) -> None: @@ -278,7 +278,7 @@ def delete(self, id: str) -> None: session.delete(u) session.commit() - def update_tags(self, study: Study, new_tags: t.List[str]) -> None: + def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None: """ Updates the tags associated with a given study in the database, replacing existing tags with new ones. diff --git a/antarest/study/storage/auto_archive_service.py b/antarest/study/storage/auto_archive_service.py index b2ae1fae63..911b715f2d 100644 --- a/antarest/study/storage/auto_archive_service.py +++ b/antarest/study/storage/auto_archive_service.py @@ -1,7 +1,7 @@ import datetime import logging import time -from typing import List, Tuple +import typing as t from antarest.core.config import Config from antarest.core.exceptions import TaskAlreadyRunning @@ -12,7 +12,6 @@ from antarest.study.model import RawStudy, Study from antarest.study.repository import StudyFilter from antarest.study.service import StudyService -from antarest.study.storage.utils import is_managed from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy logger = logging.getLogger(__name__) @@ -29,12 +28,8 @@ def __init__(self, study_service: StudyService, config: Config): def _try_archive_studies(self) -> None: old_date = datetime.datetime.utcnow() - datetime.timedelta(days=self.config.storage.auto_archive_threshold_days) with db(): - studies: List[Study] = self.study_service.repository.get_all( - study_filter=StudyFilter( - managed=True, - ) - ) - # list of study id and boolean indicating if it's a raw study (True) or a variant (False) + studies: t.Sequence[Study] = self.study_service.repository.get_all(study_filter=StudyFilter(managed=True)) + # list of study IDs and boolean indicating if it's a raw study (True) or a variant (False) study_ids_to_archive = [ (study.id, isinstance(study, RawStudy)) for study in studies