From e8a0d6ffc59aa3ff40664080d99362e188cf48eb Mon Sep 17 00:00:00 2001 From: mabw-rte <41002227+mabw-rte@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:35:49 +0100 Subject: [PATCH] feat(permission-db): check user permission through the db query (#1931) Context: Additionally to ANT-1107 related tags db management, another problem of permission filtering is blocking the pagination from working properly. Issue: The permission to read studies in the search engine function is performed after the db query. This is creating a problem for our pagination process. Solution: Pass on the permission status directly in queries to the db. --- antarest/launcher/service.py | 8 +- antarest/study/repository.py | 142 ++++- antarest/study/service.py | 39 +- .../study/storage/auto_archive_service.py | 7 +- antarest/study/storage/rawstudy/watcher.py | 1 + antarest/study/web/studies_blueprint.py | 9 +- .../studies_blueprint/test_get_studies.py | 599 ++++++++++++++++++ tests/storage/repository/test_study.py | 4 +- tests/storage/test_service.py | 23 +- tests/study/test_repository.py | 373 ++++++++++- 10 files changed, 1111 insertions(+), 94 deletions(-) diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index 4c4ea9aa15..3165df48c7 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -40,7 +40,7 @@ from antarest.launcher.repository import JobResultRepository from antarest.launcher.ssh_client import calculates_slurm_load from antarest.launcher.ssh_config import SSHConfigDTO -from antarest.study.repository import StudyFilter +from antarest.study.repository import AccessPermissions, StudyFilter from antarest.study.service import StudyService from antarest.study.storage.utils import assert_permission, extract_output_name, find_single_output_path @@ -312,7 +312,11 @@ def _filter_from_user_permission(self, job_results: List[JobResult], user: Optio if study_ids: studies = { study.id: study - for study in self.study_service.repository.get_all(study_filter=StudyFilter(study_ids=study_ids)) + for study in self.study_service.repository.get_all( + study_filter=StudyFilter( + study_ids=study_ids, access_permissions=AccessPermissions.from_params(user) + ) + ) } else: studies = {} diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 25cdd77dd7..6e89d33d5a 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -3,10 +3,13 @@ import typing as t from pydantic import BaseModel, NonNegativeInt -from sqlalchemy import func, not_, or_ # type: ignore -from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore +from sqlalchemy import and_, func, not_, or_, sql # type: ignore +from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore from antarest.core.interfaces.cache import ICache +from antarest.core.jwt import JWTUser +from antarest.core.model import PublicMode +from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag @@ -34,6 +37,43 @@ def escape_like(string: str, escape_char: str = "\\") -> str: return string.replace(escape_char, escape_char * 2).replace("%", escape_char + "%").replace("_", escape_char + "_") +class AccessPermissions(BaseModel, frozen=True, extra="forbid"): + """ + This class object is build to pass on the user identity and its associated groups information + into the listing function get_all below + """ + + is_admin: bool = False + user_id: t.Optional[int] = None + user_groups: t.Sequence[str] = () + + @classmethod + def from_params(cls, params: t.Union[RequestParameters, JWTUser]) -> "AccessPermissions": + """ + This function makes it easier to pass on user ids and groups into the repository filtering function by + extracting the associated `AccessPermissions` object. + + Args: + params: `RequestParameters` or `JWTUser` holding user ids and groups + + Returns: `AccessPermissions` + + """ + if isinstance(params, RequestParameters): + user = params.user + else: + user = params + + if user: + return cls( + is_admin=user.is_site_admin() or user.is_admin_token(), + user_id=user.id, + user_groups=[group.id for group in user.groups], + ) + else: + return cls() + + class StudyFilter(BaseModel, frozen=True, extra="forbid"): """Study filter class gathering the main filtering parameters @@ -50,6 +90,7 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"): exists: if raw study missing workspace: optional workspace of the study folder: optional folder prefix of the study + access_permissions: query user ID, groups and admins status """ name: str = "" @@ -64,6 +105,7 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"): exists: t.Optional[bool] = None workspace: str = "" folder: str = "" + access_permissions: AccessPermissions = AccessPermissions() class StudySortBy(str, enum.Enum): @@ -198,6 +240,63 @@ def get_all( # efficiently (see: `AbstractStorageService.get_study_information`) entity = with_polymorphic(Study, "*") + q = self._search_studies(study_filter) + + # sorting + if sort_by: + if sort_by == StudySortBy.DATE_DESC: + q = q.order_by(entity.created_at.desc()) + elif sort_by == StudySortBy.DATE_ASC: + q = q.order_by(entity.created_at.asc()) + elif sort_by == StudySortBy.NAME_DESC: + q = q.order_by(func.upper(entity.name).desc()) + elif sort_by == StudySortBy.NAME_ASC: + q = q.order_by(func.upper(entity.name).asc()) + else: + raise NotImplementedError(sort_by) + + # pagination + if pagination.page_nb or pagination.page_size: + q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size) + + studies: t.Sequence[Study] = q.all() + return studies + + def count_studies(self, study_filter: StudyFilter = StudyFilter()) -> int: + """ + Count all studies matching with specified filters. + + Args: + study_filter: composed of all filtering criteria. + + Returns: + Integer, corresponding to total number of studies matching with specified filters. + """ + q = self._search_studies(study_filter) + + total: int = q.count() + + return total + + def _search_studies( + self, + study_filter: StudyFilter, + ) -> Query: + """ + Build a `SQL Query` based on specified filters. + + Args: + study_filter: composed of all filtering criteria. + + Returns: + The `Query` corresponding to specified criteria (except for permissions). + """ + # When we fetch a study, we also need to fetch the associated owner and groups + # to check the permissions of the current user efficiently. + # We also need to fetch the additional data to display the study information + # efficiently (see: `AbstractStorageService.get_study_information`) + entity = with_polymorphic(Study, "*") + # noinspection PyTypeChecker q = self.session.query(entity) if study_filter.exists is not None: @@ -216,11 +315,9 @@ def get_all( q = q.filter(entity.type == "rawstudy") q = q.filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME) if study_filter.study_ids: - q = q.filter(entity.id.in_(study_filter.study_ids)) + q = q.filter(entity.id.in_(study_filter.study_ids)) if study_filter.study_ids else q if study_filter.users: q = q.filter(entity.owner_id.in_(study_filter.users)) - if study_filter.groups: - q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) if study_filter.tags: upper_tags = [tag.upper() for tag in study_filter.tags] q = q.join(entity.tags).filter(func.upper(Tag.label).in_(upper_tags)) @@ -242,24 +339,27 @@ def get_all( if study_filter.versions: q = q.filter(entity.version.in_(study_filter.versions)) - if sort_by: - if sort_by == StudySortBy.DATE_DESC: - q = q.order_by(entity.created_at.desc()) - elif sort_by == StudySortBy.DATE_ASC: - q = q.order_by(entity.created_at.asc()) - elif sort_by == StudySortBy.NAME_DESC: - q = q.order_by(func.upper(entity.name).desc()) - elif sort_by == StudySortBy.NAME_ASC: - q = q.order_by(func.upper(entity.name).asc()) + # permissions + groups filtering + if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is not None: + condition_1 = entity.public_mode != PublicMode.NONE + condition_2 = entity.owner_id == study_filter.access_permissions.user_id + q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups)) + if study_filter.groups: + q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) + q2 = q1.intersect(q2) + q = q2.union( + q.join(entity.groups).filter(and_(or_(condition_1, condition_2), Group.id.in_(study_filter.groups))) + ) else: - raise NotImplementedError(sort_by) - - # pagination - if pagination.page_nb or pagination.page_size: - q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size) + q = q1.union(q.filter(or_(condition_1, condition_2))) + elif not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None: + # return empty result + # noinspection PyTypeChecker + q = self.session.query(entity).filter(sql.false()) + elif study_filter.groups: + q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) - studies: t.Sequence[Study] = q.all() - return studies + return q def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]: query = self.session.query(RawStudy) diff --git a/antarest/study/service.py b/antarest/study/service.py index 9b22ae7638..81af28a473 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -96,7 +96,13 @@ StudyMetadataPatchDTO, StudySimResultDTO, ) -from antarest.study.repository import StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy +from antarest.study.repository import ( + AccessPermissions, + StudyFilter, + StudyMetadataRepository, + StudyPagination, + StudySortBy, +) from antarest.study.storage.matrix_profile import adjust_matrix_columns_index from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO from antarest.study.storage.rawstudy.model.filesystem.folder_node import ChildNotFoundError @@ -445,7 +451,6 @@ def edit_comments( def get_studies_information( self, - params: RequestParameters, study_filter: StudyFilter, sort_by: t.Optional[StudySortBy] = None, pagination: StudyPagination = StudyPagination(), @@ -453,7 +458,6 @@ def get_studies_information( """ Get information for matching studies of a search query. Args: - params: request parameters study_filter: filtering parameters sort_by: how to sort the db query results pagination: set offset and limit for db query @@ -472,18 +476,7 @@ def get_studies_information( study_metadata = self._try_get_studies_information(study) if study_metadata is not None: studies[study_metadata.id] = study_metadata - return { - s.id: s - for s in filter( - lambda study_dto: assert_permission( - params.user, - study_dto, - StudyPermissionType.READ, - raising=False, - ), - studies.values(), - ) - } + return studies def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]: try: @@ -2139,10 +2132,24 @@ def update_matrix( raise BadEditInstructionException(str(exc)) from exc def check_and_update_all_study_versions_in_database(self, params: RequestParameters) -> None: + """ + This function updates studies version on the db. + + **Warnings: Only users with Admins rights should be able to run this function.** + + Args: + params: Request parameters holding user ID and groups + + Raises: + UserHasNotPermissionError: if params user is not admin. + + """ if params.user and not params.user.is_site_admin(): logger.error(f"User {params.user.id} is not site admin") raise UserHasNotPermissionError() - studies = self.repository.get_all(study_filter=StudyFilter(managed=False)) + studies = self.repository.get_all( + study_filter=StudyFilter(managed=False, access_permissions=AccessPermissions.from_params(params)) + ) for study in studies: storage = self.storage_service.raw_study_service diff --git a/antarest/study/storage/auto_archive_service.py b/antarest/study/storage/auto_archive_service.py index 911b715f2d..a1eafc40a3 100644 --- a/antarest/study/storage/auto_archive_service.py +++ b/antarest/study/storage/auto_archive_service.py @@ -10,7 +10,7 @@ from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.study.model import RawStudy, Study -from antarest.study.repository import StudyFilter +from antarest.study.repository import AccessPermissions, StudyFilter from antarest.study.service import StudyService from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy @@ -28,7 +28,10 @@ def __init__(self, study_service: StudyService, config: Config): def _try_archive_studies(self) -> None: old_date = datetime.datetime.utcnow() - datetime.timedelta(days=self.config.storage.auto_archive_threshold_days) with db(): - studies: t.Sequence[Study] = self.study_service.repository.get_all(study_filter=StudyFilter(managed=True)) + # in this part full `Read` rights over studies are granted to this function + studies: t.Sequence[Study] = self.study_service.repository.get_all( + study_filter=StudyFilter(managed=True, access_permissions=AccessPermissions(is_admin=True)) + ) # list of study IDs and boolean indicating if it's a raw study (True) or a variant (False) study_ids_to_archive = [ (study.id, isinstance(study, RawStudy)) diff --git a/antarest/study/storage/rawstudy/watcher.py b/antarest/study/storage/rawstudy/watcher.py index d2b5c9883c..8f593ce0d6 100644 --- a/antarest/study/storage/rawstudy/watcher.py +++ b/antarest/study/storage/rawstudy/watcher.py @@ -94,6 +94,7 @@ def _loop(self) -> None: "Removing duplicates, this is a temporary fix that should be removed when previous duplicates are removed" ) with db(): + # in this part full `Read` rights over studies are granted to this function self.study_service.remove_duplicates() except Exception as e: logger.error("Unexpected error when removing duplicates", exc_info=e) diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index beeecd65c5..9007ed2894 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -15,7 +15,7 @@ from antarest.core.filetransfer.service import FileTransferManager from antarest.core.jwt import JWTUser from antarest.core.model import PublicMode -from antarest.core.requests import RequestParameters +from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.utils.utils import BadArchiveContent, sanitize_uuid from antarest.core.utils.web import APITag from antarest.login.auth import Auth @@ -28,7 +28,7 @@ StudyMetadataPatchDTO, StudySimResultDTO, ) -from antarest.study.repository import StudyFilter, StudyPagination, StudySortBy +from antarest.study.repository import AccessPermissions, StudyFilter, StudyPagination, StudySortBy from antarest.study.service import StudyService from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO @@ -144,6 +144,9 @@ def get_studies( user_list = [int(v) for v in _split_comma_separated_values(users)] + if not params.user: + raise UserHasNotPermissionError("FAIL permission: user is not logged") + study_filter = StudyFilter( name=name, managed=managed, @@ -157,10 +160,10 @@ def get_studies( exists=exists, workspace=workspace, folder=folder, + access_permissions=AccessPermissions.from_params(params), ) matching_studies = study_service.get_studies_information( - params=params, study_filter=study_filter, sort_by=sort_by, pagination=StudyPagination(page_nb=page_nb, page_size=page_size), diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index 579ad2dfe7..48dadaf829 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -11,6 +11,7 @@ from starlette.testclient import TestClient from antarest.core.model import PublicMode +from antarest.core.roles import RoleType from antarest.core.tasks.model import TaskStatus from tests.integration.assets import ASSETS_DIR from tests.integration.utils import wait_task_completion @@ -802,6 +803,604 @@ def test_study_listing( values = list(study_map.values()) assert values == sorted(values, key=lambda x: x["created"], reverse=True) + def test_get_studies__access_permissions(self, client: TestClient, admin_access_token: str) -> None: + """ + Test the access permissions for the `GET /studies` endpoint. + + Args: + client: client App fixture to perform the requests + admin_access_token: fixture to get the admin access token + + Returns: + + """ + ########################## + # 1. Database initialization + ########################## + + users = {"user_1": "pass_1", "user_2": "pass_2", "user_3": "pass_3"} + users_tokens = {} + users_ids = {} + groups = {"group_1", "group_2", "group_3"} + groups_ids = {} + user_groups_mapping = {"user_1": ["group_2"], "user_2": ["group_1"], "user_3": []} + + # create users + for user, password in users.items(): + res = client.post( + "/v1/users", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"name": user, "password": password}, + ) + res.raise_for_status() + users_ids[user] = res.json().get("id") + + # create groups + for group in groups: + res = client.post( + "/v1/groups", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"name": group}, + ) + res.raise_for_status() + groups_ids[group] = res.json().get("id") + + # associate users to groups + for user, groups in user_groups_mapping.items(): + user_id = users_ids[user] + for group in groups: + group_id = groups_ids[group] + res = client.post( + "/v1/roles", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"identity_id": user_id, "group_id": group_id, "type": RoleType.READER.value}, + ) + res.raise_for_status() + + # login users + for user, password in users.items(): + res = client.post( + "/v1/login", + json={"username": user, "password": password}, + ) + res.raise_for_status() + assert res.json().get("user") == users_ids[user] + users_tokens[user] = res.json().get("access_token") + + # studies creation + studies_ids_mapping = {} + + # create variant studies for user_1 and user_2 that are part of some groups + # studies that have owner and groups + for study, study_info in { + "study_1": {"owner": "user_1", "groups": ["group_1"]}, + "study_2": {"owner": "user_1", "groups": ["group_2"]}, + "study_4": {"owner": "user_2", "groups": ["group_1"]}, + "study_5": {"owner": "user_2", "groups": ["group_2"]}, + "study_7": {"owner": "user_1", "groups": ["group_1", "group_2"]}, + "study_8": {"owner": "user_2", "groups": ["group_1", "group_2"]}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": f"dummy_{study}"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + res = client.post( + f"{STUDIES_URL}/{study_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + owner_id = users_ids[study_info.get("owner")] + res = client.put( + f"{STUDIES_URL}/{study_id}/owner/{owner_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + for group in study_info.get("groups"): + group_id = groups_ids[group] + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + # studies that have owner but no groups + for study, study_info in { + "study_10": {"owner": "user_1"}, + "study_11": {"owner": "user_2"}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": f"dummy_{study}"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + res = client.post( + f"{STUDIES_URL}/{study_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + owner_id = users_ids[study_info.get("owner")] + res = client.put( + f"{STUDIES_URL}/{study_id}/owner/{owner_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + # studies that have groups but no owner + for study, study_info in { + "study_3": {"groups": ["group_1"]}, + "study_6": {"groups": ["group_2"]}, + "study_9": {"groups": ["group_1", "group_2"]}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": f"dummy_{study}"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + res = client.post( + f"{STUDIES_URL}/{study_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + for group in study_info.get("groups"): + group_id = groups_ids[group] + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + + # create variant studies with neither owner nor groups + for study, study_info in { + "study_12": {"public_mode": None}, + "study_13": {"public_mode": PublicMode.READ.value}, + "study_14": {"public_mode": PublicMode.EDIT.value}, + "study_15": {"public_mode": PublicMode.EXECUTE.value}, + "study_16": {"public_mode": PublicMode.FULL.value}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": f"dummy_{study}"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + res = client.post( + f"{STUDIES_URL}/{study_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + public_mode = study_info.get("public_mode") + if public_mode: + res = client.put( + f"{STUDIES_URL}/{study_id}/public_mode/{public_mode}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + + # create raw studies for user_1 and user_2 that are part of some groups + # studies that have owner and groups + for study, study_info in { + "study_17": {"owner": "user_1", "groups": ["group_1"]}, + "study_18": {"owner": "user_1", "groups": ["group_2"]}, + "study_20": {"owner": "user_2", "groups": ["group_1"]}, + "study_21": {"owner": "user_2", "groups": ["group_2"]}, + "study_23": {"owner": "user_1", "groups": ["group_1", "group_2"]}, + "study_24": {"owner": "user_2", "groups": ["group_1", "group_2"]}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + owner = users_ids[study_info.get("owner")] + res = client.put( + f"{STUDIES_URL}/{study_id}/owner/{owner}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + for group in study_info.get("groups"): + group_id = groups_ids[group] + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + # studies that have owner but no groups + for study, study_info in { + "study_26": {"owner": "user_1"}, + "study_27": {"owner": "user_2"}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + owner_id = users_ids[study_info.get("owner")] + res = client.put( + f"{STUDIES_URL}/{study_id}/owner/{owner_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + # studies that have groups but no owner + for study, study_info in { + "study_19": {"groups": ["group_1"]}, + "study_22": {"groups": ["group_2"]}, + "study_25": {"groups": ["group_1", "group_2"]}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + for group in study_info.get("groups"): + group_id = groups_ids[group] + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + + # create raw studies with neither owner nor groups + for study, study_info in { + "study_28": {"public_mode": None}, + "study_29": {"public_mode": PublicMode.READ.value}, + "study_30": {"public_mode": PublicMode.EDIT.value}, + "study_31": {"public_mode": PublicMode.EXECUTE.value}, + "study_32": {"public_mode": PublicMode.FULL.value}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + public_mode = study_info.get("public_mode") + if public_mode: + res = client.put( + f"{STUDIES_URL}/{study_id}/public_mode/{public_mode}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + + # create studies for user_3 that is not part of any group + # variant studies + for study, study_info in { + "study_33": {"groups": ["group_1"]}, + "study_35": {"groups": []}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": f"dummy_{study}"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + res = client.post( + f"{STUDIES_URL}/{study_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + owner_id = users_ids["user_3"] + res = client.put( + f"{STUDIES_URL}/{study_id}/owner/{owner_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + for group in study_info.get("groups", []): + group_id = groups_ids[group] + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + # raw studies + for study, study_info in { + "study_34": {"groups": ["group_2"]}, + "study_36": {"groups": []}, + }.items(): + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": study}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + studies_ids_mapping[study] = study_id + owner_id = users_ids["user_3"] + res = client.put( + f"{STUDIES_URL}/{study_id}/owner/{owner_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + for group in study_info.get("groups"): + group_id = groups_ids[group] + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + + # create studies for group_3 that has no user + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "dummy_study_37"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + res = client.post( + f"{STUDIES_URL}/{study_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "study_37"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + group_3_id = groups_ids["group_3"] + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_3_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + studies_ids_mapping["study_37"] = study_id + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "study_38"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + study_id = res.json() + res = client.put( + f"{STUDIES_URL}/{study_id}/groups/{group_3_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == 200, res.json() + studies_ids_mapping["study_38"] = study_id + + # verify the studies creation was done correctly and that admin has access to all studies + all_studies = set(studies_ids_mapping.values()) + studies_target_info = { + "study_1": { + "type": "variantstudy", + "owner": "user_1", + "groups": ["group_1"], + "public_mode": PublicMode.NONE, + }, + "study_2": { + "type": "variantstudy", + "owner": "user_1", + "groups": ["group_2"], + "public_mode": PublicMode.NONE, + }, + "study_3": {"type": "variantstudy", "owner": None, "groups": ["group_1"], "public_mode": PublicMode.NONE}, + "study_4": { + "type": "variantstudy", + "owner": "user_2", + "groups": ["group_1"], + "public_mode": PublicMode.NONE, + }, + "study_5": { + "type": "variantstudy", + "owner": "user_2", + "groups": ["group_2"], + "public_mode": PublicMode.NONE, + }, + "study_6": {"type": "variantstudy", "owner": None, "groups": ["group_2"], "public_mode": PublicMode.NONE}, + "study_7": { + "type": "variantstudy", + "owner": "user_1", + "groups": ["group_1", "group_2"], + "public_mode": PublicMode.NONE, + }, + "study_8": { + "type": "variantstudy", + "owner": "user_2", + "groups": ["group_1", "group_2"], + "public_mode": PublicMode.NONE, + }, + "study_9": { + "type": "variantstudy", + "owner": None, + "groups": ["group_1", "group_2"], + "public_mode": PublicMode.NONE, + }, + "study_10": {"type": "variantstudy", "owner": "user_1", "groups": None, "public_mode": PublicMode.NONE}, + "study_11": {"type": "variantstudy", "owner": "user_2", "groups": None, "public_mode": PublicMode.NONE}, + "study_12": {"type": "variantstudy", "owner": None, "groups": None, "public_mode": PublicMode.NONE}, + "study_13": {"type": "variantstudy", "owner": None, "groups": None, "public_mode": PublicMode.READ}, + "study_14": {"type": "variantstudy", "owner": None, "groups": None, "public_mode": PublicMode.EDIT}, + "study_15": {"type": "variantstudy", "owner": None, "groups": None, "public_mode": PublicMode.EXECUTE}, + "study_16": {"type": "variantstudy", "owner": None, "groups": None, "public_mode": PublicMode.FULL}, + "study_17": {"type": "rawstudy", "owner": "user_1", "groups": ["group_1"], "public_mode": PublicMode.NONE}, + "study_18": {"type": "rawstudy", "owner": "user_1", "groups": ["group_2"], "public_mode": PublicMode.NONE}, + "study_19": {"type": "rawstudy", "owner": None, "groups": ["group_1"], "public_mode": PublicMode.NONE}, + "study_20": {"type": "rawstudy", "owner": "user_2", "groups": ["group_1"], "public_mode": PublicMode.NONE}, + "study_21": {"type": "rawstudy", "owner": "user_2", "groups": ["group_2"], "public_mode": PublicMode.NONE}, + "study_22": {"type": "rawstudy", "owner": None, "groups": ["group_2"], "public_mode": PublicMode.NONE}, + "study_23": { + "type": "rawstudy", + "owner": "user_1", + "groups": ["group_1", "group_2"], + "public_mode": PublicMode.NONE, + }, + "study_24": { + "type": "rawstudy", + "owner": "user_2", + "groups": ["group_1", "group_2"], + "public_mode": PublicMode.NONE, + }, + "study_25": { + "type": "rawstudy", + "owner": None, + "groups": ["group_1", "group_2"], + "public_mode": PublicMode.NONE, + }, + "study_26": {"type": "rawstudy", "owner": "user_1", "groups": None, "public_mode": PublicMode.NONE}, + "study_27": {"type": "rawstudy", "owner": "user_2", "groups": None, "public_mode": PublicMode.NONE}, + "study_28": {"type": "rawstudy", "owner": None, "groups": None, "public_mode": PublicMode.NONE}, + "study_29": {"type": "rawstudy", "owner": None, "groups": None, "public_mode": PublicMode.READ}, + "study_30": {"type": "rawstudy", "owner": None, "groups": None, "public_mode": PublicMode.EDIT}, + "study_31": {"type": "rawstudy", "owner": None, "groups": None, "public_mode": PublicMode.EXECUTE}, + "study_32": {"type": "rawstudy", "owner": None, "groups": None, "public_mode": PublicMode.FULL}, + "study_33": { + "type": "variantstudy", + "owner": "user_3", + "groups": ["group_1"], + "public_mode": PublicMode.NONE, + }, + "study_34": {"type": "rawstudy", "owner": "user_3", "groups": ["group_2"], "public_mode": PublicMode.NONE}, + "study_35": {"type": "variantstudy", "owner": "user_3", "groups": None, "public_mode": PublicMode.NONE}, + "study_36": {"type": "rawstudy", "owner": "user_3", "groups": None, "public_mode": PublicMode.NONE}, + "study_37": {"type": "variantstudy", "owner": None, "groups": ["group_3"], "public_mode": PublicMode.NONE}, + "study_38": {"type": "rawstudy", "owner": None, "groups": ["group_3"], "public_mode": PublicMode.NONE}, + } + res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert len(all_studies) == 38 + assert not all_studies.difference(study_map) + for study, study_info in studies_target_info.items(): + study_id = studies_ids_mapping[study] + study_data = study_map[study_id] + assert study_data.get("type") == study_info.get("type") + if study_data.get("owner") and study_info.get("owner"): + assert study_data["owner"]["name"] == study_info.get("owner") + assert study_data["owner"]["id"] == users_ids[study_info.get("owner")] + else: + assert not study_info.get("owner") + assert study_data["owner"]["name"] == "admin" + if study_data.get("groups"): + expected_groups = set(study_info.get("groups")) + assert all( + (group["name"] in expected_groups) and groups_ids[group["name"]] == group["id"] + for group in study_data["groups"] + ) + else: + assert not study_info.get("groups") + assert study_data["public_mode"] == study_info.get("public_mode") + + ########################## + # 2. Tests + ########################## + + # user_1 access + requests_params_expected_studies = [ + # fmt: off + ([], {"1", "2", "5", "6", "7", "8", "9", "10", "13", "14", "15", "16", "17", + "18", "21", "22", "23", "24", "25", "26", "29", "30", "31", "32", "34"}), + (["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}), + (["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}), + (["3"], set()), + (["1", "2"], {"1", "2", "5", "6", "7", "8", "9", "17", "18", "21", "22", "23", "24", "25", "34"}), + (["1", "3"], {"1", "7", "8", "9", "17", "23", "24", "25"}), + (["2", "3"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}), + ( + ["1", "2", "3"], + {"1", "2", "5", "6", "7", "8", "9", "17", "18", "21", "22", "23", "24", "25", "34"}, + ), + ] + for request_groups_numbers, expected_studies_numbers in requests_params_expected_studies: + request_groups_ids = [groups_ids[f"group_{group_number}"] for group_number in request_groups_numbers] + expected_studies = { + studies_ids_mapping[f"study_{study_number}"] for study_number in expected_studies_numbers + } + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {users_tokens['user_1']}"}, + params={"groups": ",".join(request_groups_ids)} if request_groups_ids else {}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not expected_studies.difference(set(study_map)) + assert not all_studies.difference(expected_studies).intersection(set(study_map)) + + # user_2 access + requests_params_expected_studies = [ + # fmt: off + ([], {"1", "3", "4", "5", "7", "8", "9", "11", "13", "14", "15", "16", "17", + "19", "20", "21", "23", "24", "25", "27", "29", "30", "31", "32", "33"}), + (["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}), + (["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}), + (["3"], set()), + (["1", "2"], {"1", "3", "4", "5", "7", "8", "9", "17", "19", "20", "21", "23", "24", "25", "33"}), + (["1", "3"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}), + (["2", "3"], {"5", "7", "8", "9", "21", "23", "24", "25"}), + ( + ["1", "2", "3"], + {"1", "3", "4", "5", "7", "8", "9", "17", "19", "20", "21", "23", "24", "25", "33"}, + ), + ] + for request_groups_numbers, expected_studies_numbers in requests_params_expected_studies: + request_groups_ids = [groups_ids[f"group_{group_number}"] for group_number in request_groups_numbers] + expected_studies = { + studies_ids_mapping[f"study_{study_number}"] for study_number in expected_studies_numbers + } + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {users_tokens['user_2']}"}, + params={"groups": ",".join(request_groups_ids)} if request_groups_ids else {}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not expected_studies.difference(set(study_map)) + assert not all_studies.difference(expected_studies).intersection(set(study_map)) + + # user_3 access + requests_params_expected_studies = [ + ([], {"13", "14", "15", "16", "29", "30", "31", "32", "33", "34", "35", "36"}), + (["1"], {"33"}), + (["2"], {"34"}), + (["3"], set()), + (["1", "2"], {"33", "34"}), + (["1", "3"], {"33"}), + (["2", "3"], {"34"}), + (["1", "2", "3"], {"33", "34"}), + ] + for request_groups_numbers, expected_studies_numbers in requests_params_expected_studies: + request_groups_ids = [groups_ids[f"group_{group_number}"] for group_number in request_groups_numbers] + expected_studies = { + studies_ids_mapping[f"study_{study_number}"] for study_number in expected_studies_numbers + } + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {users_tokens['user_3']}"}, + params={"groups": ",".join(request_groups_ids)} if request_groups_ids else {}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not expected_studies.difference(set(study_map)) + assert not all_studies.difference(expected_studies).intersection(set(study_map)) + def test_get_studies__invalid_parameters( self, client: TestClient, diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index 5535f4d074..7aa5fb23cd 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -7,7 +7,7 @@ from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, StudyContentStatus -from antarest.study.repository import StudyMetadataRepository +from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy @@ -69,7 +69,7 @@ def test_lifecycle(db_session: Session) -> None: c = repo.one(a_id) assert a_id == c.id - assert len(repo.get_all()) == 4 + assert len(repo.get_all(study_filter=StudyFilter(access_permissions=AccessPermissions(is_admin=True)))) == 4 assert len(repo.get_all_raw(exists=True)) == 1 assert len(repo.get_all_raw(exists=False)) == 1 assert len(repo.get_all_raw()) == 2 diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index fa7ed5c62d..c322b69672 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -44,7 +44,7 @@ TimeSerie, TimeSeriesData, ) -from antarest.study.repository import StudyFilter, StudyMetadataRepository +from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository from antarest.study.service import MAX_MISSING_STUDY_TIMEOUT, StudyService, StudyUpgraderTask, UserHasNotPermissionError from antarest.study.storage.patch_service import PatchService from antarest.study.storage.rawstudy.model.filesystem.config.model import ( @@ -172,6 +172,7 @@ def test_study_listing(db_session: Session) -> None: config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()})) repository = StudyMetadataRepository(cache_service=Mock(spec=ICache), session=db_session) service = build_study_service(raw_study_service, repository, config, cache_service=cache) + params: RequestParameters = RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")) # retrieve studies that are not managed # use the db recorder to check that: @@ -179,10 +180,7 @@ def test_study_listing(db_session: Session) -> None: # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=False, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=False, access_permissions=AccessPermissions.from_params(params)), ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -196,10 +194,7 @@ def test_study_listing(db_session: Session) -> None: # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=True, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=True, access_permissions=AccessPermissions.from_params(params)), ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -213,10 +208,7 @@ def test_study_listing(db_session: Session) -> None: # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=None, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=None, access_permissions=AccessPermissions.from_params(params)), ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -230,10 +222,7 @@ def test_study_listing(db_session: Session) -> None: # 2- the `put` method of `cache` was never used with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=None, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=None, access_permissions=AccessPermissions.from_params(params)), ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) with contextlib.suppress(AssertionError): diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index 0a6063fac5..4762cc7fed 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -6,9 +6,10 @@ from sqlalchemy.orm import Session # type: ignore from antarest.core.interfaces.cache import ICache +from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag -from antarest.study.repository import StudyFilter, StudyMetadataRepository +from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.db_statement_recorder import DBStatementRecorder @@ -38,7 +39,7 @@ (False, [1, 3, 5, 7], None, {"7"}), ], ) -def test_repository_get_all__general_case( +def test_get_all__general_case( db_session: Session, managed: t.Union[bool, None], study_ids: t.Sequence[str], @@ -66,7 +67,14 @@ def test_repository_get_all__general_case( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists)) + all_studies = repository.get_all( + study_filter=StudyFilter( + managed=managed, + study_ids=study_ids, + exists=exists, + access_permissions=AccessPermissions(is_admin=True), + ) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -77,7 +85,7 @@ def test_repository_get_all__general_case( assert {s.id for s in all_studies} == expected_ids -def test_repository_get_all__incompatible_case( +def test_get_all__incompatible_case( db_session: Session, ) -> None: test_workspace = "workspace1" @@ -97,7 +105,7 @@ def test_repository_get_all__incompatible_case( db_session.commit() # case 1 - study_filter = StudyFilter(managed=False, variant=True) + study_filter = StudyFilter(managed=False, variant=True, access_permissions=AccessPermissions(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] @@ -108,7 +116,9 @@ def test_repository_get_all__incompatible_case( assert not {s.id for s in all_studies} # case 2 - study_filter = StudyFilter(workspace=test_workspace, variant=True) + study_filter = StudyFilter( + workspace=test_workspace, variant=True, access_permissions=AccessPermissions(is_admin=True) + ) with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] @@ -119,7 +129,7 @@ def test_repository_get_all__incompatible_case( assert not {s.id for s in all_studies} # case 3 - study_filter = StudyFilter(exists=False, variant=True) + study_filter = StudyFilter(exists=False, variant=True, access_permissions=AccessPermissions(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] @@ -144,7 +154,7 @@ def test_repository_get_all__incompatible_case( ("specie-suffix", set()), ], ) -def test_repository_get_all__study_name_filter( +def test_get_all__study_name_filter( db_session: Session, name: str, expected_ids: t.Set[str], @@ -169,7 +179,9 @@ def test_repository_get_all__study_name_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(name=name)) + all_studies = repository.get_all( + study_filter=StudyFilter(name=name, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -188,7 +200,7 @@ def test_repository_get_all__study_name_filter( (False, {"6", "7"}), ], ) -def test_repository_get_all__managed_study_filter( +def test_get_all__managed_study_filter( db_session: Session, managed: t.Optional[bool], expected_ids: t.Set[str], @@ -214,7 +226,9 @@ def test_repository_get_all__managed_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(managed=managed)) + all_studies = repository.get_all( + study_filter=StudyFilter(managed=managed, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -233,7 +247,7 @@ def test_repository_get_all__managed_study_filter( (False, {"2", "4"}), ], ) -def test_repository_get_all__archived_study_filter( +def test_get_all__archived_study_filter( db_session: Session, archived: t.Optional[bool], expected_ids: t.Set[str], @@ -254,7 +268,9 @@ def test_repository_get_all__archived_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(archived=archived)) + all_studies = repository.get_all( + study_filter=StudyFilter(archived=archived, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -273,7 +289,7 @@ def test_repository_get_all__archived_study_filter( (False, {"3", "4"}), ], ) -def test_repository_get_all__variant_study_filter( +def test_get_all__variant_study_filter( db_session: Session, variant: t.Optional[bool], expected_ids: t.Set[str], @@ -294,7 +310,9 @@ def test_repository_get_all__variant_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(variant=variant)) + all_studies = repository.get_all( + study_filter=StudyFilter(variant=variant, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -315,7 +333,7 @@ def test_repository_get_all__variant_study_filter( (["3"], set()), ], ) -def test_repository_get_all__study_version_filter( +def test_get_all__study_version_filter( db_session: Session, versions: t.Sequence[str], expected_ids: t.Set[str], @@ -336,7 +354,9 @@ def test_repository_get_all__study_version_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(versions=versions)) + all_studies = repository.get_all( + study_filter=StudyFilter(versions=versions, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -357,7 +377,7 @@ def test_repository_get_all__study_version_filter( (["3000"], set()), ], ) -def test_repository_get_all__study_users_filter( +def test_get_all__study_users_filter( db_session: Session, users: t.Sequence["int"], expected_ids: t.Set[str], @@ -384,7 +404,9 @@ def test_repository_get_all__study_users_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(users=users)) + all_studies = repository.get_all( + study_filter=StudyFilter(users=users, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -405,7 +427,7 @@ def test_repository_get_all__study_users_filter( (["3000"], set()), ], ) -def test_repository_get_all__study_groups_filter( +def test_get_all__study_groups_filter( db_session: Session, groups: t.Sequence[str], expected_ids: t.Set[str], @@ -432,7 +454,9 @@ def test_repository_get_all__study_groups_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(groups=groups)) + all_studies = repository.get_all( + study_filter=StudyFilter(groups=groups, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -454,7 +478,7 @@ def test_repository_get_all__study_groups_filter( (["3000"], set()), ], ) -def test_repository_get_all__study_ids_filter( +def test_get_all__study_ids_filter( db_session: Session, study_ids: t.Sequence[str], expected_ids: t.Set[str], @@ -475,7 +499,9 @@ def test_repository_get_all__study_ids_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(study_ids=study_ids)) + all_studies = repository.get_all( + study_filter=StudyFilter(study_ids=study_ids, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -494,7 +520,7 @@ def test_repository_get_all__study_ids_filter( (False, {"3"}), ], ) -def test_repository_get_all__study_existence_filter( +def test_get_all__study_existence_filter( db_session: Session, exists: t.Optional[bool], expected_ids: t.Set[str], @@ -515,7 +541,9 @@ def test_repository_get_all__study_existence_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(exists=exists)) + all_studies = repository.get_all( + study_filter=StudyFilter(exists=exists, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -535,7 +563,7 @@ def test_repository_get_all__study_existence_filter( ("workspace-3", set()), ], ) -def test_repository_get_all__study_workspace_filter( +def test_get_all__study_workspace_filter( db_session: Session, workspace: str, expected_ids: t.Set[str], @@ -556,7 +584,9 @@ def test_repository_get_all__study_workspace_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(workspace=workspace)) + all_studies = repository.get_all( + study_filter=StudyFilter(workspace=workspace, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -578,7 +608,7 @@ def test_repository_get_all__study_workspace_filter( ("folder-1", set()), ], ) -def test_repository_get_all__study_folder_filter( +def test_get_all__study_folder_filter( db_session: Session, folder: str, expected_ids: t.Set[str], @@ -599,7 +629,9 @@ def test_repository_get_all__study_folder_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(folder=folder)) + all_studies = repository.get_all( + study_filter=StudyFilter(folder=folder, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -621,7 +653,7 @@ def test_repository_get_all__study_folder_filter( (["no-study-tag"], set()), ], ) -def test_repository_get_all__study_tags_filter( +def test_get_all__study_tags_filter( db_session: Session, tags: t.Sequence[str], expected_ids: t.Set[str], @@ -650,7 +682,9 @@ def test_repository_get_all__study_tags_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(tags=tags)) + all_studies = repository.get_all( + study_filter=StudyFilter(tags=tags, access_permissions=AccessPermissions(is_admin=True)) + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -662,6 +696,283 @@ def test_repository_get_all__study_tags_filter( assert {s.id for s in all_studies} == expected_ids +@pytest.mark.parametrize( + "user_id, study_groups, expected_ids", + [ + # fmt: off + (101, [], {"1", "2", "5", "6", "7", "8", "9", "10", "13", "14", "15", "16", "17", "18", + "21", "22", "23", "24", "25", "26", "29", "30", "31", "32", "34"}), + (101, ["101"], {"1", "7", "8", "9", "17", "23", "24", "25"}), + (101, ["102"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}), + (101, ["103"], set()), + (101, ["101", "102"], {"1", "2", "5", "6", "7", "8", "9", "17", "18", "21", "22", "23", "24", "25", "34"}), + (101, ["101", "103"], {"1", "7", "8", "9", "17", "23", "24", "25"}), + (101, ["102", "103"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}), + (101, ["101", "102", "103"], {"1", "2", "5", "6", "7", "8", "9", "17", "18", "21", "22", + "23", "24", "25", "34"}), + (102, [], {"1", "3", "4", "5", "7", "8", "9", "11", "13", "14", "15", "16", "17", "19", + "20", "21", "23", "24", "25", "27", "29", "30", "31", "32", "33"}), + (102, ["101"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}), + (102, ["102"], {"5", "7", "8", "9", "21", "23", "24", "25"}), + (102, ["103"], set()), + (102, ["101", "102"], {"1", "3", "4", "5", "7", "8", "9", "17", "19", "20", "21", "23", "24", "25", "33"}), + (102, ["101", "103"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}), + (102, ["102", "103"], {"5", "7", "8", "9", "21", "23", "24", "25"}), + (102, ["101", "102", "103"], {"1", "3", "4", "5", "7", "8", "9", "17", "19", "20", "21", + "23", "24", "25", "33"}), + (103, [], {"13", "14", "15", "16", "29", "30", "31", "32", "33", "34", "35", "36"}), + (103, ["101"], {"33"}), + (103, ["102"], {"34"}), + (103, ["103"], set()), + (103, ["101", "102"], {"33", "34"}), + (103, ["101", "103"], {"33"}), + (103, ["102", "103"], {"34"}), + (103, ["101", "102", "103"], {"33", "34"}), + (None, [], set()), + (None, ["101"], set()), + (None, ["102"], set()), + (None, ["103"], set()), + (None, ["101", "102"], set()), + (None, ["101", "103"], set()), + (None, ["102", "103"], set()), + (None, ["101", "102", "103"], set()), + # fmt: on + ], +) +def test_get_all__non_admin_permissions_filter( + db_session: Session, + user_id: t.Optional[int], + study_groups: t.Sequence[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + user_1 = User(id=101, name="user1") + user_2 = User(id=102, name="user2") + user_3 = User(id=103, name="user3") + + group_1 = Group(id=101, name="group1") + group_2 = Group(id=102, name="group2") + group_3 = Group(id=103, name="group3") + + user_groups_mapping = {101: [group_2.id], 102: [group_1.id], 103: []} + + # create variant studies for user_1 and user_2 that are part of some groups + study_1 = VariantStudy(id=1, owner=user_1, groups=[group_1]) + study_2 = VariantStudy(id=2, owner=user_1, groups=[group_2]) + study_3 = VariantStudy(id=3, groups=[group_1]) + study_4 = VariantStudy(id=4, owner=user_2, groups=[group_1]) + study_5 = VariantStudy(id=5, owner=user_2, groups=[group_2]) + study_6 = VariantStudy(id=6, groups=[group_2]) + study_7 = VariantStudy(id=7, owner=user_1, groups=[group_1, group_2]) + study_8 = VariantStudy(id=8, owner=user_2, groups=[group_1, group_2]) + study_9 = VariantStudy(id=9, groups=[group_1, group_2]) + study_10 = VariantStudy(id=10, owner=user_1) + study_11 = VariantStudy(id=11, owner=user_2) + + # create variant studies with neither owner nor groups + study_12 = VariantStudy(id=12) + study_13 = VariantStudy(id=13, public_mode=PublicMode.READ) + study_14 = VariantStudy(id=14, public_mode=PublicMode.EDIT) + study_15 = VariantStudy(id=15, public_mode=PublicMode.EXECUTE) + study_16 = VariantStudy(id=16, public_mode=PublicMode.FULL) + + # create raw studies for user_1 and user_2 that are part of some groups + study_17 = RawStudy(id=17, owner=user_1, groups=[group_1]) + study_18 = RawStudy(id=18, owner=user_1, groups=[group_2]) + study_19 = RawStudy(id=19, groups=[group_1]) + study_20 = RawStudy(id=20, owner=user_2, groups=[group_1]) + study_21 = RawStudy(id=21, owner=user_2, groups=[group_2]) + study_22 = RawStudy(id=22, groups=[group_2]) + study_23 = RawStudy(id=23, owner=user_1, groups=[group_1, group_2]) + study_24 = RawStudy(id=24, owner=user_2, groups=[group_1, group_2]) + study_25 = RawStudy(id=25, groups=[group_1, group_2]) + study_26 = RawStudy(id=26, owner=user_1) + study_27 = RawStudy(id=27, owner=user_2) + + # create raw studies with neither owner nor groups + study_28 = RawStudy(id=28) + study_29 = RawStudy(id=29, public_mode=PublicMode.READ) + study_30 = RawStudy(id=30, public_mode=PublicMode.EDIT) + study_31 = RawStudy(id=31, public_mode=PublicMode.EXECUTE) + study_32 = RawStudy(id=32, public_mode=PublicMode.FULL) + + # create studies for user_3 that is not part of any group + study_33 = VariantStudy(id=33, owner=user_3, groups=[group_1]) + study_34 = RawStudy(id=34, owner=user_3, groups=[group_2]) + study_35 = VariantStudy(id=35, owner=user_3) + study_36 = RawStudy(id=36, owner=user_3) + + # create studies for group_3 that has no user + study_37 = VariantStudy(id=37, groups=[group_3]) + study_38 = RawStudy(id=38, groups=[group_3]) + + db_session.add_all([user_1, user_2, user_3, group_1, group_2, group_3]) + db_session.add_all( + [ + # fmt: off + study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8, study_9, study_10, + study_11, study_12, study_13, study_14, study_15, study_16, study_17, study_18, study_19, study_20, + study_21, study_22, study_23, study_24, study_25, study_26, study_27, study_28, study_29, study_30, + study_31, study_32, study_33, study_34, study_35, study_36, study_37, study_38, + # fmt: on + ] + ) + db_session.commit() + + access_permissions = ( + AccessPermissions(user_id=user_id, user_groups=user_groups_mapping.get(user_id)) + if user_id + else AccessPermissions() + ) + study_filter = ( + StudyFilter(groups=study_groups, access_permissions=access_permissions) + if study_groups + else StudyFilter(access_permissions=access_permissions) + ) + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does not require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "is_admin, study_groups, expected_ids", + [ + # fmt: off + (True, [], {str(e) for e in range(1, 39)}), + (True, ["101"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}), + (True, ["102"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}), + (True, ["103"], {"37", "38"}), + (True, ["101", "102"], {"1", "2", "3", "4", "5", "6", "7", "8", "9", "17", "18", "19", + "20", "21", "22", "23", "24", "25", "33", "34"}), + (True, ["101", "103"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33", "37", "38"}), + (True, ["101", "102", "103"], {"1", "2", "3", "4", "5", "6", "7", "8", "9", "17", "18", + "19", "20", "21", "22", "23", "24", "25", "33", "34", "37", "38"}), + (False, [], set()), + (False, ["101"], set()), + (False, ["102"], set()), + (False, ["103"], set()), + (False, ["101", "102"], set()), + (False, ["101", "103"], set()), + (False, ["101", "102", "103"], set()), + # fmt: on + ], +) +def test_get_all__admin_permissions_filter( + db_session: Session, + is_admin: bool, + study_groups: t.Sequence[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + user_1 = User(id=101, name="user1") + user_2 = User(id=102, name="user2") + user_3 = User(id=103, name="user3") + + group_1 = Group(id=101, name="group1") + group_2 = Group(id=102, name="group2") + group_3 = Group(id=103, name="group3") + + # create variant studies for user_1 and user_2 that are part of some groups + study_1 = VariantStudy(id=1, owner=user_1, groups=[group_1]) + study_2 = VariantStudy(id=2, owner=user_1, groups=[group_2]) + study_3 = VariantStudy(id=3, groups=[group_1]) + study_4 = VariantStudy(id=4, owner=user_2, groups=[group_1]) + study_5 = VariantStudy(id=5, owner=user_2, groups=[group_2]) + study_6 = VariantStudy(id=6, groups=[group_2]) + study_7 = VariantStudy(id=7, owner=user_1, groups=[group_1, group_2]) + study_8 = VariantStudy(id=8, owner=user_2, groups=[group_1, group_2]) + study_9 = VariantStudy(id=9, groups=[group_1, group_2]) + study_10 = VariantStudy(id=10, owner=user_1) + study_11 = VariantStudy(id=11, owner=user_2) + + # create variant studies with neither owner nor groups + study_12 = VariantStudy(id=12) + study_13 = VariantStudy(id=13, public_mode=PublicMode.READ) + study_14 = VariantStudy(id=14, public_mode=PublicMode.EDIT) + study_15 = VariantStudy(id=15, public_mode=PublicMode.EXECUTE) + study_16 = VariantStudy(id=16, public_mode=PublicMode.FULL) + + # create raw studies for user_1 and user_2 that are part of some groups + study_17 = RawStudy(id=17, owner=user_1, groups=[group_1]) + study_18 = RawStudy(id=18, owner=user_1, groups=[group_2]) + study_19 = RawStudy(id=19, groups=[group_1]) + study_20 = RawStudy(id=20, owner=user_2, groups=[group_1]) + study_21 = RawStudy(id=21, owner=user_2, groups=[group_2]) + study_22 = RawStudy(id=22, groups=[group_2]) + study_23 = RawStudy(id=23, owner=user_1, groups=[group_1, group_2]) + study_24 = RawStudy(id=24, owner=user_2, groups=[group_1, group_2]) + study_25 = RawStudy(id=25, groups=[group_1, group_2]) + study_26 = RawStudy(id=26, owner=user_1) + study_27 = RawStudy(id=27, owner=user_2) + + # create raw studies with neither owner nor groups + study_28 = RawStudy(id=28) + study_29 = RawStudy(id=29, public_mode=PublicMode.READ) + study_30 = RawStudy(id=30, public_mode=PublicMode.EDIT) + study_31 = RawStudy(id=31, public_mode=PublicMode.EXECUTE) + study_32 = RawStudy(id=32, public_mode=PublicMode.FULL) + + # create studies for user_3 that is not part of any group + study_33 = VariantStudy(id=33, owner=user_3, groups=[group_1]) + study_34 = RawStudy(id=34, owner=user_3, groups=[group_2]) + study_35 = VariantStudy(id=35, owner=user_3) + study_36 = RawStudy(id=36, owner=user_3) + + # create studies for group_3 that has no user + study_37 = VariantStudy(id=37, groups=[group_3]) + study_38 = RawStudy(id=38, groups=[group_3]) + + db_session.add_all([user_1, user_2, user_3, group_1, group_2, group_3]) + db_session.add_all( + [ + # fmt: off + study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8, study_9, study_10, + study_11, study_12, study_13, study_14, study_15, study_16, study_17, study_18, study_19, study_20, + study_21, study_22, study_23, study_24, study_25, study_26, study_27, study_28, study_29, study_30, + study_31, study_32, study_33, study_34, study_35, study_36, study_37, study_38, + # fmt: on + ] + ) + db_session.commit() + + access_permissions = AccessPermissions(is_admin=is_admin) + + study_filter = ( + StudyFilter(groups=study_groups, access_permissions=access_permissions) + if study_groups + else StudyFilter(access_permissions=access_permissions) + ) + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does not require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + def test_update_tags( db_session: Session, ) -> None: