From ca97700c66dab5780745f4082dd657c0bc4ee107 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Mon, 19 Feb 2024 17:02:31 +0100 Subject: [PATCH] test(permission-db): unittests for permissions filtering --- antarest/study/repository.py | 48 ++++---- tests/study/test_repository.py | 199 +++++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 26 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 5eefaf79b7..97e7bc0a0b 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -3,7 +3,7 @@ import typing as t from pydantic import BaseModel, NonNegativeInt -from sqlalchemy import func, not_, or_ # type: ignore +from sqlalchemy import and_, func, not_, or_ # type: ignore from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore from antarest.core.interfaces.cache import ICache @@ -243,14 +243,8 @@ def get_all( q = self._search_studies(study_filter) # permissions filtering - if not study_filter.access_permissions.is_admin: - if study_filter.access_permissions.user_id is None: - return [] - condition_1 = entity.public_mode != PublicMode.NONE - condition_2 = entity.owner_id == study_filter.access_permissions.user_id - condition_3 = Group.id.in_(study_filter.access_permissions.user_groups or []) - q0 = q.filter(condition_3) if study_filter.groups else q.join(entity.groups).filter(condition_3) - q = q0.union(q.filter(or_(condition_1, condition_2))) + if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None: + return [] # sorting if sort_by: @@ -282,30 +276,18 @@ def count_studies(self, study_filter: StudyFilter = StudyFilter()) -> int: Returns: Integer, corresponding to total number of studies matching with specified filters. """ - # When we fetch a study, we also need to fetch the associated owner and groups - # to check the permissions of the current user efficiently. - # We also need to fetch the additional data to display the study information - # efficiently (see: `AbstractStorageService.get_study_information`) - entity = with_polymorphic(Study, "*") - q = self._search_studies(study_filter) # permissions filtering - if not study_filter.access_permissions.is_admin: - if study_filter.access_permissions.user_id is None: - return 0 - condition_1 = entity.public_mode != PublicMode.NONE - condition_2 = entity.owner_id == study_filter.access_permissions.user_id - condition_3 = Group.id.in_(study_filter.access_permissions.user_groups or []) - q0 = q.filter(condition_3) if study_filter.groups else q.join(entity.groups).filter(condition_3) - q = q0.union(q.filter(or_(condition_1, condition_2))) + if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None: + return 0 total: int = q.count() return total def _search_studies( self, - study_filter: StudyFilter = StudyFilter(), + study_filter: StudyFilter, ) -> Query: """ Build a `SQL Query` based on specified filters. @@ -344,8 +326,6 @@ def _search_studies( q = q.filter(entity.id.in_(study_filter.study_ids)) if study_filter.study_ids else q if study_filter.users: q = q.filter(entity.owner_id.in_(study_filter.users)) - if study_filter.groups: - q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) if study_filter.tags: q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags)) if study_filter.archived is not None: @@ -363,6 +343,22 @@ def _search_studies( if study_filter.versions: q = q.filter(entity.version.in_(study_filter.versions)) + # permissions + groups filtering + if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is not None: + condition_1 = entity.public_mode != PublicMode.NONE + condition_2 = entity.owner_id == study_filter.access_permissions.user_id + q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups or [])) + if study_filter.groups: + q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups or [])) + q2 = q1.intersect(q2) + q = q2.union( + q.join(entity.groups).filter(and_(or_(condition_1, condition_2), Group.id.in_(study_filter.groups))) + ) + else: + q = q1.union(q.filter(or_(condition_1, condition_2))) + elif study_filter.groups: + q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) + return q def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]: diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index 116b3cd9ea..46201220d7 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -6,6 +6,9 @@ from sqlalchemy.orm import Session # type: ignore from antarest.core.interfaces.cache import ICache +from antarest.core.jwt import JWTUser +from antarest.core.model import PublicMode +from antarest.core.requests import RequestParameters from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository @@ -692,3 +695,199 @@ def test_repository_get_all__study_tags_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "user_id, study_groups, expected_ids", + [ + ( + 1, + [], + { + "1", + "2", + "5", + "6", + "7", + "8", + "9", + "10", + "13", + "14", + "15", + "16", + "17", + "18", + "21", + "22", + "23", + "24", + "25", + "26", + "29", + "30", + "31", + "32", + }, + ), + (1, ["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}), + (1, ["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25"}), + (1, ["1", "2"], {"1", "2", "5", "6", "7", "8", "9", "17", "18", "21", "22", "23", "24", "25"}), + ( + 2, + [], + { + "1", + "3", + "4", + "5", + "7", + "8", + "9", + "11", + "13", + "14", + "15", + "16", + "17", + "19", + "20", + "21", + "23", + "24", + "25", + "27", + "29", + "30", + "31", + "32", + }, + ), + (2, ["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25"}), + (2, ["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}), + (2, ["1", "2"], {"1", "3", "4", "5", "7", "8", "9", "17", "19", "20", "21", "23", "24", "25"}), + (None, [], set()), + (None, ["1"], set()), + (None, ["2"], set()), + (None, ["1", "2"], set()), + ], +) +def test_repository_get_all__non_admin_access_permissions_filter( + db_session: Session, + user_id: t.Optional[int], + study_groups: t.Sequence[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + user_1 = User(id=1, name="user1") + user_2 = User(id=2, name="user2") + + group_1 = Group(id=1, name="group1") + group_2 = Group(id=2, name="group2") + + user_groups_mapping = {1: [group_2.id], 2: [group_1.id]} + + study_1 = VariantStudy(id=1, owner=user_1, groups=[group_1]) + study_2 = VariantStudy(id=2, owner=user_1, groups=[group_2]) + study_3 = VariantStudy(id=3, groups=[group_1]) + study_4 = VariantStudy(id=4, owner=user_2, groups=[group_1]) + study_5 = VariantStudy(id=5, owner=user_2, groups=[group_2]) + study_6 = VariantStudy(id=6, groups=[group_2]) + study_7 = VariantStudy(id=7, owner=user_1, groups=[group_1, group_2]) + study_8 = VariantStudy(id=8, owner=user_2, groups=[group_1, group_2]) + study_9 = VariantStudy(id=9, groups=[group_1, group_2]) + study_10 = VariantStudy(id=10, owner=user_1) + study_11 = VariantStudy(id=11, owner=user_2) + study_12 = VariantStudy(id=12) + study_13 = VariantStudy(id=13, public_mode=PublicMode.READ) + study_14 = VariantStudy(id=14, public_mode=PublicMode.EDIT) + study_15 = VariantStudy(id=15, public_mode=PublicMode.EXECUTE) + study_16 = VariantStudy(id=16, public_mode=PublicMode.FULL) + + study_17 = RawStudy(id=17, owner=user_1, groups=[group_1]) + study_18 = RawStudy(id=18, owner=user_1, groups=[group_2]) + study_19 = RawStudy(id=19, groups=[group_1]) + study_20 = RawStudy(id=20, owner=user_2, groups=[group_1]) + study_21 = RawStudy(id=21, owner=user_2, groups=[group_2]) + study_22 = RawStudy(id=22, groups=[group_2]) + study_23 = RawStudy(id=23, owner=user_1, groups=[group_1, group_2]) + study_24 = RawStudy(id=24, owner=user_2, groups=[group_1, group_2]) + study_25 = RawStudy(id=25, groups=[group_1, group_2]) + study_26 = RawStudy(id=26, owner=user_1) + study_27 = RawStudy(id=27, owner=user_2) + study_28 = RawStudy(id=28) + study_29 = RawStudy(id=29, public_mode=PublicMode.READ) + study_30 = RawStudy(id=30, public_mode=PublicMode.EDIT) + study_31 = RawStudy(id=31, public_mode=PublicMode.EXECUTE) + study_32 = RawStudy(id=32, public_mode=PublicMode.FULL) + + db_session.add_all([user_1, user_2, group_1, group_2]) + db_session.add_all( + [ + study_1, + study_2, + study_3, + study_4, + study_5, + study_6, + study_7, + study_8, + study_9, + study_10, + study_11, + study_12, + study_13, + study_14, + study_15, + study_16, + study_17, + study_18, + study_19, + study_20, + study_21, + study_22, + study_23, + study_24, + study_25, + study_26, + study_27, + study_28, + study_29, + study_30, + study_31, + study_32, + ] + ) + db_session.commit() + + access_permissions = ( + AccessPermissions(user_id=user_id, user_groups=user_groups_mapping.get(user_id)) + if user_id + else AccessPermissions() + ) + study_filter = ( + StudyFilter(groups=study_groups, access_permissions=access_permissions) + if study_groups + else StudyFilter(access_permissions=access_permissions) + ) + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query if user_id is not None + # 2- accessing studies attributes does not require additional queries to db + # 3- having an exact total of queries equals to 1 if user_id is not None + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] + if user_id: + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + else: + # no query should be executed if user_id is None + assert len(db_recorder.sql_statements) == 0, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids