From aa6c9e9b0fd36e839256f9eba8ce6a3e956788ab Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Fri, 23 Feb 2024 10:43:29 +0100 Subject: [PATCH] feat(permission-db): update following code review --- antarest/study/repository.py | 31 +++++---- antarest/study/service.py | 15 +--- antarest/study/web/studies_blueprint.py | 4 +- .../studies_blueprint/test_get_studies.py | 68 ++----------------- tests/storage/test_service.py | 4 -- 5 files changed, 26 insertions(+), 96 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index b311512bff..6e89d33d5a 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -300,21 +300,20 @@ def _search_studies( # noinspection PyTypeChecker q = self.session.query(entity) if study_filter.exists is not None: - q = ( - q.filter(RawStudy.missing.is_(None)) - if study_filter.exists - else q.filter(not_(RawStudy.missing.is_(None))) - ) + if study_filter.exists: + q = q.filter(RawStudy.missing.is_(None)) + else: + q = q.filter(not_(RawStudy.missing.is_(None))) q = q.options(joinedload(entity.owner)) q = q.options(joinedload(entity.groups)) q = q.options(joinedload(entity.additional_data)) q = q.options(joinedload(entity.tags)) if study_filter.managed is not None: - q = ( - q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME)) - if study_filter.managed - else q.filter(entity.type == "rawstudy").filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME) - ) + if study_filter.managed: + q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME)) + else: + q = q.filter(entity.type == "rawstudy") + q = q.filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME) if study_filter.study_ids: q = q.filter(entity.id.in_(study_filter.study_ids)) if study_filter.study_ids else q if study_filter.users: @@ -333,7 +332,10 @@ def _search_studies( if study_filter.workspace: q = q.filter(RawStudy.workspace == study_filter.workspace) if study_filter.variant is not None: - q = q.filter(entity.type == "variantstudy") if study_filter.variant else q.filter(entity.type == "rawstudy") + if study_filter.variant: + q = q.filter(entity.type == "variantstudy") + else: + q = q.filter(entity.type == "rawstudy") if study_filter.versions: q = q.filter(entity.version.in_(study_filter.versions)) @@ -341,9 +343,9 @@ def _search_studies( if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is not None: condition_1 = entity.public_mode != PublicMode.NONE condition_2 = entity.owner_id == study_filter.access_permissions.user_id - q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups or [])) + q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups)) if study_filter.groups: - q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups or [])) + q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) q2 = q1.intersect(q2) q = q2.union( q.join(entity.groups).filter(and_(or_(condition_1, condition_2), Group.id.in_(study_filter.groups))) @@ -352,7 +354,8 @@ def _search_studies( q = q1.union(q.filter(or_(condition_1, condition_2))) elif not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None: # return empty result - q = q.filter(sql.false()) + # noinspection PyTypeChecker + q = self.session.query(entity).filter(sql.false()) elif study_filter.groups: q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) diff --git a/antarest/study/service.py b/antarest/study/service.py index 7290cf692d..81af28a473 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -451,7 +451,6 @@ def edit_comments( def get_studies_information( self, - params: RequestParameters, study_filter: StudyFilter, sort_by: t.Optional[StudySortBy] = None, pagination: StudyPagination = StudyPagination(), @@ -459,7 +458,6 @@ def get_studies_information( """ Get information for matching studies of a search query. Args: - params: request parameters study_filter: filtering parameters sort_by: how to sort the db query results pagination: set offset and limit for db query @@ -478,18 +476,7 @@ def get_studies_information( study_metadata = self._try_get_studies_information(study) if study_metadata is not None: studies[study_metadata.id] = study_metadata - return { - s.id: s - for s in filter( - lambda study_dto: assert_permission( - params.user, - study_dto, - StudyPermissionType.READ, - raising=False, - ), - studies.values(), - ) - } + return studies def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]: try: diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index cf76765122..9007ed2894 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -145,8 +145,7 @@ def get_studies( user_list = [int(v) for v in _split_comma_separated_values(users)] if not params.user: - logger.error("FAIL permission: user is not logged") - raise UserHasNotPermissionError() + raise UserHasNotPermissionError("FAIL permission: user is not logged") study_filter = StudyFilter( name=name, @@ -165,7 +164,6 @@ def get_studies( ) matching_studies = study_service.get_studies_information( - params=params, study_filter=study_filter, sort_by=sort_by, pagination=StudyPagination(page_nb=page_nb, page_size=page_size), diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index 4ff5ba6dc2..48dadaf829 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -851,7 +851,7 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_ for group in groups: group_id = groups_ids[group] res = client.post( - f"/v1/roles", + "/v1/roles", headers={"Authorization": f"Bearer {admin_access_token}"}, json={"identity_id": user_id, "group_id": group_id, "type": RoleType.READER.value}, ) @@ -1315,36 +1315,9 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_ # user_1 access requests_params_expected_studies = [ - ( - [], - { - "1", - "2", - "5", - "6", - "7", - "8", - "9", - "10", - "13", - "14", - "15", - "16", - "17", - "18", - "21", - "22", - "23", - "24", - "25", - "26", - "29", - "30", - "31", - "32", - "34", - }, - ), + # fmt: off + ([], {"1", "2", "5", "6", "7", "8", "9", "10", "13", "14", "15", "16", "17", + "18", "21", "22", "23", "24", "25", "26", "29", "30", "31", "32", "34"}), (["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}), (["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}), (["3"], set()), @@ -1373,36 +1346,9 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_ # user_2 access requests_params_expected_studies = [ - ( - [], - { - "1", - "3", - "4", - "5", - "7", - "8", - "9", - "11", - "13", - "14", - "15", - "16", - "17", - "19", - "20", - "21", - "23", - "24", - "25", - "27", - "29", - "30", - "31", - "32", - "33", - }, - ), + # fmt: off + ([], {"1", "3", "4", "5", "7", "8", "9", "11", "13", "14", "15", "16", "17", + "19", "20", "21", "23", "24", "25", "27", "29", "30", "31", "32", "33"}), (["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}), (["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}), (["3"], set()), diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index 84a8ded30a..c322b69672 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -181,7 +181,6 @@ def test_study_listing(db_session: Session) -> None: with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( study_filter=StudyFilter(managed=False, access_permissions=AccessPermissions.from_params(params)), - params=params, ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -196,7 +195,6 @@ def test_study_listing(db_session: Session) -> None: with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( study_filter=StudyFilter(managed=True, access_permissions=AccessPermissions.from_params(params)), - params=params, ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -211,7 +209,6 @@ def test_study_listing(db_session: Session) -> None: with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( study_filter=StudyFilter(managed=None, access_permissions=AccessPermissions.from_params(params)), - params=params, ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -226,7 +223,6 @@ def test_study_listing(db_session: Session) -> None: with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( study_filter=StudyFilter(managed=None, access_permissions=AccessPermissions.from_params(params)), - params=params, ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) with contextlib.suppress(AssertionError):