Skip to content

Commit

Permalink
test(permission-db): unittests for permissions filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Feb 19, 2024
1 parent ed23924 commit ca97700
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 26 deletions.
48 changes: 22 additions & 26 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing as t

from pydantic import BaseModel, NonNegativeInt
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy import and_, func, not_, or_ # type: ignore
from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import ICache
Expand Down Expand Up @@ -243,14 +243,8 @@ def get_all(
q = self._search_studies(study_filter)

# permissions filtering
if not study_filter.access_permissions.is_admin:
if study_filter.access_permissions.user_id is None:
return []
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.access_permissions.user_id
condition_3 = Group.id.in_(study_filter.access_permissions.user_groups or [])
q0 = q.filter(condition_3) if study_filter.groups else q.join(entity.groups).filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))
if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None:
return []

# sorting
if sort_by:
Expand Down Expand Up @@ -282,30 +276,18 @@ def count_studies(self, study_filter: StudyFilter = StudyFilter()) -> int:
Returns:
Integer, corresponding to total number of studies matching with specified filters.
"""
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
# We also need to fetch the additional data to display the study information
# efficiently (see: `AbstractStorageService.get_study_information`)
entity = with_polymorphic(Study, "*")

q = self._search_studies(study_filter)

# permissions filtering
if not study_filter.access_permissions.is_admin:
if study_filter.access_permissions.user_id is None:
return 0
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.access_permissions.user_id
condition_3 = Group.id.in_(study_filter.access_permissions.user_groups or [])
q0 = q.filter(condition_3) if study_filter.groups else q.join(entity.groups).filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))
if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None:
return 0
total: int = q.count()

return total

def _search_studies(
self,
study_filter: StudyFilter = StudyFilter(),
study_filter: StudyFilter,
) -> Query:
"""
Build a `SQL Query` based on specified filters.
Expand Down Expand Up @@ -344,8 +326,6 @@ def _search_studies(
q = q.filter(entity.id.in_(study_filter.study_ids)) if study_filter.study_ids else q
if study_filter.users:
q = q.filter(entity.owner_id.in_(study_filter.users))
if study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
if study_filter.tags:
q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags))
if study_filter.archived is not None:
Expand All @@ -363,6 +343,22 @@ def _search_studies(
if study_filter.versions:
q = q.filter(entity.version.in_(study_filter.versions))

# permissions + groups filtering
if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is not None:
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.access_permissions.user_id
q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups or []))
if study_filter.groups:
q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups or []))
q2 = q1.intersect(q2)
q = q2.union(
q.join(entity.groups).filter(and_(or_(condition_1, condition_2), Group.id.in_(study_filter.groups)))
)
else:
q = q1.union(q.filter(or_(condition_1, condition_2)))
elif study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))

return q

def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]:
Expand Down
199 changes: 199 additions & 0 deletions tests/study/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from sqlalchemy.orm import Session # type: ignore

from antarest.core.interfaces.cache import ICache
from antarest.core.jwt import JWTUser
from antarest.core.model import PublicMode
from antarest.core.requests import RequestParameters
from antarest.login.model import Group, User
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag
from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository
Expand Down Expand Up @@ -692,3 +695,199 @@ def test_repository_get_all__study_tags_filter(

if expected_ids is not None:
assert {s.id for s in all_studies} == expected_ids


@pytest.mark.parametrize(
"user_id, study_groups, expected_ids",
[
(
1,
[],
{
"1",
"2",
"5",
"6",
"7",
"8",
"9",
"10",
"13",
"14",
"15",
"16",
"17",
"18",
"21",
"22",
"23",
"24",
"25",
"26",
"29",
"30",
"31",
"32",
},
),
(1, ["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}),
(1, ["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25"}),
(1, ["1", "2"], {"1", "2", "5", "6", "7", "8", "9", "17", "18", "21", "22", "23", "24", "25"}),
(
2,
[],
{
"1",
"3",
"4",
"5",
"7",
"8",
"9",
"11",
"13",
"14",
"15",
"16",
"17",
"19",
"20",
"21",
"23",
"24",
"25",
"27",
"29",
"30",
"31",
"32",
},
),
(2, ["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25"}),
(2, ["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}),
(2, ["1", "2"], {"1", "3", "4", "5", "7", "8", "9", "17", "19", "20", "21", "23", "24", "25"}),
(None, [], set()),
(None, ["1"], set()),
(None, ["2"], set()),
(None, ["1", "2"], set()),
],
)
def test_repository_get_all__non_admin_access_permissions_filter(
db_session: Session,
user_id: t.Optional[int],
study_groups: t.Sequence[str],
expected_ids: t.Set[str],
) -> None:
icache: Mock = Mock(spec=ICache)
repository = StudyMetadataRepository(cache_service=icache, session=db_session)

user_1 = User(id=1, name="user1")
user_2 = User(id=2, name="user2")

group_1 = Group(id=1, name="group1")
group_2 = Group(id=2, name="group2")

user_groups_mapping = {1: [group_2.id], 2: [group_1.id]}

study_1 = VariantStudy(id=1, owner=user_1, groups=[group_1])
study_2 = VariantStudy(id=2, owner=user_1, groups=[group_2])
study_3 = VariantStudy(id=3, groups=[group_1])
study_4 = VariantStudy(id=4, owner=user_2, groups=[group_1])
study_5 = VariantStudy(id=5, owner=user_2, groups=[group_2])
study_6 = VariantStudy(id=6, groups=[group_2])
study_7 = VariantStudy(id=7, owner=user_1, groups=[group_1, group_2])
study_8 = VariantStudy(id=8, owner=user_2, groups=[group_1, group_2])
study_9 = VariantStudy(id=9, groups=[group_1, group_2])
study_10 = VariantStudy(id=10, owner=user_1)
study_11 = VariantStudy(id=11, owner=user_2)
study_12 = VariantStudy(id=12)
study_13 = VariantStudy(id=13, public_mode=PublicMode.READ)
study_14 = VariantStudy(id=14, public_mode=PublicMode.EDIT)
study_15 = VariantStudy(id=15, public_mode=PublicMode.EXECUTE)
study_16 = VariantStudy(id=16, public_mode=PublicMode.FULL)

study_17 = RawStudy(id=17, owner=user_1, groups=[group_1])
study_18 = RawStudy(id=18, owner=user_1, groups=[group_2])
study_19 = RawStudy(id=19, groups=[group_1])
study_20 = RawStudy(id=20, owner=user_2, groups=[group_1])
study_21 = RawStudy(id=21, owner=user_2, groups=[group_2])
study_22 = RawStudy(id=22, groups=[group_2])
study_23 = RawStudy(id=23, owner=user_1, groups=[group_1, group_2])
study_24 = RawStudy(id=24, owner=user_2, groups=[group_1, group_2])
study_25 = RawStudy(id=25, groups=[group_1, group_2])
study_26 = RawStudy(id=26, owner=user_1)
study_27 = RawStudy(id=27, owner=user_2)
study_28 = RawStudy(id=28)
study_29 = RawStudy(id=29, public_mode=PublicMode.READ)
study_30 = RawStudy(id=30, public_mode=PublicMode.EDIT)
study_31 = RawStudy(id=31, public_mode=PublicMode.EXECUTE)
study_32 = RawStudy(id=32, public_mode=PublicMode.FULL)

db_session.add_all([user_1, user_2, group_1, group_2])
db_session.add_all(
[
study_1,
study_2,
study_3,
study_4,
study_5,
study_6,
study_7,
study_8,
study_9,
study_10,
study_11,
study_12,
study_13,
study_14,
study_15,
study_16,
study_17,
study_18,
study_19,
study_20,
study_21,
study_22,
study_23,
study_24,
study_25,
study_26,
study_27,
study_28,
study_29,
study_30,
study_31,
study_32,
]
)
db_session.commit()

access_permissions = (
AccessPermissions(user_id=user_id, user_groups=user_groups_mapping.get(user_id))
if user_id
else AccessPermissions()
)
study_filter = (
StudyFilter(groups=study_groups, access_permissions=access_permissions)
if study_groups
else StudyFilter(access_permissions=access_permissions)
)

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query if user_id is not None
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1 if user_id is not None
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=study_filter)
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
if user_id:
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
else:
# no query should be executed if user_id is None
assert len(db_recorder.sql_statements) == 0, str(db_recorder)

if expected_ids is not None:
assert {s.id for s in all_studies} == expected_ids

0 comments on commit ca97700

Please sign in to comment.