diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 3aa6e60681..f1caf703a6 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore from antarest.core.interfaces.cache import ICache +from antarest.core.model import PublicMode from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag @@ -66,6 +67,17 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"): folder: str = "" +class QueryUser(BaseModel, frozen=True, extra="forbid"): + """ + This class object is build to pass on the user identity and its associated groups information + into the listing function get_all below + """ + + is_admin: bool = False + user_id: t.Optional[int] = None + user_groups: t.Optional[t.Sequence[str]] = None + + class StudySortBy(str, enum.Enum): """How to sort the results of studies query results""" @@ -182,6 +194,7 @@ def get_all( study_filter: StudyFilter = StudyFilter(), sort_by: t.Optional[StudySortBy] = None, pagination: StudyPagination = StudyPagination(), + query_user: QueryUser = QueryUser(), ) -> t.Sequence[Study]: """ Retrieve studies based on specified filters, sorting, and pagination. @@ -190,6 +203,7 @@ def get_all( study_filter: composed of all filtering criteria. sort_by: how the user would like the results to be sorted. pagination: specifies the number of results to displayed in each page and the actually displayed page. + query_user: user id and groups info Returns: The matching studies in proper order and pagination. @@ -243,6 +257,21 @@ def get_all( if study_filter.versions: q = q.filter(entity.version.in_(study_filter.versions)) + # permissions filtering + if not query_user.is_admin: + if query_user.user_id is not None: + condition_1 = entity.public_mode != PublicMode.NONE + condition_2 = entity.owner_id == query_user.user_id + condition_3 = Group.id.in_(query_user.user_groups or []) + if study_filter.groups: + q0 = q.filter(condition_3) + q = q0.union(q.filter(or_(condition_1, condition_2))) + else: + q0 = q.join(entity.groups).filter(condition_3) + q = q0.union(q.filter(or_(condition_1, condition_2))) + else: + return [] + if sort_by: if sort_by == StudySortBy.DATE_DESC: q = q.order_by(entity.created_at.desc()) diff --git a/antarest/study/service.py b/antarest/study/service.py index ae86fe62ae..ab60291cb5 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -92,7 +92,7 @@ StudyMetadataPatchDTO, StudySimResultDTO, ) -from antarest.study.repository import StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy +from antarest.study.repository import QueryUser, StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO from antarest.study.storage.rawstudy.model.filesystem.folder_node import ChildNotFoundError from antarest.study.storage.rawstudy.model.filesystem.ini_file_node import IniFileNode @@ -456,10 +456,29 @@ def get_studies_information( """ logger.info("Retrieving matching studies") studies: t.Dict[str, StudyMetadataDTO] = {} + + # retrieve user id and groups + user_id = None + user_groups = None + if params.user: + if params.user.id: + user_id = params.user.id + if params.user.groups: + user_groups = [group.id for group in params.user.groups] + else: + logger.error("FAIL permission: user is not logged") + raise UserHasNotPermissionError() + + query_user = QueryUser( + is_admin=params.user.is_site_admin() or params.user.is_admin_token(), + user_id=user_id, + user_groups=user_groups, + ) matching_studies = self.repository.get_all( study_filter=study_filter, sort_by=sort_by, pagination=pagination, + query_user=query_user, ) logger.info("Studies retrieved") for study in matching_studies: @@ -697,7 +716,7 @@ def get_input_matrix_startdate( def remove_duplicates(self) -> None: study_paths: t.Dict[str, t.List[str]] = {} - for study in self.repository.get_all(): + for study in self.repository.get_all(query_user=QueryUser(is_admin=True)): if isinstance(study, RawStudy) and not study.archived: path = str(study.path) if path not in study_paths: @@ -2140,7 +2159,7 @@ def check_and_update_all_study_versions_in_database(self, params: RequestParamet if params.user and not params.user.is_site_admin(): logger.error(f"User {params.user.id} is not site admin") raise UserHasNotPermissionError() - studies = self.repository.get_all(study_filter=StudyFilter(managed=False)) + studies = self.repository.get_all(study_filter=StudyFilter(managed=False), query_user=QueryUser(is_amin=True)) for study in studies: storage = self.storage_service.raw_study_service diff --git a/antarest/study/storage/auto_archive_service.py b/antarest/study/storage/auto_archive_service.py index 911b715f2d..f8e4bdee75 100644 --- a/antarest/study/storage/auto_archive_service.py +++ b/antarest/study/storage/auto_archive_service.py @@ -10,7 +10,7 @@ from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.study.model import RawStudy, Study -from antarest.study.repository import StudyFilter +from antarest.study.repository import QueryUser, StudyFilter from antarest.study.service import StudyService from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy @@ -28,7 +28,9 @@ def __init__(self, study_service: StudyService, config: Config): def _try_archive_studies(self) -> None: old_date = datetime.datetime.utcnow() - datetime.timedelta(days=self.config.storage.auto_archive_threshold_days) with db(): - studies: t.Sequence[Study] = self.study_service.repository.get_all(study_filter=StudyFilter(managed=True)) + studies: t.Sequence[Study] = self.study_service.repository.get_all( + study_filter=StudyFilter(managed=True), query_user=QueryUser(is_admin=True) + ) # list of study IDs and boolean indicating if it's a raw study (True) or a variant (False) study_ids_to_archive = [ (study.id, isinstance(study, RawStudy)) diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index f865ab613a..7de63bc508 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -4,7 +4,7 @@ from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyContentStatus -from antarest.study.repository import StudyMetadataRepository +from antarest.study.repository import QueryUser, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.helpers import with_db_context @@ -64,7 +64,7 @@ def test_lifecycle() -> None: c = repo.one(a.id) assert a == c - assert len(repo.get_all()) == 4 + assert len(repo.get_all(query_user=QueryUser(is_admin=True))) == 4 assert len(repo.get_all_raw(exists=True)) == 1 assert len(repo.get_all_raw(exists=False)) == 1 assert len(repo.get_all_raw()) == 2 diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index d30c051a6a..f4041d46b4 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -8,7 +8,7 @@ from antarest.core.interfaces.cache import ICache from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag -from antarest.study.repository import StudyFilter, StudyMetadataRepository +from antarest.study.repository import QueryUser, StudyFilter, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.db_statement_recorder import DBStatementRecorder @@ -66,7 +66,10 @@ def test_repository_get_all__general_case( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists)) + all_studies = repository.get_all( + study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists), + query_user=QueryUser(is_admin=True), + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -99,7 +102,7 @@ def test_repository_get_all__incompatible_case( # case 1 study_filter = StudyFilter(managed=False, variant=True) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=study_filter) + all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -110,7 +113,7 @@ def test_repository_get_all__incompatible_case( # case 2 study_filter = StudyFilter(workspace=test_workspace, variant=True) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=study_filter) + all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -121,7 +124,7 @@ def test_repository_get_all__incompatible_case( # case 3 study_filter = StudyFilter(exists=False, variant=True) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=study_filter) + all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -169,7 +172,7 @@ def test_repository_get_all__study_name_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(name=name)) + all_studies = repository.get_all(study_filter=StudyFilter(name=name), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -214,7 +217,7 @@ def test_repository_get_all__managed_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(managed=managed)) + all_studies = repository.get_all(study_filter=StudyFilter(managed=managed), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -254,7 +257,9 @@ def test_repository_get_all__archived_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(archived=archived)) + all_studies = repository.get_all( + study_filter=StudyFilter(archived=archived), query_user=QueryUser(is_admin=True) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -294,7 +299,7 @@ def test_repository_get_all__variant_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(variant=variant)) + all_studies = repository.get_all(study_filter=StudyFilter(variant=variant), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -336,7 +341,9 @@ def test_repository_get_all__study_version_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(versions=versions)) + all_studies = repository.get_all( + study_filter=StudyFilter(versions=versions), query_user=QueryUser(is_admin=True) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -384,7 +391,7 @@ def test_repository_get_all__study_users_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(users=users)) + all_studies = repository.get_all(study_filter=StudyFilter(users=users), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -432,7 +439,7 @@ def test_repository_get_all__study_groups_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(groups=groups)) + all_studies = repository.get_all(study_filter=StudyFilter(groups=groups), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -475,7 +482,9 @@ def test_repository_get_all__study_ids_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(study_ids=study_ids)) + all_studies = repository.get_all( + study_filter=StudyFilter(study_ids=study_ids), query_user=QueryUser(is_admin=True) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -515,7 +524,7 @@ def test_repository_get_all__study_existence_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(exists=exists)) + all_studies = repository.get_all(study_filter=StudyFilter(exists=exists), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -556,7 +565,9 @@ def test_repository_get_all__study_workspace_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(workspace=workspace)) + all_studies = repository.get_all( + study_filter=StudyFilter(workspace=workspace), query_user=QueryUser(is_admin=True) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -599,7 +610,7 @@ def test_repository_get_all__study_folder_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(folder=folder)) + all_studies = repository.get_all(study_filter=StudyFilter(folder=folder), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -650,7 +661,7 @@ def test_repository_get_all__study_tags_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(tags=tags)) + all_studies = repository.get_all(study_filter=StudyFilter(tags=tags), query_user=QueryUser(is_admin=True)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies]