diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index 25a777ef5a..86b65ec9ce 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -38,6 +38,7 @@ XpansionParametersDTO, ) from antarest.launcher.repository import JobResultRepository +from antarest.study.repository import StudyFilter from antarest.study.service import StudyService from antarest.study.storage.utils import assert_permission, extract_output_name, find_single_output_path @@ -305,8 +306,14 @@ def _filter_from_user_permission(self, job_results: List[JobResult], user: Optio orphan_visibility_threshold = datetime.utcnow() - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD) allowed_job_results = [] - studies_ids = [job_result.study_id for job_result in job_results] - studies = {study.id: study for study in self.study_service.repository.get_all(studies_ids=studies_ids)} + study_ids = [job_result.study_id for job_result in job_results] + if study_ids: + studies = { + study.id: study + for study in self.study_service.repository.get_all(study_filter=StudyFilter(study_ids=study_ids)) + } + else: + studies = {} for job_result in job_results: if job_result.study_id in studies: diff --git a/antarest/study/repository.py b/antarest/study/repository.py index ac7f730fca..e4646e1546 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -1,18 +1,97 @@ import datetime +import enum import logging import typing as t -from sqlalchemy import and_, or_ # type: ignore +from pydantic import BaseModel, NonNegativeInt +from sqlalchemy import func, not_, or_ # type: ignore from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore from antarest.core.interfaces.cache import CacheConstants, ICache from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.login.model import Group from antarest.study.common.utils import get_study_information from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData logger = logging.getLogger(__name__) +def escape_like(string: str, escape_char: str = "\\") -> str: + """ + Escape the string parameter used in SQL LIKE expressions. + + Examples:: + + from sqlalchemy_utils import escape_like + + query = session.query(User).filter( + User.name.ilike(escape_like('John')) + ) + + Args: + string: a string to escape + escape_char: escape character + + Returns: + Escaped string. + """ + return string.replace(escape_char, escape_char * 2).replace("%", escape_char + "%").replace("_", escape_char + "_") + + +class StudyFilter(BaseModel, frozen=True, extra="forbid"): + """Study filter class gathering the main filtering parameters + + Attributes: + name: optional name regex of the study to match + managed: indicate if just managed studies should be retrieved + archived: optional if the study is archived + variant: optional if the study is raw study + versions: versions to filter by + users: users to filter by + groups: groups to filter by + tags: tags to filter by + study_ids: study IDs to filter by + exists: if raw study missing + workspace: optional workspace of the study + folder: optional folder prefix of the study + """ + + name: str = "" + managed: t.Optional[bool] = None + archived: t.Optional[bool] = None + variant: t.Optional[bool] = None + versions: t.Sequence[str] = () + users: t.Sequence[int] = () + groups: t.Sequence[str] = () + tags: t.Sequence[str] = () + study_ids: t.Sequence[str] = () + exists: t.Optional[bool] = None + workspace: str = "" + folder: str = "" + + +class StudySortBy(str, enum.Enum): + """How to sort the results of studies query results""" + + NAME_ASC = "+name" + NAME_DESC = "-name" + DATE_ASC = "+date" + DATE_DESC = "-date" + + +class StudyPagination(BaseModel, frozen=True, extra="forbid"): + """ + Pagination of a studies query results + + Attributes: + page_nb: offset + page_size: SQL limit + """ + + page_nb: NonNegativeInt = 0 + page_size: NonNegativeInt = 0 + + class StudyMetadataRepository: """ Database connector to manage Study entity @@ -70,6 +149,9 @@ def refresh(self, metadata: Study) -> None: def get(self, id: str) -> t.Optional[Study]: """Get the study by ID or return `None` if not found in database.""" + # todo: I think we should use a `entity = with_polymorphic(Study, "*")` + # to make sure RawStudy and VariantStudy fields are also fetched. + # see: antarest.study.service.StudyService.delete_study # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. study: Study = ( @@ -84,6 +166,9 @@ def get(self, id: str) -> t.Optional[Study]: def one(self, study_id: str) -> Study: """Get the study by ID or raise `sqlalchemy.exc.NoResultFound` if not found in database.""" + # todo: I think we should use a `entity = with_polymorphic(Study, "*")` + # to make sure RawStudy and VariantStudy fields are also fetched. + # see: antarest.study.service.StudyService.delete_study # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. study: Study = ( @@ -101,37 +186,94 @@ def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]: def get_all( self, - managed: t.Optional[bool] = None, - studies_ids: t.Optional[t.List[str]] = None, - exists: bool = True, + study_filter: StudyFilter = StudyFilter(), + sort_by: t.Optional[StudySortBy] = None, + pagination: StudyPagination = StudyPagination(), ) -> t.List[Study]: + """ + This function goal is to create a search engine throughout the studies with optimal + runtime. + + Args: + study_filter: composed of all filtering criteria + sort_by: how the user would like the results to be sorted + pagination: specifies the number of results to displayed in each page and the actually displayed page + + Returns: + The matching studies in proper order and pagination + """ # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. # We also need to fetch the additional data to display the study information # efficiently (see: `utils.get_study_information`) entity = with_polymorphic(Study, "*") + # noinspection PyTypeChecker q = self.session.query(entity) - if exists: - q = q.filter(RawStudy.missing.is_(None)) + if study_filter.exists is not None: + if study_filter.exists: + q = q.filter(RawStudy.missing.is_(None)) + else: + q = q.filter(not_(RawStudy.missing.is_(None))) q = q.options(joinedload(entity.owner)) q = q.options(joinedload(entity.groups)) q = q.options(joinedload(entity.additional_data)) - if managed is not None: - if managed: + if study_filter.managed is not None: + if study_filter.managed: q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME)) else: q = q.filter(entity.type == "rawstudy") q = q.filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME) - if studies_ids is not None: - q = q.filter(entity.id.in_(studies_ids)) + if study_filter.study_ids: + q = q.filter(entity.id.in_(study_filter.study_ids)) + if study_filter.users: + q = q.filter(entity.owner_id.in_(study_filter.users)) + if study_filter.groups: + q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) + if study_filter.archived is not None: + q = q.filter(entity.archived == study_filter.archived) + if study_filter.name: + regex = f"%{escape_like(study_filter.name)}%" + q = q.filter(entity.name.ilike(regex)) + if study_filter.folder: + regex = f"{escape_like(study_filter.folder)}%" + q = q.filter(entity.folder.ilike(regex)) + if study_filter.workspace: + q = q.filter(RawStudy.workspace == study_filter.workspace) + if study_filter.variant is not None: + if study_filter.variant: + q = q.filter(entity.type == "variantstudy") + else: + q = q.filter(entity.type == "rawstudy") + if study_filter.versions: + q = q.filter(entity.version.in_(study_filter.versions)) + + if sort_by: + if sort_by == StudySortBy.DATE_DESC: + q = q.order_by(entity.created_at.desc()) + elif sort_by == StudySortBy.DATE_ASC: + q = q.order_by(entity.created_at.asc()) + elif sort_by == StudySortBy.NAME_DESC: + q = q.order_by(func.upper(entity.name).desc()) + elif sort_by == StudySortBy.NAME_ASC: + q = q.order_by(func.upper(entity.name).asc()) + else: + raise NotImplementedError(sort_by) + + # pagination + if pagination.page_nb or pagination.page_size: + q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size) + studies: t.List[Study] = q.all() return studies - def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: + def get_all_raw(self, exists: t.Optional[bool] = None) -> t.List[RawStudy]: query = self.session.query(RawStudy) - if not show_missing: - query = query.filter(RawStudy.missing.is_(None)) + if exists is not None: + if exists: + query = query.filter(RawStudy.missing.is_(None)) + else: + query = query.filter(not_(RawStudy.missing.is_(None))) studies: t.List[RawStudy] = query.all() return studies diff --git a/antarest/study/service.py b/antarest/study/service.py index 7a9ee71507..98d36dee9b 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -4,11 +4,11 @@ import json import logging import os +import time +import typing as t from datetime import datetime, timedelta from http import HTTPStatus from pathlib import Path, PurePosixPath -from time import time -from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast from uuid import uuid4 import numpy as np @@ -92,7 +92,7 @@ StudyMetadataPatchDTO, StudySimResultDTO, ) -from antarest.study.repository import StudyMetadataRepository +from antarest.study.repository import StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO from antarest.study.storage.rawstudy.model.filesystem.folder_node import ChildNotFoundError from antarest.study.storage.rawstudy.model.filesystem.ini_file_node import IniFileNode @@ -109,7 +109,7 @@ should_study_be_denormalized, upgrade_study, ) -from antarest.study.storage.utils import assert_permission, get_start_date, is_managed, remove_from_cache, study_matcher +from antarest.study.storage.utils import assert_permission, get_start_date, is_managed, remove_from_cache from antarest.study.storage.variantstudy.model.command.icommand import ICommand from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix from antarest.study.storage.variantstudy.model.command.update_comments import UpdateComments @@ -126,7 +126,7 @@ MAX_MISSING_STUDY_TIMEOUT = 2 # days -def get_disk_usage(path: Union[str, Path]) -> int: +def get_disk_usage(path: t.Union[str, Path]) -> int: path = Path(path) if path.suffix.lower() in {".zip", "7z"}: return os.path.getsize(path) @@ -269,9 +269,9 @@ def __init__( self.binding_constraint_manager = BindingConstraintManager(self.storage_service) self.cache_service = cache_service self.config = config - self.on_deletion_callbacks: List[Callable[[str], None]] = [] + self.on_deletion_callbacks: t.List[t.Callable[[str], None]] = [] - def add_on_deletion_callback(self, callback: Callable[[str], None]) -> None: + def add_on_deletion_callback(self, callback: t.Callable[[str], None]) -> None: self.on_deletion_callbacks.append(callback) def _on_study_delete(self, uuid: str) -> None: @@ -311,7 +311,7 @@ def get_logs( job_id: str, err_log: bool, params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: study = self.get_study(study_id) assert_permission(params.user, study, StudyPermissionType.READ) file_study = self.storage_service.get_storage(study).get_raw(study) @@ -331,7 +331,7 @@ def get_logs( empty_log = False for log_location in log_locations[err_log]: try: - log = cast( + log = t.cast( bytes, file_study.tree.get(log_location, depth=1, formatted=True), ).decode(encoding="utf-8") @@ -367,9 +367,9 @@ def save_logs( f"{job_id}-{log_suffix}", ], ) - stopwatch.log_elapsed(lambda t: logger.info(f"Saved logs for job {job_id} in {t}s")) + stopwatch.log_elapsed(lambda d: logger.info(f"Saved logs for job {job_id} in {d}s")) - def get_comments(self, study_id: str, params: RequestParameters) -> Union[str, JSON]: + def get_comments(self, study_id: str, params: RequestParameters) -> t.Union[str, JSON]: """ Get the comments of a study. @@ -382,7 +382,7 @@ def get_comments(self, study_id: str, params: RequestParameters) -> Union[str, J study = self.get_study(study_id) assert_permission(params.user, study, StudyPermissionType.READ) - output: Union[str, JSON] + output: t.Union[str, JSON] raw_study_service = self.storage_service.raw_study_service variant_study_service = self.storage_service.variant_study_service if isinstance(study, RawStudy): @@ -438,44 +438,33 @@ def edit_comments( def get_studies_information( self, - managed: bool, - name: Optional[str], - workspace: Optional[str], - folder: Optional[str], params: RequestParameters, - ) -> Dict[str, StudyMetadataDTO]: + study_filter: StudyFilter, + sort_by: t.Optional[StudySortBy] = None, + pagination: StudyPagination = StudyPagination(), + ) -> t.Dict[str, StudyMetadataDTO]: """ - Get information for all studies. + Get information for matching studies of a search query. Args: - managed: indicate if just managed studies should be retrieved - name: optional name of the study to match - folder: optional folder prefix of the study to match - workspace: optional workspace of the study to match params: request parameters + study_filter: filtering parameters + sort_by: how to sort the db query results + pagination: set offset and limit for db query Returns: List of study information - """ - logger.info("Fetching study listing") - studies: Dict[str, StudyMetadataDTO] = {} - cache_key = CacheConstants.STUDY_LISTING.value - cached_studies = self.cache_service.get(cache_key) - if cached_studies: - for k in cached_studies: - studies[k] = StudyMetadataDTO.parse_obj(cached_studies[k]) - else: - if managed: - logger.info("Retrieving all managed studies") - all_studies = self.repository.get_all(managed=True) - else: - logger.info("Retrieving all studies") - all_studies = self.repository.get_all() - logger.info("Studies retrieved") - for study in all_studies: - study_metadata = self._try_get_studies_information(study) - if study_metadata is not None: - studies[study_metadata.id] = study_metadata - self.cache_service.put(cache_key, studies) + logger.info("Retrieving matching studies") + studies: t.Dict[str, StudyMetadataDTO] = {} + matching_studies = self.repository.get_all( + study_filter=study_filter, + sort_by=sort_by, + pagination=pagination, + ) + logger.info("Studies retrieved") + for study in matching_studies: + study_metadata = self._try_get_studies_information(study) + if study_metadata is not None: + studies[study_metadata.id] = study_metadata return { s.id: s for s in filter( @@ -484,14 +473,12 @@ def get_studies_information( study_dto, StudyPermissionType.READ, raising=False, - ) - and study_matcher(name, workspace, folder)(study_dto) - and (not managed or study_dto.managed), + ), studies.values(), ) } - def _try_get_studies_information(self, study: Study) -> Optional[StudyMetadataDTO]: + def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]: try: return self.storage_service.get_storage(study).get_study_information(study) except Exception as e: @@ -612,8 +599,8 @@ def get_study_path(self, uuid: str, params: RequestParameters) -> Path: def create_study( self, study_name: str, - version: Optional[str], - group_ids: List[str], + version: t.Optional[str], + group_ids: t.List[str], params: RequestParameters, ) -> str: """ @@ -693,7 +680,9 @@ def get_study_synthesis(self, study_id: str, params: RequestParameters) -> FileS study_storage_service = self.storage_service.get_storage(study) return study_storage_service.get_synthesis(study, params) - def get_input_matrix_startdate(self, study_id: str, path: Optional[str], params: RequestParameters) -> MatrixIndex: + def get_input_matrix_startdate( + self, study_id: str, path: t.Optional[str], params: RequestParameters + ) -> MatrixIndex: study = self.get_study(study_id) assert_permission(params.user, study, StudyPermissionType.READ) file_study = self.storage_service.get_storage(study).get_raw(study) @@ -709,7 +698,7 @@ def get_input_matrix_startdate(self, study_id: str, path: Optional[str], params: return get_start_date(file_study, output_id, level) def remove_duplicates(self) -> None: - study_paths: Dict[str, List[str]] = {} + study_paths: t.Dict[str, t.List[str]] = {} for study in self.repository.get_all(): if isinstance(study, RawStudy) and not study.archived: path = str(study.path) @@ -724,7 +713,7 @@ def remove_duplicates(self) -> None: logger.info(f"Removing study {study_name}") self.repository.delete(study_name) - def sync_studies_on_disk(self, folders: List[StudyFolder], directory: Optional[Path] = None) -> None: + def sync_studies_on_disk(self, folders: t.List[StudyFolder], directory: t.Optional[Path] = None) -> None: """ Used by watcher to send list of studies present on filesystem. @@ -833,7 +822,7 @@ def copy_study( self, src_uuid: str, dest_study_name: str, - group_ids: List[str], + group_ids: t.List[str], use_task: bool, params: RequestParameters, with_outputs: bool = False, @@ -967,7 +956,7 @@ def output_variables_information( study_uuid: str, output_uuid: str, params: RequestParameters, - ) -> Dict[str, List[str]]: + ) -> t.Dict[str, t.List[str]]: """ Returns information about output variables using thematic and geographic trimming information Args: @@ -1040,7 +1029,7 @@ def export_study_flat( uuid: str, params: RequestParameters, dest: Path, - output_list: Optional[List[str]] = None, + output_list: t.Optional[t.List[str]] = None, ) -> None: logger.info(f"Flat exporting study {uuid}") study = self.get_study(uuid) @@ -1053,20 +1042,20 @@ def export_study_flat( def delete_study(self, uuid: str, children: bool, params: RequestParameters) -> None: """ - Delete study + Delete study and all its children + Args: uuid: study uuid + children: delete children or not params: request parameters - - Returns: - """ study = self.get_study(uuid) assert_permission(params.user, study, StudyPermissionType.WRITE) study_info = study.to_json_summary() - # this prefetch the workspace because it is lazy loaded and the object is deleted before using workspace attribute in raw study deletion + # this prefetch the workspace because it is lazy loaded and the object is deleted + # before using workspace attribute in raw study deletion # see https://github.com/AntaresSimulatorTeam/AntaREST/issues/606 if isinstance(study, RawStudy): _ = study.workspace @@ -1137,8 +1126,8 @@ def download_outputs( use_task: bool, filetype: ExportFormat, params: RequestParameters, - tmp_export_file: Optional[Path] = None, - ) -> Union[Response, FileDownloadTaskDTO, FileResponse]: + tmp_export_file: t.Optional[Path] = None, + ) -> t.Union[Response, FileDownloadTaskDTO, FileResponse]: """ Download outputs Args: @@ -1239,7 +1228,7 @@ def export_task(notifier: TaskUpdateNotifier) -> TaskResult: ).encode("utf-8") return Response(content=json_response, media_type="application/json") - def get_study_sim_result(self, study_id: str, params: RequestParameters) -> List[StudySimResultDTO]: + def get_study_sim_result(self, study_id: str, params: RequestParameters) -> t.List[StudySimResultDTO]: """ Get global result information Args: @@ -1293,8 +1282,8 @@ def set_sim_reference( def import_study( self, - stream: BinaryIO, - group_ids: List[str], + stream: t.BinaryIO, + group_ids: t.List[str], params: RequestParameters, ) -> str: """ @@ -1345,11 +1334,11 @@ def import_study( def import_output( self, uuid: str, - output: Union[BinaryIO, Path], + output: t.Union[t.BinaryIO, Path], params: RequestParameters, - output_name_suffix: Optional[str] = None, + output_name_suffix: t.Optional[str] = None, auto_unzip: bool = True, - ) -> Optional[str]: + ) -> t.Optional[str]: """ Import specific output simulation inside study Args: @@ -1486,7 +1475,7 @@ def _edit_study_using_command( # noinspection SpellCheckingInspection url = "study/antares/lastsave" last_save_node = file_study.tree.get_node(url.split("/")) - cmd = self._create_edit_study_command(tree_node=last_save_node, url=url, data=int(time())) + cmd = self._create_edit_study_command(tree_node=last_save_node, url=url, data=int(time.time())) cmd.apply(file_study) self.storage_service.variant_study_service.invalidate_cache(study) @@ -1503,7 +1492,9 @@ def _edit_study_using_command( return command # for testing purpose - def apply_commands(self, uuid: str, commands: List[CommandDTO], params: RequestParameters) -> Optional[List[str]]: + def apply_commands( + self, uuid: str, commands: t.List[CommandDTO], params: RequestParameters + ) -> t.Optional[t.List[str]]: study = self.get_study(uuid) if isinstance(study, VariantStudy): return self.storage_service.variant_study_service.append_commands(uuid, commands, params) @@ -1511,7 +1502,7 @@ def apply_commands(self, uuid: str, commands: List[CommandDTO], params: RequestP file_study = self.storage_service.raw_study_service.get_raw(study) assert_permission(params.user, study, StudyPermissionType.WRITE) self._assert_study_unarchived(study) - parsed_commands: List[ICommand] = [] + parsed_commands: t.List[ICommand] = [] for command in commands: parsed_commands.extend(self.storage_service.variant_study_service.command_factory.to_command(command)) execute_or_add_commands( @@ -1574,7 +1565,7 @@ def edit_study( uuid, params.get_user_id(), ) - return cast(JSON, new) + return t.cast(JSON, new) def change_owner(self, study_id: str, owner_id: int, params: RequestParameters) -> None: """ @@ -1702,7 +1693,7 @@ def set_public_mode(self, study_id: str, mode: PublicMode, params: RequestParame params.get_user_id(), ) - def check_errors(self, uuid: str) -> List[str]: + def check_errors(self, uuid: str) -> t.List[str]: study = self.get_study(uuid) self._assert_study_unarchived(study) return self.storage_service.raw_study_service.check_errors(study) @@ -1710,10 +1701,10 @@ def check_errors(self, uuid: str) -> List[str]: def get_all_areas( self, uuid: str, - area_type: Optional[AreaType], + area_type: t.Optional[AreaType], ui: bool, params: RequestParameters, - ) -> Union[List[AreaInfoDTO], Dict[str, Any]]: + ) -> t.Union[t.List[AreaInfoDTO], t.Dict[str, t.Any]]: study = self.get_study(uuid) assert_permission(params.user, study, StudyPermissionType.READ) return self.areas.get_all_areas_ui_info(study) if ui else self.areas.get_all_areas(study, area_type) @@ -1723,7 +1714,7 @@ def get_all_links( uuid: str, with_ui: bool, params: RequestParameters, - ) -> List[LinkInfoDTO]: + ) -> t.List[LinkInfoDTO]: study = self.get_study(uuid) assert_permission(params.user, study, StudyPermissionType.READ) return self.links.get_all_links(study, with_ui) @@ -1803,7 +1794,7 @@ def update_thermal_cluster_metadata( self, uuid: str, area_id: str, - clusters_metadata: Dict[str, PatchCluster], + clusters_metadata: t.Dict[str, PatchCluster], params: RequestParameters, ) -> AreaInfoDTO: study = self.get_study(uuid) @@ -1938,8 +1929,8 @@ def unarchive_task(notifier: TaskUpdateNotifier) -> TaskResult: def _save_study( self, study: Study, - owner: Optional[JWTUser] = None, - group_ids: Sequence[str] = (), + owner: t.Optional[JWTUser] = None, + group_ids: t.Sequence[str] = (), content_status: StudyContentStatus = StudyContentStatus.VALID, ) -> None: """ @@ -1969,7 +1960,7 @@ def _save_study( study.groups.clear() for gid in group_ids: - jwt_group: Optional[JWTGroup] = next(filter(lambda g: g.id == gid, owner.groups), None) # type: ignore + jwt_group: t.Optional[JWTGroup] = next(filter(lambda g: g.id == gid, owner.groups), None) # type: ignore if ( jwt_group is None or jwt_group.role is None @@ -2030,13 +2021,13 @@ def _analyse_study(self, metadata: Study) -> StudyContentStatus: # noinspection PyUnusedLocal @staticmethod - def get_studies_versions(params: RequestParameters) -> List[str]: + def get_studies_versions(params: RequestParameters) -> t.List[str]: return list(STUDY_REFERENCE_TEMPLATES) def create_xpansion_configuration( self, uuid: str, - zipped_config: Optional[UploadFile], + zipped_config: t.Optional[UploadFile], params: RequestParameters, ) -> None: study = self.get_study(uuid) @@ -2082,7 +2073,7 @@ def get_candidate(self, uuid: str, candidate_name: str, params: RequestParameter assert_permission(params.user, study, StudyPermissionType.READ) return self.xpansion_manager.get_candidate(study, candidate_name) - def get_candidates(self, uuid: str, params: RequestParameters) -> List[XpansionCandidateDTO]: + def get_candidates(self, uuid: str, params: RequestParameters) -> t.List[XpansionCandidateDTO]: study = self.get_study(uuid) assert_permission(params.user, study, StudyPermissionType.READ) return self.xpansion_manager.get_candidates(study) @@ -2120,7 +2111,7 @@ def update_matrix( self, uuid: str, path: str, - matrix_edit_instruction: List[MatrixEditInstruction], + matrix_edit_instruction: t.List[MatrixEditInstruction], params: RequestParameters, ) -> None: """ @@ -2150,7 +2141,8 @@ def check_and_update_all_study_versions_in_database(self, params: RequestParamet if params.user and not params.user.is_site_admin(): logger.error(f"User {params.user.id} is not site admin") raise UserHasNotPermissionError() - studies = self.repository.get_all(managed=False) + studies = self.repository.get_all(study_filter=StudyFilter(managed=False)) + for study in studies: storage = self.storage_service.raw_study_service storage.check_and_update_study_version_in_database(study) @@ -2167,7 +2159,7 @@ def archive_outputs(self, study_id: str, params: RequestParameters) -> None: self.archive_output(study_id, output, params) @staticmethod - def _get_output_archive_task_names(study: Study, output_id: str) -> Tuple[str, str]: + def _get_output_archive_task_names(study: Study, output_id: str) -> t.Tuple[str, str]: return ( f"Archive output {study.id}/{output_id}", f"Unarchive output {study.name}/{output_id} ({study.id})", @@ -2179,7 +2171,7 @@ def archive_output( output_id: str, params: RequestParameters, force: bool = False, - ) -> Optional[str]: + ) -> t.Optional[str]: study = self.get_study(study_id) assert_permission(params.user, study, StudyPermissionType.WRITE) self._assert_study_unarchived(study) @@ -2236,7 +2228,7 @@ def unarchive_output( output_id: str, keep_src_zip: bool, params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: study = self.get_study(study_id) assert_permission(params.user, study, StudyPermissionType.READ) self._assert_study_unarchived(study) @@ -2276,7 +2268,7 @@ def unarchive_output_task( ) raise e - task_id: Optional[str] = None + task_id: t.Optional[str] = None workspace = getattr(study, "workspace", DEFAULT_WORKSPACE_NAME) if workspace != DEFAULT_WORKSPACE_NAME: dest = Path(study.path) / "output" / output_id diff --git a/antarest/study/storage/auto_archive_service.py b/antarest/study/storage/auto_archive_service.py index 8a15cb0f49..b2ae1fae63 100644 --- a/antarest/study/storage/auto_archive_service.py +++ b/antarest/study/storage/auto_archive_service.py @@ -10,6 +10,7 @@ from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.study.model import RawStudy, Study +from antarest.study.repository import StudyFilter from antarest.study.service import StudyService from antarest.study.storage.utils import is_managed from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy @@ -28,7 +29,11 @@ def __init__(self, study_service: StudyService, config: Config): def _try_archive_studies(self) -> None: old_date = datetime.datetime.utcnow() - datetime.timedelta(days=self.config.storage.auto_archive_threshold_days) with db(): - studies: List[Study] = self.study_service.repository.get_all(managed=True, exists=False) + studies: List[Study] = self.study_service.repository.get_all( + study_filter=StudyFilter( + managed=True, + ) + ) # list of study id and boolean indicating if it's a raw study (True) or a variant (False) study_ids_to_archive = [ (study.id, isinstance(study, RawStudy)) diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/renewable.py b/antarest/study/storage/rawstudy/model/filesystem/config/renewable.py index 3c53c23053..4d34e21637 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/renewable.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/renewable.py @@ -12,6 +12,7 @@ "RenewableConfig", "RenewableConfigType", "create_renewable_config", + "RenewableClusterGroup", ) diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index 1b1dc72312..e4fcbe100e 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -1,11 +1,13 @@ +import collections import io import logging +import typing as t from http import HTTPStatus from pathlib import Path -from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Depends, File, HTTPException, Request +from fastapi import APIRouter, Depends, File, HTTPException, Query, Request from markupsafe import escape +from pydantic import NonNegativeInt from antarest.core.config import Config from antarest.core.exceptions import BadZipBinary @@ -26,12 +28,22 @@ StudyMetadataPatchDTO, StudySimResultDTO, ) +from antarest.study.repository import StudyFilter, StudyPagination, StudySortBy from antarest.study.service import StudyService from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO logger = logging.getLogger(__name__) +def _split_comma_separated_values(value: str, *, default: t.Sequence[str] = ()) -> t.Sequence[str]: + """Split a comma-separated list of values into an ordered set of strings.""" + values = value.split(",") if value else default + # drop whitespace around values + values = [v.strip() for v in values] + # remove duplicates and preserve order (to have a deterministic result for unit tests). + return list(collections.OrderedDict.fromkeys(values)) + + def create_study_routes(study_service: StudyService, ftm: FileTransferManager, config: Config) -> APIRouter: """ Endpoint implementation for studies management @@ -50,19 +62,111 @@ def create_study_routes(study_service: StudyService, ftm: FileTransferManager, c "/studies", tags=[APITag.study_management], summary="Get Studies", - response_model=Dict[str, StudyMetadataDTO], ) def get_studies( - managed: bool = False, - name: str = "", - folder: str = "", - workspace: str = "", current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: - logger.info("Fetching study list", extra={"user": current_user.id}) + name: str = Query( + "", + description=( + "Filter studies based on their name." + "Case-insensitive search for studies whose name contains the specified value." + ), + alias="name", + ), + managed: t.Optional[bool] = Query(None, description="Filter studies based on their management status."), + archived: t.Optional[bool] = Query(None, description="Filter studies based on their archive status."), + variant: t.Optional[bool] = Query(None, description="Filter studies based on their variant status."), + versions: str = Query( + "", + description="Comma-separated list of versions for filtering.", + regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$", + ), + users: str = Query( + "", + description="Comma-separated list of user IDs for filtering.", + regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$", + ), + groups: str = Query("", description="Comma-separated list of group IDs for filtering."), + tags: str = Query("", description="Comma-separated list of tags for filtering."), + study_ids: str = Query( + "", + description="Comma-separated list of study IDs for filtering.", + alias="studyIds", + ), + exists: t.Optional[bool] = Query(None, description="Filter studies based on their existence on disk."), + workspace: str = Query("", description="Filter studies based on their workspace."), + folder: str = Query("", description="Filter studies based on their folder."), + # It is advisable to use an optional Query parameter for enumerated types, like booleans. + sort_by: t.Optional[StudySortBy] = Query( + None, + description="Sort studies based on their name (case-insensitive) or creation date.", + alias="sortBy", + ), + page_nb: NonNegativeInt = Query( + 0, + description="Page number (starting from 0).", + alias="pageNb", + ), + page_size: NonNegativeInt = Query( + 0, + description="Number of studies per page (0 = no limit).", + alias="pageSize", + ), + ) -> t.Dict[str, StudyMetadataDTO]: + """ + Get the list of studies matching the specified criteria. + + Args: + - `name`: Filter studies based on their name. Case-insensitive search for studies + whose name contains the specified value. + - `managed`: Filter studies based on their management status. + - `archived`: Filter studies based on their archive status. + - `variant`: Filter studies based on their variant status. + - `versions`: Comma-separated list of versions for filtering. + - `users`: Comma-separated list of user IDs for filtering. + - `groups`: Comma-separated list of group IDs for filtering. + - `tags`: Comma-separated list of tags for filtering. + - `studyIds`: Comma-separated list of study IDs for filtering. + - `exists`: Filter studies based on their existence on disk. + - `workspace`: Filter studies based on their workspace. + - `folder`: Filter studies based on their folder. + - `sortBy`: Sort studies based on their name (case-insensitive) or date. + - `pageNb`: Page number (starting from 0). + - `pageSize`: Number of studies per page (0 = no limit). + + Returns: + - A dictionary of studies matching the specified criteria, + where keys are study IDs and values are study properties. + """ + + logger.info("Fetching for matching studies", extra={"user": current_user.id}) params = RequestParameters(user=current_user) - available_studies = study_service.get_studies_information(managed, name, workspace, folder, params) - return available_studies + + user_list = [int(v) for v in _split_comma_separated_values(users)] + + study_filter = StudyFilter( + name=name, + managed=managed, + archived=archived, + variant=variant, + versions=_split_comma_separated_values(versions), + users=user_list, + groups=_split_comma_separated_values(groups), + tags=_split_comma_separated_values(tags), + study_ids=_split_comma_separated_values(study_ids), + exists=exists, + workspace=workspace, + folder=folder, + ) + + matching_studies = study_service.get_studies_information( + params=params, + study_filter=study_filter, + sort_by=sort_by, + pagination=StudyPagination(page_nb=page_nb, page_size=page_size), + ) + + return matching_studies @bp.get( "/studies/{uuid}/comments", @@ -72,7 +176,7 @@ def get_studies( def get_comments( uuid: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info(f"Get comments of study {uuid}", extra={"user": current_user.id}) params = RequestParameters(user=current_user) study_id = sanitize_uuid(uuid) @@ -89,7 +193,7 @@ def edit_comments( uuid: str, data: CommentsDto, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Editing comments for study {uuid}", extra={"user": current_user.id}, @@ -130,8 +234,8 @@ def import_study( zip_binary = io.BytesIO(study) params = RequestParameters(user=current_user) - group_ids = groups.split(",") if groups else [group.id for group in current_user.groups] - group_ids = [sanitize_uuid(gid) for gid in set(group_ids)] # sanitize and avoid duplicates + group_ids = _split_comma_separated_values(groups, default=[group.id for group in current_user.groups]) + group_ids = [sanitize_uuid(gid) for gid in group_ids] try: uuid = study_service.import_study(zip_binary, group_ids, params) @@ -209,8 +313,8 @@ def copy_study( extra={"user": current_user.id}, ) source_uuid = uuid - group_ids = groups.split(",") if groups else [group.id for group in current_user.groups] - group_ids = [sanitize_uuid(gid) for gid in set(group_ids)] # sanitize and avoid duplicates + group_ids = _split_comma_separated_values(groups, default=[group.id for group in current_user.groups]) + group_ids = [sanitize_uuid(gid) for gid in group_ids] source_uuid_sanitized = sanitize_uuid(source_uuid) destination_name_sanitized = escape(dest) @@ -236,7 +340,7 @@ def move_study( uuid: str, folder_dest: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Moving study {uuid} into folder '{folder_dest}'", extra={"user": current_user.id}, @@ -256,11 +360,11 @@ def create_study( version: str = "", groups: str = "", current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info(f"Creating new study '{name}'", extra={"user": current_user.id}) name_sanitized = escape(name) - group_ids = groups.split(",") if groups else [] - group_ids = [sanitize_uuid(gid) for gid in set(group_ids)] # sanitize and avoid duplicates + group_ids = _split_comma_separated_values(groups) + group_ids = [sanitize_uuid(gid) for gid in group_ids] params = RequestParameters(user=current_user) uuid = study_service.create_study(name_sanitized, version, group_ids, params) @@ -276,7 +380,7 @@ def create_study( def get_study_synthesis( uuid: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: study_id = sanitize_uuid(uuid) logger.info( f"Return a synthesis for study '{study_id}'", @@ -295,7 +399,7 @@ def get_study_matrix_index( uuid: str, path: str = "", current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: study_id = sanitize_uuid(uuid) logger.info( f"Return the start date for input matrix '{study_id}'", @@ -312,9 +416,9 @@ def get_study_matrix_index( ) def export_study( uuid: str, - no_output: Optional[bool] = False, + no_output: t.Optional[bool] = False, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info(f"Exporting study {uuid}", extra={"user": current_user.id}) uuid_sanitized = sanitize_uuid(uuid) @@ -331,7 +435,7 @@ def delete_study( uuid: str, children: bool = False, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info(f"Deleting study {uuid}", extra={"user": current_user.id}) uuid_sanitized = sanitize_uuid(uuid) @@ -351,7 +455,7 @@ def import_output( uuid: str, output: bytes = File(...), current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Importing output for study {uuid}", extra={"user": current_user.id}, @@ -373,7 +477,7 @@ def change_owner( uuid: str, user_id: int, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Changing owner to {user_id} for study {uuid}", extra={"user": current_user.id}, @@ -393,7 +497,7 @@ def add_group( uuid: str, group_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Adding group {group_id} to study {uuid}", extra={"user": current_user.id}, @@ -414,7 +518,7 @@ def remove_group( uuid: str, group_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Removing group {group_id} to study {uuid}", extra={"user": current_user.id}, @@ -436,7 +540,7 @@ def set_public_mode( uuid: str, mode: PublicMode, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Setting public mode to {mode} for study {uuid}", extra={"user": current_user.id}, @@ -451,11 +555,11 @@ def set_public_mode( "/studies/_versions", tags=[APITag.study_management], summary="Show available study versions", - response_model=List[str], + response_model=t.List[str], ) def get_study_versions( current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: params = RequestParameters(user=current_user) logger.info("Fetching version list") return StudyService.get_studies_versions(params=params) @@ -469,7 +573,7 @@ def get_study_versions( def get_study_metadata( uuid: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info(f"Fetching study {uuid} metadata", extra={"user": current_user.id}) params = RequestParameters(user=current_user) study_metadata = study_service.get_study_information(uuid, params) @@ -485,7 +589,7 @@ def update_study_metadata( uuid: str, study_metadata_patch: StudyMetadataPatchDTO, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Updating metadata for study {uuid}", extra={"user": current_user.id}, @@ -503,7 +607,7 @@ def output_variables_information( study_id: str, output_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: study_id = sanitize_uuid(study_id) output_id = sanitize_uuid(output_id) logger.info(f"Fetching whole output of the simulation {output_id} for study {study_id}") @@ -523,7 +627,7 @@ def output_export( study_id: str, output_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: study_id = sanitize_uuid(study_id) output_id = sanitize_uuid(output_id) logger.info(f"Fetching whole output of the simulation {output_id} for study {study_id}") @@ -547,7 +651,7 @@ def output_download( use_task: bool = False, tmp_export_file: Path = Depends(ftm.request_tmp_file), current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: study_id = sanitize_uuid(study_id) output_id = sanitize_uuid(output_id) logger.info( @@ -601,7 +705,7 @@ def archive_output( study_id: str, output_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: study_id = sanitize_uuid(study_id) output_id = sanitize_uuid(output_id) logger.info( @@ -626,7 +730,7 @@ def unarchive_output( study_id: str, output_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: study_id = sanitize_uuid(study_id) output_id = sanitize_uuid(output_id) logger.info( @@ -647,12 +751,12 @@ def unarchive_output( "/studies/{study_id}/outputs", summary="Get global information about a study simulation result", tags=[APITag.study_outputs], - response_model=List[StudySimResultDTO], + response_model=t.List[StudySimResultDTO], ) def sim_result( study_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Fetching output list for study {study_id}", extra={"user": current_user.id}, @@ -672,7 +776,7 @@ def set_sim_reference( output_id: str, status: bool = True, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( f"Setting output {output_id} as reference simulation for study {study_id}", extra={"user": current_user.id}, @@ -691,7 +795,7 @@ def set_sim_reference( def archive_study( study_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info(f"Archiving study {study_id}", extra={"user": current_user.id}) study_id = sanitize_uuid(study_id) params = RequestParameters(user=current_user) @@ -705,7 +809,7 @@ def archive_study( def unarchive_study( study_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info(f"Unarchiving study {study_id}", extra={"user": current_user.id}) study_id = sanitize_uuid(study_id) params = RequestParameters(user=current_user) @@ -718,7 +822,7 @@ def unarchive_study( ) def invalidate_study_listing_cache( current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> t.Any: logger.info( "Invalidating the study listing cache", extra={"user": current_user.id}, diff --git a/tests/integration/assets/ext-840.zip b/tests/integration/assets/ext-840.zip new file mode 100644 index 0000000000..ac81e0ca0c Binary files /dev/null and b/tests/integration/assets/ext-840.zip differ diff --git a/tests/integration/assets/ext-850.zip b/tests/integration/assets/ext-850.zip new file mode 100644 index 0000000000..ff25db9c38 Binary files /dev/null and b/tests/integration/assets/ext-850.zip differ diff --git a/tests/integration/assets/ext-860.zip b/tests/integration/assets/ext-860.zip new file mode 100644 index 0000000000..427c7463ed Binary files /dev/null and b/tests/integration/assets/ext-860.zip differ diff --git a/tests/integration/studies_blueprint/test_comments.py b/tests/integration/studies_blueprint/test_comments.py index ca6d746443..b282ed8781 100644 --- a/tests/integration/studies_blueprint/test_comments.py +++ b/tests/integration/studies_blueprint/test_comments.py @@ -113,7 +113,7 @@ def test_variant_study( ) assert res.status_code == 200, res.json() duration = time.time() - start - assert 0 <= duration <= 0.1, f"Duration is {duration} seconds" + assert 0 <= duration <= 0.3, f"Duration is {duration} seconds" # Update the comments of the study res = client.put( diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py new file mode 100644 index 0000000000..b134406e50 --- /dev/null +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -0,0 +1,770 @@ +import io +import operator +import re +import shutil +import typing as t +import zipfile +from pathlib import Path + +import pytest +from fastapi import FastAPI +from starlette.testclient import TestClient + +from antarest.core.model import PublicMode +from antarest.core.tasks.model import TaskStatus +from tests.integration.assets import ASSETS_DIR +from tests.integration.utils import wait_task_completion + +# URL used to create or list studies +STUDIES_URL = "/v1/studies" + +# Status codes for study creation requests +CREATE_STATUS_CODES = {200, 201} + +# Status code for study listing requests +LIST_STATUS_CODE = 200 + +# Status code for study listing with invalid parameters +INVALID_PARAMS_STATUS_CODE = 422 + + +class TestStudiesListing: + """ + This class contains tests related to the following endpoints: + + - GET /v1/studies + """ + + # noinspection PyUnusedLocal + @pytest.fixture(name="to_be_deleted_study_path", autouse=True) + def studies_in_ext_fixture(self, tmp_path: Path, app: FastAPI) -> Path: + # create a non managed raw study version 840 + study_dir = tmp_path / "ext_workspace" / "ext-840" + study_dir.mkdir(exist_ok=True) + zip_path = ASSETS_DIR.joinpath("ext-840.zip") + with zipfile.ZipFile(zip_path) as zip_output: + zip_output.extractall(path=study_dir) + + # create a non managed raw study version 850 + study_dir = tmp_path / "ext_workspace" / "ext-850" + study_dir.mkdir(exist_ok=True) + zip_path = ASSETS_DIR.joinpath("ext-850.zip") + with zipfile.ZipFile(zip_path) as zip_output: + zip_output.extractall(path=study_dir) + + # create a non managed raw study version 860 + study_dir = tmp_path / "ext_workspace" / "ext-860" + study_dir.mkdir(exist_ok=True) + zip_path = ASSETS_DIR.joinpath("ext-860.zip") + with zipfile.ZipFile(zip_path) as zip_output: + zip_output.extractall(path=study_dir) + + # create a non managed raw study version 840 to be deleted from disk + study_dir = tmp_path / "ext_workspace" / "to-be-deleted-840" + study_dir.mkdir(exist_ok=True) + zip_path = ASSETS_DIR.joinpath("ext-840.zip") + with zipfile.ZipFile(zip_path) as zip_output: + zip_output.extractall(path=study_dir) + + return study_dir + + def test_study_listing( + self, + client: TestClient, + admin_access_token: str, + to_be_deleted_study_path: Path, + ) -> None: + """ + This test verifies that database is correctly initialized and then runs the filtering tests with different + parameters + """ + + # ========================== + # 1. Database initialization + # ========================== + + # database update to include non managed studies using the watcher + res = client.post( + "/v1/watcher/_scan", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"path": "ext"}, + ) + res.raise_for_status() + task_id = res.json() + task = wait_task_completion(client, admin_access_token, task_id) + assert task.status == TaskStatus.COMPLETED, task + + # retrieve a created non managed + to be deleted study IDs + res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}) + res.raise_for_status() + folder_map = {v["folder"]: k for k, v in res.json().items()} + non_managed_840_id = folder_map["ext-840"] + non_managed_850_id = folder_map["ext-850"] + non_managed_860_id = folder_map["ext-860"] + to_be_deleted_id = folder_map["to-be-deleted-840"] + + # delete study `to_be_deleted_id` from disk + shutil.rmtree(to_be_deleted_study_path) + assert not to_be_deleted_study_path.exists() + + # database update with missing studies + res = client.post( + "/v1/watcher/_scan", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"path": "ext"}, + ) + res.raise_for_status() + task_id = res.json() + task = wait_task_completion(client, admin_access_token, task_id) + assert task.status == TaskStatus.COMPLETED, task + + # change permissions for non managed studies (no access but to admin) + non_managed_studies = {non_managed_840_id, non_managed_850_id, non_managed_860_id, to_be_deleted_id} + no_access_code = "NONE" + for non_managed_study in non_managed_studies: + res = client.put( + f"/v1/studies/{non_managed_study}/public_mode/{no_access_code}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"name": "James Bond", "password": "0007"}, + ) + res.raise_for_status() + + # create a user 'James Bond' with password '007' + res = client.post( + "/v1/users", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"name": "James Bond", "password": "0007"}, + ) + res.raise_for_status() + james_bond_id = res.json().get("id") + + # create a user 'John Doe' with password '0011' + res = client.post( + "/v1/users", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"name": "John Doe", "password": "0011"}, + ) + res.raise_for_status() + john_doe_id = res.json().get("id") + + # create a group 'Group X' with id 'groupX' + res = client.post( + "/v1/groups", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"name": "Group X", "id": "groupX"}, + ) + res.raise_for_status() + group_x_id = res.json().get("id") + assert group_x_id == "groupX" + + # create a group 'Group Y' with id 'groupY' + res = client.post( + "/v1/groups", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"name": "Group Y", "id": "groupY"}, + ) + res.raise_for_status() + group_y_id = res.json().get("id") + assert group_y_id == "groupY" + + # login 'James Bond' + res = client.post( + "/v1/login", + json={"username": "James Bond", "password": "0007"}, + ) + res.raise_for_status() + assert res.json().get("user") == james_bond_id + james_bond_access_token = res.json().get("access_token") + + # login 'John Doe' + res = client.post( + "/v1/login", + json={"username": "John Doe", "password": "0011"}, + ) + res.raise_for_status() + assert res.json().get("user") == john_doe_id + john_doe_access_token = res.json().get("access_token") + + # create a bot user 'James Bond' + res = client.post( + "/v1/bots", + headers={"Authorization": f"Bearer {james_bond_access_token}"}, + json={"name": "James Bond", "roles": []}, + ) + res.raise_for_status() + james_bond_bot_token = res.json() + + # create a raw study version 840 + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "raw-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + raw_840_id = res.json() + + # create a raw study version 850 + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "raw-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + raw_850_id = res.json() + + # create a raw study version 860 + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "raw-860", "version": "860"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + raw_860_id = res.json() + + # create a variant study version 840 + res = client.post( + f"{STUDIES_URL}/{raw_840_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "variant-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + variant_840_id = res.json() + + # create a variant study version 850 + res = client.post( + f"{STUDIES_URL}/{raw_850_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "variant-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + variant_850_id = res.json() + + # create a variant study version 860 + res = client.post( + f"{STUDIES_URL}/{raw_860_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "variant-860", "version": "860"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + variant_860_id = res.json() + + # create a raw study version 840 to be archived + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "archived-raw-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + archived_raw_840_id = res.json() + + # create a raw study version 850 to be archived + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "archived-raw-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + archived_raw_850_id = res.json() + + # create a variant study version 840 + res = client.post( + f"{STUDIES_URL}/{archived_raw_840_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "archived-variant-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + archived_variant_840_id = res.json() + + # create a variant study version 850 to be archived + res = client.post( + f"{STUDIES_URL}/{archived_raw_850_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "archived-variant-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + archived_variant_850_id = res.json() + + # create a raw study to be transferred in folder1 + zip_path = ASSETS_DIR / "STA-mini.zip" + res = client.post( + f"{STUDIES_URL}/_import", + headers={"Authorization": f"Bearer {admin_access_token}"}, + files={"study": io.BytesIO(zip_path.read_bytes())}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + folder1_study_id = res.json() + res = client.put( + f"{STUDIES_URL}/{folder1_study_id}/move", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"folder_dest": "folder1"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + + # give permission to James Bond for some select studies + james_bond_studies = {raw_840_id, variant_850_id, non_managed_860_id} + for james_bond_study in james_bond_studies: + res = client.put( + f"{STUDIES_URL}/{james_bond_study}/owner/{james_bond_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + + # associate select studies to each group: groupX, groupY + group_x_studies = {variant_850_id, raw_860_id} + group_y_studies = {raw_850_id, raw_860_id} + for group_x_study in group_x_studies: + res = client.put( + f"{STUDIES_URL}/{group_x_study}/groups/{group_x_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + for group_y_study in group_y_studies: + res = client.put( + f"{STUDIES_URL}/{group_y_study}/groups/{group_y_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + + # archive studies + archive_studies = {archived_raw_840_id, archived_raw_850_id} + for archive_study in archive_studies: + res = client.put( + f"{STUDIES_URL}/{archive_study}/archive", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + archiving_study_task_id = res.json() + task = wait_task_completion(client, admin_access_token, archiving_study_task_id) + assert task.status == TaskStatus.COMPLETED, task + + # the testing studies set + all_studies = { + raw_840_id, + raw_850_id, + raw_860_id, + non_managed_840_id, + non_managed_850_id, + non_managed_860_id, + variant_840_id, + variant_850_id, + variant_860_id, + archived_raw_840_id, + archived_raw_850_id, + archived_variant_840_id, + archived_variant_850_id, + folder1_study_id, + to_be_deleted_id, + } + + pm = operator.itemgetter("public_mode") + + # tests (1) for user permission filtering + # test 1.a for a user with no access permission + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {john_doe_access_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert not all_studies.intersection(study_map) + assert all(map(lambda x: pm(x) in [PublicMode.READ, PublicMode.FULL], study_map.values())) + + # test 1.b for an admin user + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(study_map) + + # test 1.c for a user with access to select studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {james_bond_access_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not james_bond_studies.difference(study_map) + assert all( + map( + lambda x: pm(x) in [PublicMode.READ, PublicMode.FULL], + [e for k, e in study_map.items() if k not in james_bond_studies], + ) + ) + # #TODO you need to update the permission for James Bond bot + + # test 1.d for a user bot with access to select studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {james_bond_bot_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + + # #TODO add the correct test assertions + # ] = res.json() + # assert not set(james_bond_studies).difference(study_map) + # assert all( + # map( + # lambda x: pm(x) in [PublicMode.READ, PublicMode.FULL], + # [e for k, e in study_map.items() if k not in james_bond_studies], + # ) + # ) + + # tests (2) for studies names filtering + # test 2.a with matching studies + res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, params={"name": "840"}) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert all(map(lambda x: "840" in x.get("name"), study_map.values())) and len(study_map) >= 5 + # test 2.b with no matching studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "NON-SENSE-746846351469798465"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not study_map + + # tests (3) managed studies vs non managed + # test 3.a managed + managed_studies = all_studies.difference(non_managed_studies) + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"managed": True}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not managed_studies.difference(study_map) + assert not all_studies.difference(managed_studies).intersection(study_map) + # test 3.b non managed + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"managed": False}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(managed_studies).difference(study_map) + assert not managed_studies.intersection(study_map) + + # tests (4) archived vs non archived + # test 4.a archived studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"archived": True}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not archive_studies.difference(study_map) + assert not all_studies.difference(archive_studies).intersection(study_map) + # test 4.b non archived + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"archived": False}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(archive_studies).difference(study_map) + assert not archive_studies.intersection(study_map) + + # tests (5) for filtering variant studies + variant_studies = { + variant_840_id, + variant_850_id, + variant_860_id, + archived_variant_840_id, + archived_variant_850_id, + } + # test 5.a get variant studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"variant": True}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not variant_studies.difference(study_map) + assert not all_studies.difference(variant_studies).intersection(study_map) + # test 5.b get raw studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"variant": False}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(variant_studies).difference(study_map) + assert not variant_studies.intersection(study_map) + + # tests (6) for version filtering + studies_version_850 = { + raw_850_id, + non_managed_850_id, + variant_850_id, + archived_raw_850_id, + archived_variant_850_id, + } + studies_version_860 = { + raw_860_id, + non_managed_860_id, + variant_860_id, + } + # test 6.a filter for one version: 860 + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"versions": "860"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(studies_version_860).intersection(study_map) + assert not studies_version_860.difference(study_map) + # test 8.b filter for two versions: 850, 860 + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"versions": "850,860"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(studies_version_850.union(studies_version_860)).intersection(study_map) + assert not studies_version_850.union(studies_version_860).difference(study_map) + + # tests (7) for users filtering + # test 7.a to get studies for one user: James Bond + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"users": f"{james_bond_id}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(james_bond_studies).intersection(study_map) + assert not james_bond_studies.difference(study_map) + # test 7.b to get studies for two users + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"users": f"{james_bond_id},{john_doe_id}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(james_bond_studies).intersection(study_map) + assert not james_bond_studies.difference(study_map) + + # tests (8) for groups filtering + # test 8.a filter for one group: groupX + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"groups": f"{group_x_id}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(group_x_studies).intersection(study_map) + assert not group_x_studies.difference(study_map) + # test 8.b filter for two groups: groupX, groupY + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"groups": f"{group_x_id},{group_y_id}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(group_x_studies.union(group_y_studies)).intersection(study_map) + assert not group_x_studies.union(group_y_studies).difference(study_map) + + # TODO you need to add filtering through tags to the search engine + # tests (9) for tags filtering + # test 9.a filtering for one tag: decennial + # test 9.b filtering for two tags: decennial,winter_transition + + # tests (10) for studies uuids sequence filtering + # test 10.a filter for one uuid + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"studyIds": f"{raw_840_id}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert {raw_840_id} == set(study_map) + # test 10.b filter for two uuids + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"studyIds": f"{raw_840_id},{raw_860_id}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert {raw_840_id, raw_860_id} == set(study_map) + + # tests (11) studies filtering regarding existence on disk + existing_studies = all_studies.difference({to_be_deleted_id}) + # test 11.a filter existing studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"exists": True}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not existing_studies.difference(study_map) + assert not all_studies.difference(existing_studies).intersection(study_map) + # test 11.b filter non-existing studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"exists": False}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(existing_studies).difference(study_map) + assert not existing_studies.intersection(study_map) + + # tests (12) studies filtering with workspace + ext_workspace_studies = non_managed_studies + # test 12.a filter `ext` workspace studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"workspace": "ext"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not ext_workspace_studies.difference(study_map) + assert not all_studies.difference(ext_workspace_studies).intersection(study_map) + + # tests (13) studies filtering with folder + # test 13.a filter `folder1` studies + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"folder": "folder1"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not {folder1_study_id}.difference(study_map) + assert not all_studies.difference({folder1_study_id}).intersection(study_map) + + # test sort by name ASC + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"sortBy": "+name"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + values = list(study_map.values()) + assert values == sorted(values, key=lambda x: x["name"].upper()) + + # test sort by name DESC + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"sortBy": "-name"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + values = list(study_map.values()) + assert values == sorted(values, key=lambda x: x["name"].upper(), reverse=True) + + # test sort by date ASC + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"sortBy": "+date"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + values = list(study_map.values()) + assert values == sorted(values, key=lambda x: x["created"]) + + # test sort by date DESC + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"sortBy": "-date"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + values = list(study_map.values()) + assert values == sorted(values, key=lambda x: x["created"], reverse=True) + + def test_get_studies__invalid_parameters( + self, + client: TestClient, + user_access_token: str, + ) -> None: + headers = {"Authorization": f"Bearer {user_access_token}"} + + # Invalid `sortBy` parameter + res = client.get(STUDIES_URL, headers=headers, params={"sortBy": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"not a valid enumeration member", description), f"{description=}" + + # Invalid `pageNb` parameter (negative integer) + res = client.get(STUDIES_URL, headers=headers, params={"pageNb": -1}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"greater than or equal to 0", description), f"{description=}" + + # Invalid `pageNb` parameter (not an integer) + res = client.get(STUDIES_URL, headers=headers, params={"pageNb": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"not a valid integer", description), f"{description=}" + + # Invalid `pageSize` parameter (negative integer) + res = client.get(STUDIES_URL, headers=headers, params={"pageSize": -1}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"greater than or equal to 0", description), f"{description=}" + + # Invalid `pageSize` parameter (not an integer) + res = client.get(STUDIES_URL, headers=headers, params={"pageSize": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"not a valid integer", description), f"{description=}" + + # Invalid `managed` parameter (not a boolean) + res = client.get(STUDIES_URL, headers=headers, params={"managed": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + + # Invalid `archived` parameter (not a boolean) + res = client.get(STUDIES_URL, headers=headers, params={"archived": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + + # Invalid `variant` parameter (not a boolean) + res = client.get(STUDIES_URL, headers=headers, params={"variant": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + + # Invalid `versions` parameter (not a list of integers) + res = client.get(STUDIES_URL, headers=headers, params={"versions": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"string does not match regex", description), f"{description=}" + + # Invalid `users` parameter (not a list of integers) + res = client.get(STUDIES_URL, headers=headers, params={"users": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"string does not match regex", description), f"{description=}" + + # Invalid `exists` parameter (not a boolean) + res = client.get(STUDIES_URL, headers=headers, params={"exists": "invalid"}) + assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() + description = res.json()["description"] + assert re.search(r"could not be parsed to a boolean", description), f"{description=}" diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index 03293ce16c..12afde7b05 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -67,9 +67,10 @@ def test_cyclelife(): c = repo.get(a.id) assert a == c - assert len(repo.get_all()) == 3 - assert len(repo.get_all_raw(show_missing=True)) == 2 - assert len(repo.get_all_raw(show_missing=False)) == 1 + assert len(repo.get_all()) == 4 + assert len(repo.get_all_raw(exists=True)) == 1 + assert len(repo.get_all_raw(exists=False)) == 1 + assert len(repo.get_all_raw()) == 2 repo.delete(a.id) assert repo.get(a.id) is None diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index b509445052..e7e8662394 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -44,7 +44,7 @@ TimeSerie, TimeSeriesData, ) -from antarest.study.repository import StudyMetadataRepository +from antarest.study.repository import StudyFilter, StudyMetadataRepository from antarest.study.service import MAX_MISSING_STUDY_TIMEOUT, StudyService, StudyUpgraderTask, UserHasNotPermissionError from antarest.study.storage.patch_service import PatchService from antarest.study.storage.rawstudy.model.filesystem.config.model import ( @@ -173,57 +173,73 @@ def test_study_listing(db_session: Session) -> None: repository = StudyMetadataRepository(cache_service=Mock(spec=ICache), session=db_session) service = build_study_service(raw_study_service, repository, config, cache_service=cache) + # retrieve studies that are not managed # use the db recorder to check that: # 1- retrieving studies information requires only 1 query # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - managed=False, - name=None, - workspace=None, - folder=None, + study_filter=StudyFilter( + managed=False, + ), params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) # verify that we get the expected studies information - expected_result = {e.id: e for e in map(lambda x: study_to_dto(x), [a, c])} + expected_result = {e.id: e for e in map(lambda x: study_to_dto(x), [c])} assert expected_result == studies - cache.get.return_value = {e.id: e for e in map(lambda x: study_to_dto(x), [a, b, c])} - # check that: - # 1- retrieving studies information requires no query at all (cache is used) - # 2- the `put` method of `cache` was used once + # retrieve managed studies + # use the db recorder to check that: + # 1- retrieving studies information requires only 1 query + # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - managed=False, - name=None, - workspace=None, - folder=None, + study_filter=StudyFilter( + managed=True, + ), params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), ) - assert len(db_recorder.sql_statements) == 0, str(db_recorder) - cache.put.assert_called_once() + assert len(db_recorder.sql_statements) == 1, str(db_recorder) # verify that we get the expected studies information + expected_result = {e.id: e for e in map(lambda x: study_to_dto(x), [a])} assert expected_result == studies - cache.get.return_value = None + # retrieve studies regardless of whether they are managed or not # use the db recorder to check that: - # 1- retrieving studies information requires only 1 query (cache reset to None) + # 1- retrieving studies information requires only 1 query # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - managed=True, - name=None, - workspace=None, - folder=None, + study_filter=StudyFilter( + managed=None, + ), params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) # verify that we get the expected studies information - expected_result = {e.id: e for e in map(lambda x: study_to_dto(x), [a])} + expected_result = {e.id: e for e in map(lambda x: study_to_dto(x), [a, c])} + assert expected_result == studies + + # in previous versions cache was used, verify that it is not used anymore + # check that: + # 1- retrieving studies information still requires 1 query + # 2- the `put` method of `cache` was never used + with DBStatementRecorder(db_session.bind) as db_recorder: + studies = service.get_studies_information( + study_filter=StudyFilter( + managed=None, + ), + params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + ) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + with contextlib.suppress(AssertionError): + cache.put.assert_any_call() + + # verify that we get the expected studies information assert expected_result == studies diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index be1a763602..77e4f1554c 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -1,42 +1,50 @@ +import datetime import typing as t -from datetime import datetime from unittest.mock import Mock import pytest from sqlalchemy.orm import Session # type: ignore from antarest.core.interfaces.cache import ICache +from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy -from antarest.study.repository import StudyMetadataRepository +from antarest.study.repository import StudyFilter, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.db_statement_recorder import DBStatementRecorder @pytest.mark.parametrize( - "managed, studies_ids, exists, expected_ids", + "managed, study_ids, exists, expected_ids", [ - (None, None, False, {"1", "2", "3", "4", "5", "6", "7", "8"}), - (None, None, True, {"1", "2", "3", "4", "7", "8"}), - (None, [1, 3, 5, 7], False, {"1", "3", "5", "7"}), + (None, [], False, {"5", "6"}), + (None, [], True, {"1", "2", "3", "4", "7", "8"}), + (None, [], None, {"1", "2", "3", "4", "5", "6", "7", "8"}), + (None, [1, 3, 5, 7], False, {"5"}), (None, [1, 3, 5, 7], True, {"1", "3", "7"}), - (True, None, False, {"1", "2", "3", "4", "5", "8"}), - (True, None, True, {"1", "2", "3", "4", "8"}), - (True, [1, 3, 5, 7], False, {"1", "3", "5"}), + (None, [1, 3, 5, 7], None, {"1", "3", "5", "7"}), + (True, [], False, {"5"}), + (True, [], True, {"1", "2", "3", "4", "8"}), + (True, [], None, {"1", "2", "3", "4", "5", "8"}), + (True, [1, 3, 5, 7], False, {"5"}), (True, [1, 3, 5, 7], True, {"1", "3"}), + (True, [1, 3, 5, 7], None, {"1", "3", "5"}), (True, [2, 4, 6, 8], True, {"2", "4", "8"}), - (False, None, False, {"6", "7"}), - (False, None, True, {"7"}), - (False, [1, 3, 5, 7], False, {"7"}), + (True, [2, 4, 6, 8], None, {"2", "4", "8"}), + (False, [], False, {"6"}), + (False, [], True, {"7"}), + (False, [], None, {"6", "7"}), + (False, [1, 3, 5, 7], False, set()), (False, [1, 3, 5, 7], True, {"7"}), + (False, [1, 3, 5, 7], None, {"7"}), ], ) -def test_repository_get_all( +def test_repository_get_all__general_case( db_session: Session, managed: t.Union[bool, None], - studies_ids: t.Union[t.List[str], None], - exists: bool, - expected_ids: set, -): + study_ids: t.List[str], + exists: t.Union[bool, None], + expected_ids: t.Set[str], +) -> None: test_workspace = "test-repository" icache: Mock = Mock(spec=ICache) repository = StudyMetadataRepository(cache_service=icache, session=db_session) @@ -45,8 +53,8 @@ def test_repository_get_all( study_2 = VariantStudy(id=2) study_3 = VariantStudy(id=3) study_4 = VariantStudy(id=4) - study_5 = RawStudy(id=5, missing=datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) - study_6 = RawStudy(id=6, missing=datetime.now(), workspace=test_workspace) + study_5 = RawStudy(id=5, missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) + study_6 = RawStudy(id=6, missing=datetime.datetime.now(), workspace=test_workspace) study_7 = RawStudy(id=7, missing=None, workspace=test_workspace) study_8 = RawStudy(id=8, missing=None, workspace=DEFAULT_WORKSPACE_NAME) @@ -58,11 +66,530 @@ def test_repository_get_all( # 2- accessing studies attributes does require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(managed=managed, studies_ids=studies_ids, exists=exists) + all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: - assert set([s.id for s in all_studies]) == expected_ids + assert {s.id for s in all_studies} == expected_ids + + +def test_repository_get_all__incompatible_case( + db_session: Session, +) -> None: + test_workspace = "workspace1" + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1) + study_2 = VariantStudy(id=2) + study_3 = VariantStudy(id=3) + study_4 = VariantStudy(id=4) + study_5 = RawStudy(id=5, missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) + study_6 = RawStudy(id=6, missing=datetime.datetime.now(), workspace=test_workspace) + study_7 = RawStudy(id=7, missing=None, workspace=test_workspace) + study_8 = RawStudy(id=8, missing=None, workspace=DEFAULT_WORKSPACE_NAME) + + db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) + db_session.commit() + + # case 1 + study_filter = StudyFilter(managed=False, variant=True) + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + assert not {s.id for s in all_studies} + + # case 2 + study_filter = StudyFilter(workspace=test_workspace, variant=True) + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + assert not {s.id for s in all_studies} + + # case 3 + study_filter = StudyFilter(exists=False, variant=True) + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + assert not {s.id for s in all_studies} + + +@pytest.mark.parametrize( + "name, expected_ids", + [ + ("", {"1", "2", "3", "4", "5", "6", "7", "8"}), + ("specie", {"1", "2", "3", "4", "5", "6", "7", "8"}), + ("prefix-specie", {"2", "3", "6", "7"}), + ("variant", {"1", "2", "3", "4"}), + ("variant-suffix", {"3", "4"}), + ("raw", {"5", "6", "7", "8"}), + ("raw-suffix", {"7", "8"}), + ("prefix-variant", set()), + ("specie-suffix", set()), + ], +) +def test_repository_get_all__study_name_filter( + db_session: Session, + name: str, + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1, name="specie-variant") + study_2 = VariantStudy(id=2, name="prefix-specie-variant") + study_3 = VariantStudy(id=3, name="prefix-specie-variant-suffix") + study_4 = VariantStudy(id=4, name="specie-variant-suffix") + study_5 = RawStudy(id=5, name="specie-raw") + study_6 = RawStudy(id=6, name="prefix-specie-raw") + study_7 = RawStudy(id=7, name="prefix-specie-raw-suffix") + study_8 = RawStudy(id=8, name="specie-raw-suffix") + + db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(name=name)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "managed, expected_ids", + [ + (None, {"1", "2", "3", "4", "5", "6", "7", "8"}), + (True, {"1", "2", "3", "4", "5", "8"}), + (False, {"6", "7"}), + ], +) +def test_repository_get_all__managed_study_filter( + db_session: Session, + managed: t.Optional[bool], + expected_ids: t.Set[str], +) -> None: + test_workspace = "test-workspace" + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1) + study_2 = VariantStudy(id=2) + study_3 = VariantStudy(id=3) + study_4 = VariantStudy(id=4) + study_5 = RawStudy(id=5, workspace=DEFAULT_WORKSPACE_NAME) + study_6 = RawStudy(id=6, workspace=test_workspace) + study_7 = RawStudy(id=7, workspace=test_workspace) + study_8 = RawStudy(id=8, workspace=DEFAULT_WORKSPACE_NAME) + + db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(managed=managed)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "archived, expected_ids", + [ + (None, {"1", "2", "3", "4"}), + (True, {"1", "3"}), + (False, {"2", "4"}), + ], +) +def test_repository_get_all__archived_study_filter( + db_session: Session, + archived: t.Optional[bool], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1, archived=True) + study_2 = VariantStudy(id=2, archived=False) + study_3 = RawStudy(id=3, archived=True) + study_4 = RawStudy(id=4, archived=False) + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(archived=archived)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "variant, expected_ids", + [ + (None, {"1", "2", "3", "4"}), + (True, {"1", "2"}), + (False, {"3", "4"}), + ], +) +def test_repository_get_all__variant_study_filter( + db_session: Session, + variant: t.Optional[bool], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1) + study_2 = VariantStudy(id=2) + study_3 = RawStudy(id=3) + study_4 = RawStudy(id=4) + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(variant=variant)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "versions, expected_ids", + [ + ([], {"1", "2", "3", "4"}), + (["1", "2"], {"1", "2", "3", "4"}), + (["1"], {"1", "3"}), + (["2"], {"2", "4"}), + (["3"], set()), + ], +) +def test_repository_get_all__study_version_filter( + db_session: Session, + versions: t.List[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1, version="1") + study_2 = VariantStudy(id=2, version="2") + study_3 = RawStudy(id=3, version="1") + study_4 = RawStudy(id=4, version="2") + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(versions=versions)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "users, expected_ids", + [ + ([], {"1", "2", "3", "4"}), + (["1000", "2000"], {"1", "2", "3", "4"}), + (["1000"], {"1", "3"}), + (["2000"], {"2", "4"}), + (["3000"], set()), + ], +) +def test_repository_get_all__study_users_filter( + db_session: Session, + users: t.List["int"], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + test_user_1 = User(id=1000) + test_user_2 = User(id=2000) + + study_1 = VariantStudy(id=1, owner=test_user_1) + study_2 = VariantStudy(id=2, owner=test_user_2) + study_3 = RawStudy(id=3, owner=test_user_1) + study_4 = RawStudy(id=4, owner=test_user_2) + + db_session.add_all([test_user_1, test_user_2]) + db_session.commit() + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(users=users)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "groups, expected_ids", + [ + ([], {"1", "2", "3", "4"}), + (["1000", "2000"], {"1", "2", "3", "4"}), + (["1000"], {"1", "2", "4"}), + (["2000"], {"2", "3"}), + (["3000"], set()), + ], +) +def test_repository_get_all__study_groups_filter( + db_session: Session, + groups: t.List[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + test_group_1 = Group(id=1000) + test_group_2 = Group(id=2000) + + study_1 = VariantStudy(id=1, groups=[test_group_1]) + study_2 = VariantStudy(id=2, groups=[test_group_1, test_group_2]) + study_3 = RawStudy(id=3, groups=[test_group_2]) + study_4 = RawStudy(id=4, groups=[test_group_1]) + + db_session.add_all([test_group_1, test_group_2]) + db_session.commit() + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(groups=groups)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "study_ids, expected_ids", + [ + ([], {"1", "2", "3", "4"}), + (["1", "2", "3", "4"], {"1", "2", "3", "4"}), + (["1", "2", "4"], {"1", "2", "4"}), + (["2", "3"], {"2", "3"}), + (["2"], {"2"}), + (["3000"], set()), + ], +) +def test_repository_get_all__study_ids_filter( + db_session: Session, + study_ids: t.List[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1) + study_2 = VariantStudy(id=2) + study_3 = RawStudy(id=3) + study_4 = RawStudy(id=4) + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(study_ids=study_ids)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "exists, expected_ids", + [ + (None, {"1", "2", "3", "4"}), + (True, {"1", "2", "4"}), + (False, {"3"}), + ], +) +def test_repository_get_all__study_existence_filter( + db_session: Session, + exists: t.Optional[bool], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1) + study_2 = VariantStudy(id=2) + study_3 = RawStudy(id=3, missing=datetime.datetime.now()) + study_4 = RawStudy(id=4) + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(exists=exists)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "workspace, expected_ids", + [ + ("", {"1", "2", "3", "4"}), + ("workspace-1", {"3"}), + ("workspace-2", {"4"}), + ("workspace-3", set()), + ], +) +def test_repository_get_all__study_workspace_filter( + db_session: Session, + workspace: str, + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1) + study_2 = VariantStudy(id=2) + study_3 = RawStudy(id=3, workspace="workspace-1") + study_4 = RawStudy(id=4, workspace="workspace-2") + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(workspace=workspace)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "folder, expected_ids", + [ + ("", {"1", "2", "3", "4"}), + ("/home/folder-", {"1", "2", "3", "4"}), + ("/home/folder-1", {"1", "3"}), + ("/home/folder-2", {"2", "4"}), + ("/home/folder-3", set()), + ("folder-1", set()), + ], +) +def test_repository_get_all__study_folder_filter( + db_session: Session, + folder: str, + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1, folder="/home/folder-1") + study_2 = VariantStudy(id=2, folder="/home/folder-2") + study_3 = RawStudy(id=3, folder="/home/folder-1") + study_4 = RawStudy(id=4, folder="/home/folder-2") + + db_session.add_all([study_1, study_2, study_3, study_4]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(folder=folder)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids