Skip to content

Commit

Permalink
feat(permission-db): update following code review
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Feb 23, 2024
1 parent c2f651a commit aa6c9e9
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 96 deletions.
31 changes: 17 additions & 14 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,20 @@ def _search_studies(
# noinspection PyTypeChecker
q = self.session.query(entity)
if study_filter.exists is not None:
q = (
q.filter(RawStudy.missing.is_(None))
if study_filter.exists
else q.filter(not_(RawStudy.missing.is_(None)))
)
if study_filter.exists:
q = q.filter(RawStudy.missing.is_(None))
else:
q = q.filter(not_(RawStudy.missing.is_(None)))
q = q.options(joinedload(entity.owner))
q = q.options(joinedload(entity.groups))
q = q.options(joinedload(entity.additional_data))
q = q.options(joinedload(entity.tags))
if study_filter.managed is not None:
q = (
q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
if study_filter.managed
else q.filter(entity.type == "rawstudy").filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME)
)
if study_filter.managed:
q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
else:
q = q.filter(entity.type == "rawstudy")
q = q.filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME)
if study_filter.study_ids:
q = q.filter(entity.id.in_(study_filter.study_ids)) if study_filter.study_ids else q
if study_filter.users:
Expand All @@ -333,17 +332,20 @@ def _search_studies(
if study_filter.workspace:
q = q.filter(RawStudy.workspace == study_filter.workspace)
if study_filter.variant is not None:
q = q.filter(entity.type == "variantstudy") if study_filter.variant else q.filter(entity.type == "rawstudy")
if study_filter.variant:
q = q.filter(entity.type == "variantstudy")
else:
q = q.filter(entity.type == "rawstudy")
if study_filter.versions:
q = q.filter(entity.version.in_(study_filter.versions))

# permissions + groups filtering
if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is not None:
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.access_permissions.user_id
q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups or []))
q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups))
if study_filter.groups:
q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups or []))
q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
q2 = q1.intersect(q2)
q = q2.union(
q.join(entity.groups).filter(and_(or_(condition_1, condition_2), Group.id.in_(study_filter.groups)))
Expand All @@ -352,7 +354,8 @@ def _search_studies(
q = q1.union(q.filter(or_(condition_1, condition_2)))
elif not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None:
# return empty result
q = q.filter(sql.false())
# noinspection PyTypeChecker
q = self.session.query(entity).filter(sql.false())
elif study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))

Expand Down
15 changes: 1 addition & 14 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,15 +451,13 @@ def edit_comments(

def get_studies_information(
self,
params: RequestParameters,
study_filter: StudyFilter,
sort_by: t.Optional[StudySortBy] = None,
pagination: StudyPagination = StudyPagination(),
) -> t.Dict[str, StudyMetadataDTO]:
"""
Get information for matching studies of a search query.
Args:
params: request parameters
study_filter: filtering parameters
sort_by: how to sort the db query results
pagination: set offset and limit for db query
Expand All @@ -478,18 +476,7 @@ def get_studies_information(
study_metadata = self._try_get_studies_information(study)
if study_metadata is not None:
studies[study_metadata.id] = study_metadata
return {
s.id: s
for s in filter(
lambda study_dto: assert_permission(
params.user,
study_dto,
StudyPermissionType.READ,
raising=False,
),
studies.values(),
)
}
return studies

def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]:
try:
Expand Down
4 changes: 1 addition & 3 deletions antarest/study/web/studies_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def get_studies(
user_list = [int(v) for v in _split_comma_separated_values(users)]

if not params.user:
logger.error("FAIL permission: user is not logged")
raise UserHasNotPermissionError()
raise UserHasNotPermissionError("FAIL permission: user is not logged")

study_filter = StudyFilter(
name=name,
Expand All @@ -165,7 +164,6 @@ def get_studies(
)

matching_studies = study_service.get_studies_information(
params=params,
study_filter=study_filter,
sort_by=sort_by,
pagination=StudyPagination(page_nb=page_nb, page_size=page_size),
Expand Down
68 changes: 7 additions & 61 deletions tests/integration/studies_blueprint/test_get_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_
for group in groups:
group_id = groups_ids[group]
res = client.post(
f"/v1/roles",
"/v1/roles",
headers={"Authorization": f"Bearer {admin_access_token}"},
json={"identity_id": user_id, "group_id": group_id, "type": RoleType.READER.value},
)
Expand Down Expand Up @@ -1315,36 +1315,9 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_

# user_1 access
requests_params_expected_studies = [
(
[],
{
"1",
"2",
"5",
"6",
"7",
"8",
"9",
"10",
"13",
"14",
"15",
"16",
"17",
"18",
"21",
"22",
"23",
"24",
"25",
"26",
"29",
"30",
"31",
"32",
"34",
},
),
# fmt: off
([], {"1", "2", "5", "6", "7", "8", "9", "10", "13", "14", "15", "16", "17",
"18", "21", "22", "23", "24", "25", "26", "29", "30", "31", "32", "34"}),
(["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}),
(["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}),
(["3"], set()),
Expand Down Expand Up @@ -1373,36 +1346,9 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_

# user_2 access
requests_params_expected_studies = [
(
[],
{
"1",
"3",
"4",
"5",
"7",
"8",
"9",
"11",
"13",
"14",
"15",
"16",
"17",
"19",
"20",
"21",
"23",
"24",
"25",
"27",
"29",
"30",
"31",
"32",
"33",
},
),
# fmt: off
([], {"1", "3", "4", "5", "7", "8", "9", "11", "13", "14", "15", "16", "17",
"19", "20", "21", "23", "24", "25", "27", "29", "30", "31", "32", "33"}),
(["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}),
(["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}),
(["3"], set()),
Expand Down
4 changes: 0 additions & 4 deletions tests/storage/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def test_study_listing(db_session: Session) -> None:
with DBStatementRecorder(db_session.bind) as db_recorder:
studies = service.get_studies_information(
study_filter=StudyFilter(managed=False, access_permissions=AccessPermissions.from_params(params)),
params=params,
)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

Expand All @@ -196,7 +195,6 @@ def test_study_listing(db_session: Session) -> None:
with DBStatementRecorder(db_session.bind) as db_recorder:
studies = service.get_studies_information(
study_filter=StudyFilter(managed=True, access_permissions=AccessPermissions.from_params(params)),
params=params,
)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

Expand All @@ -211,7 +209,6 @@ def test_study_listing(db_session: Session) -> None:
with DBStatementRecorder(db_session.bind) as db_recorder:
studies = service.get_studies_information(
study_filter=StudyFilter(managed=None, access_permissions=AccessPermissions.from_params(params)),
params=params,
)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

Expand All @@ -226,7 +223,6 @@ def test_study_listing(db_session: Session) -> None:
with DBStatementRecorder(db_session.bind) as db_recorder:
studies = service.get_studies_information(
study_filter=StudyFilter(managed=None, access_permissions=AccessPermissions.from_params(params)),
params=params,
)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with contextlib.suppress(AssertionError):
Expand Down

0 comments on commit aa6c9e9

Please sign in to comment.