Skip to content

Commit

Permalink
feature(permission-db): check user permission through the db query
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Feb 13, 2024
1 parent b7d98fd commit 5f9824c
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 24 deletions.
29 changes: 29 additions & 0 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import ICache
from antarest.core.model import PublicMode
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.login.model import Group
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag
Expand Down Expand Up @@ -66,6 +67,17 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"):
folder: str = ""


class QueryUser(BaseModel, frozen=True, extra="forbid"):
"""
This class object is build to pass on the user identity and its associated groups information
into the listing function get_all below
"""

is_admin: bool = False
user_id: t.Optional[int] = None
user_groups: t.Optional[t.Sequence[str]] = None


class StudySortBy(str, enum.Enum):
"""How to sort the results of studies query results"""

Expand Down Expand Up @@ -182,6 +194,7 @@ def get_all(
study_filter: StudyFilter = StudyFilter(),
sort_by: t.Optional[StudySortBy] = None,
pagination: StudyPagination = StudyPagination(),
query_user: QueryUser = QueryUser(),
) -> t.Sequence[Study]:
"""
Retrieve studies based on specified filters, sorting, and pagination.
Expand All @@ -190,6 +203,7 @@ def get_all(
study_filter: composed of all filtering criteria.
sort_by: how the user would like the results to be sorted.
pagination: specifies the number of results to displayed in each page and the actually displayed page.
query_user: user id and groups info
Returns:
The matching studies in proper order and pagination.
Expand Down Expand Up @@ -243,6 +257,21 @@ def get_all(
if study_filter.versions:
q = q.filter(entity.version.in_(study_filter.versions))

# permissions filtering
if not query_user.is_admin:
if query_user.user_id is not None:
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == query_user.user_id
condition_3 = Group.id.in_(query_user.user_groups or [])
if study_filter.groups:
q0 = q.filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))
else:
q0 = q.join(entity.groups).filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))
else:
return []

if sort_by:
if sort_by == StudySortBy.DATE_DESC:
q = q.order_by(entity.created_at.desc())
Expand Down
25 changes: 22 additions & 3 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
StudyMetadataPatchDTO,
StudySimResultDTO,
)
from antarest.study.repository import StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy
from antarest.study.repository import QueryUser, StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy
from antarest.study.storage.matrix_profile import adjust_matrix_columns_index
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO
from antarest.study.storage.rawstudy.model.filesystem.folder_node import ChildNotFoundError
Expand Down Expand Up @@ -461,10 +461,29 @@ def get_studies_information(
"""
logger.info("Retrieving matching studies")
studies: t.Dict[str, StudyMetadataDTO] = {}

# retrieve user id and groups
user_id = None
user_groups = None
if params.user:
if params.user.id:
user_id = params.user.id
if params.user.groups:
user_groups = [group.id for group in params.user.groups]
else:
logger.error("FAIL permission: user is not logged")
raise UserHasNotPermissionError()

query_user = QueryUser(
is_admin=params.user.is_site_admin() or params.user.is_admin_token(),
user_id=user_id,
user_groups=user_groups,
)
matching_studies = self.repository.get_all(
study_filter=study_filter,
sort_by=sort_by,
pagination=pagination,
query_user=query_user,
)
logger.info("Studies retrieved")
for study in matching_studies:
Expand Down Expand Up @@ -702,7 +721,7 @@ def get_input_matrix_startdate(

def remove_duplicates(self) -> None:
study_paths: t.Dict[str, t.List[str]] = {}
for study in self.repository.get_all():
for study in self.repository.get_all(query_user=QueryUser(is_admin=True)):
if isinstance(study, RawStudy) and not study.archived:
path = str(study.path)
if path not in study_paths:
Expand Down Expand Up @@ -2145,7 +2164,7 @@ def check_and_update_all_study_versions_in_database(self, params: RequestParamet
if params.user and not params.user.is_site_admin():
logger.error(f"User {params.user.id} is not site admin")
raise UserHasNotPermissionError()
studies = self.repository.get_all(study_filter=StudyFilter(managed=False))
studies = self.repository.get_all(study_filter=StudyFilter(managed=False), query_user=QueryUser(is_amin=True))

for study in studies:
storage = self.storage_service.raw_study_service
Expand Down
6 changes: 4 additions & 2 deletions antarest/study/storage/auto_archive_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from antarest.core.requests import RequestParameters
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.study.model import RawStudy, Study
from antarest.study.repository import StudyFilter
from antarest.study.repository import QueryUser, StudyFilter
from antarest.study.service import StudyService
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy

Expand All @@ -28,7 +28,9 @@ def __init__(self, study_service: StudyService, config: Config):
def _try_archive_studies(self) -> None:
old_date = datetime.datetime.utcnow() - datetime.timedelta(days=self.config.storage.auto_archive_threshold_days)
with db():
studies: t.Sequence[Study] = self.study_service.repository.get_all(study_filter=StudyFilter(managed=True))
studies: t.Sequence[Study] = self.study_service.repository.get_all(
study_filter=StudyFilter(managed=True), query_user=QueryUser(is_admin=True)
)
# list of study IDs and boolean indicating if it's a raw study (True) or a variant (False)
study_ids_to_archive = [
(study.id, isinstance(study, RawStudy))
Expand Down
4 changes: 2 additions & 2 deletions tests/storage/repository/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from antarest.core.model import PublicMode
from antarest.login.model import Group, User
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyContentStatus
from antarest.study.repository import StudyMetadataRepository
from antarest.study.repository import QueryUser, StudyMetadataRepository
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy
from tests.helpers import with_db_context

Expand Down Expand Up @@ -64,7 +64,7 @@ def test_lifecycle() -> None:
c = repo.one(a.id)
assert a == c

assert len(repo.get_all()) == 4
assert len(repo.get_all(query_user=QueryUser(is_admin=True))) == 4
assert len(repo.get_all_raw(exists=True)) == 1
assert len(repo.get_all_raw(exists=False)) == 1
assert len(repo.get_all_raw()) == 2
Expand Down
45 changes: 28 additions & 17 deletions tests/study/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from antarest.core.interfaces.cache import ICache
from antarest.login.model import Group, User
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag
from antarest.study.repository import StudyFilter, StudyMetadataRepository
from antarest.study.repository import QueryUser, StudyFilter, StudyMetadataRepository
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy
from tests.db_statement_recorder import DBStatementRecorder

Expand Down Expand Up @@ -66,7 +66,10 @@ def test_repository_get_all__general_case(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists))
all_studies = repository.get_all(
study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists),
query_user=QueryUser(is_admin=True),
)
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -99,7 +102,7 @@ def test_repository_get_all__incompatible_case(
# case 1
study_filter = StudyFilter(managed=False, variant=True)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=study_filter)
all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand All @@ -110,7 +113,7 @@ def test_repository_get_all__incompatible_case(
# case 2
study_filter = StudyFilter(workspace=test_workspace, variant=True)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=study_filter)
all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand All @@ -121,7 +124,7 @@ def test_repository_get_all__incompatible_case(
# case 3
study_filter = StudyFilter(exists=False, variant=True)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=study_filter)
all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -169,7 +172,7 @@ def test_repository_get_all__study_name_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(name=name))
all_studies = repository.get_all(study_filter=StudyFilter(name=name), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -214,7 +217,7 @@ def test_repository_get_all__managed_study_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(managed=managed))
all_studies = repository.get_all(study_filter=StudyFilter(managed=managed), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -254,7 +257,9 @@ def test_repository_get_all__archived_study_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(archived=archived))
all_studies = repository.get_all(
study_filter=StudyFilter(archived=archived), query_user=QueryUser(is_admin=True)
)
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -294,7 +299,7 @@ def test_repository_get_all__variant_study_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(variant=variant))
all_studies = repository.get_all(study_filter=StudyFilter(variant=variant), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -336,7 +341,9 @@ def test_repository_get_all__study_version_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(versions=versions))
all_studies = repository.get_all(
study_filter=StudyFilter(versions=versions), query_user=QueryUser(is_admin=True)
)
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -384,7 +391,7 @@ def test_repository_get_all__study_users_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(users=users))
all_studies = repository.get_all(study_filter=StudyFilter(users=users), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -432,7 +439,7 @@ def test_repository_get_all__study_groups_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(groups=groups))
all_studies = repository.get_all(study_filter=StudyFilter(groups=groups), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -475,7 +482,9 @@ def test_repository_get_all__study_ids_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(study_ids=study_ids))
all_studies = repository.get_all(
study_filter=StudyFilter(study_ids=study_ids), query_user=QueryUser(is_admin=True)
)
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -515,7 +524,7 @@ def test_repository_get_all__study_existence_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(exists=exists))
all_studies = repository.get_all(study_filter=StudyFilter(exists=exists), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -556,7 +565,9 @@ def test_repository_get_all__study_workspace_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(workspace=workspace))
all_studies = repository.get_all(
study_filter=StudyFilter(workspace=workspace), query_user=QueryUser(is_admin=True)
)
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -599,7 +610,7 @@ def test_repository_get_all__study_folder_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(folder=folder))
all_studies = repository.get_all(study_filter=StudyFilter(folder=folder), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down Expand Up @@ -650,7 +661,7 @@ def test_repository_get_all__study_tags_filter(
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(tags=tags))
all_studies = repository.get_all(study_filter=StudyFilter(tags=tags), query_user=QueryUser(is_admin=True))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
Expand Down

0 comments on commit 5f9824c

Please sign in to comment.