Skip to content

Commit

Permalink
refactor(permission-db): redefine access permissions, studies countin…
Browse files Browse the repository at this point in the history
…g and searching
  • Loading branch information
mabw-rte committed Feb 16, 2024
1 parent 817f383 commit ed23924
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 105 deletions.
6 changes: 4 additions & 2 deletions antarest/launcher/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from antarest.launcher.repository import JobResultRepository
from antarest.launcher.ssh_client import calculates_slurm_load
from antarest.launcher.ssh_config import SSHConfigDTO
from antarest.study.repository import QueryUser, StudyFilter, build_query_user_from_params
from antarest.study.repository import AccessPermissions, StudyFilter
from antarest.study.service import StudyService
from antarest.study.storage.utils import assert_permission, extract_output_name, find_single_output_path

Expand Down Expand Up @@ -313,7 +313,9 @@ def _filter_from_user_permission(self, job_results: List[JobResult], user: Optio
studies = {
study.id: study
for study in self.study_service.repository.get_all(
study_filter=StudyFilter(study_ids=study_ids, query_user=build_query_user_from_params(user))
study_filter=StudyFilter(
study_ids=study_ids, access_permissions=AccessPermissions.from_params(user)
)
)
}
else:
Expand Down
193 changes: 121 additions & 72 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, NonNegativeInt
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore
from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import ICache
from antarest.core.jwt import JWTUser
Expand Down Expand Up @@ -37,7 +37,7 @@ def escape_like(string: str, escape_char: str = "\\") -> str:
return string.replace(escape_char, escape_char * 2).replace("%", escape_char + "%").replace("_", escape_char + "_")


class QueryUser(BaseModel, frozen=True, extra="forbid"):
class AccessPermissions(BaseModel, frozen=True, extra="forbid"):
"""
This class object is build to pass on the user identity and its associated groups information
into the listing function get_all below
Expand All @@ -47,30 +47,30 @@ class QueryUser(BaseModel, frozen=True, extra="forbid"):
user_id: t.Optional[int] = None
user_groups: t.Sequence[str] = ()

@classmethod
def from_params(cls, params: t.Union[RequestParameters, JWTUser]) -> "AccessPermissions":
"""
This function makes it easier to pass on user ids and groups into the repository filtering function by
extracting the associated `AccessPermissions` object.
Args:
params: `RequestParameters` or `JWTUser` holding user ids and groups
def build_query_user_from_params(params: t.Union[RequestParameters, JWTUser]) -> QueryUser:
"""
This function makes it easier to pass on user ids and groups into the repository filtering function by
extracting the associated `QueryUser` object.
Args:
params: `RequestParameters` or `JWTUser` holding user ids and groups
Returns: `AccessPermissions`
Returns: `QueryUser`
"""
if isinstance(params, RequestParameters):
user = params.user
else:
user = params

"""
if isinstance(params, RequestParameters):
user = params.user
else:
user = params

if user:
return QueryUser(
is_admin=user.is_site_admin() or user.is_admin_token(),
user_id=user.id,
user_groups=[group.id for group in user.groups],
)
else:
return QueryUser()
if user:
return cls(
is_admin=user.is_site_admin() or user.is_admin_token(),
user_id=user.id,
user_groups=[group.id for group in user.groups],
)
else:
return cls()


class StudyFilter(BaseModel, frozen=True, extra="forbid"):
Expand All @@ -89,7 +89,7 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"):
exists: if raw study missing
workspace: optional workspace of the study
folder: optional folder prefix of the study
query_user: query user id, groups and admins status
access_permissions: query user id, groups and admins status
"""

name: str = ""
Expand All @@ -104,7 +104,7 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"):
exists: t.Optional[bool] = None
workspace: str = ""
folder: str = ""
query_user: QueryUser = QueryUser()
access_permissions: AccessPermissions = AccessPermissions()


class StudySortBy(str, enum.Enum):
Expand Down Expand Up @@ -233,6 +233,89 @@ def get_all(
Returns:
The matching studies in proper order and pagination.
"""

# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
# We also need to fetch the additional data to display the study information
# efficiently (see: `AbstractStorageService.get_study_information`)
entity = with_polymorphic(Study, "*")

q = self._search_studies(study_filter)

# permissions filtering
if not study_filter.access_permissions.is_admin:
if study_filter.access_permissions.user_id is None:
return []
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.access_permissions.user_id
condition_3 = Group.id.in_(study_filter.access_permissions.user_groups or [])
q0 = q.filter(condition_3) if study_filter.groups else q.join(entity.groups).filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))

# sorting
if sort_by:
if sort_by == StudySortBy.DATE_DESC:
q = q.order_by(entity.created_at.desc())
elif sort_by == StudySortBy.DATE_ASC:
q = q.order_by(entity.created_at.asc())
elif sort_by == StudySortBy.NAME_DESC:
q = q.order_by(func.upper(entity.name).desc())
elif sort_by == StudySortBy.NAME_ASC:
q = q.order_by(func.upper(entity.name).asc())
else:
raise NotImplementedError(sort_by)

# pagination
if pagination.page_nb or pagination.page_size:
q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size)

studies: t.Sequence[Study] = q.all()
return studies

def count_studies(self, study_filter: StudyFilter = StudyFilter()) -> int:
"""
Count all studies matching with specified filters.
Args:
study_filter: composed of all filtering criteria.
Returns:
Integer, corresponding to total number of studies matching with specified filters.
"""
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
# We also need to fetch the additional data to display the study information
# efficiently (see: `AbstractStorageService.get_study_information`)
entity = with_polymorphic(Study, "*")

q = self._search_studies(study_filter)

# permissions filtering
if not study_filter.access_permissions.is_admin:
if study_filter.access_permissions.user_id is None:
return 0
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.access_permissions.user_id
condition_3 = Group.id.in_(study_filter.access_permissions.user_groups or [])
q0 = q.filter(condition_3) if study_filter.groups else q.join(entity.groups).filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))
total: int = q.count()

return total

def _search_studies(
self,
study_filter: StudyFilter = StudyFilter(),
) -> Query:
"""
Build a `SQL Query` based on specified filters.
Args:
study_filter: composed of all filtering criteria.
Returns:
The `Query` corresponding to specified criteria (except for permissions).
"""
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
# We also need to fetch the additional data to display the study information
Expand All @@ -242,22 +325,23 @@ def get_all(
# noinspection PyTypeChecker
q = self.session.query(entity)
if study_filter.exists is not None:
if study_filter.exists:
q = q.filter(RawStudy.missing.is_(None))
else:
q = q.filter(not_(RawStudy.missing.is_(None)))
q = (
q.filter(RawStudy.missing.is_(None))
if study_filter.exists
else q.filter(not_(RawStudy.missing.is_(None)))
)
q = q.options(joinedload(entity.owner))
q = q.options(joinedload(entity.groups))
q = q.options(joinedload(entity.additional_data))
q = q.options(joinedload(entity.tags))
if study_filter.managed is not None:
if study_filter.managed:
q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
else:
q = q.filter(entity.type == "rawstudy")
q = q.filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME)
q = (
q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
if study_filter.managed
else q.filter(entity.type == "rawstudy").filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME)
)
if study_filter.study_ids:
q = q.filter(entity.id.in_(study_filter.study_ids))
q = q.filter(entity.id.in_(study_filter.study_ids)) if study_filter.study_ids else q
if study_filter.users:
q = q.filter(entity.owner_id.in_(study_filter.users))
if study_filter.groups:
Expand All @@ -275,46 +359,11 @@ def get_all(
if study_filter.workspace:
q = q.filter(RawStudy.workspace == study_filter.workspace)
if study_filter.variant is not None:
if study_filter.variant:
q = q.filter(entity.type == "variantstudy")
else:
q = q.filter(entity.type == "rawstudy")
q = q.filter(entity.type == "variantstudy") if study_filter.variant else q.filter(entity.type == "rawstudy")
if study_filter.versions:
q = q.filter(entity.version.in_(study_filter.versions))

# permissions filtering
if not study_filter.query_user.is_admin:
if study_filter.query_user.user_id is not None:
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.query_user.user_id
condition_3 = Group.id.in_(study_filter.query_user.user_groups or [])
if study_filter.groups:
q0 = q.filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))
else:
q0 = q.join(entity.groups).filter(condition_3)
q = q0.union(q.filter(or_(condition_1, condition_2)))
else:
return []

if sort_by:
if sort_by == StudySortBy.DATE_DESC:
q = q.order_by(entity.created_at.desc())
elif sort_by == StudySortBy.DATE_ASC:
q = q.order_by(entity.created_at.asc())
elif sort_by == StudySortBy.NAME_DESC:
q = q.order_by(func.upper(entity.name).desc())
elif sort_by == StudySortBy.NAME_ASC:
q = q.order_by(func.upper(entity.name).asc())
else:
raise NotImplementedError(sort_by)

# pagination
if pagination.page_nb or pagination.page_size:
q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size)

studies: t.Sequence[Study] = q.all()
return studies
return q

def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]:
query = self.session.query(RawStudy)
Expand Down
5 changes: 2 additions & 3 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,11 @@
StudySimResultDTO,
)
from antarest.study.repository import (
QueryUser,
AccessPermissions,
StudyFilter,
StudyMetadataRepository,
StudyPagination,
StudySortBy,
build_query_user_from_params,
)
from antarest.study.storage.matrix_profile import adjust_matrix_columns_index
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO
Expand Down Expand Up @@ -2161,7 +2160,7 @@ def check_and_update_all_study_versions_in_database(self, params: RequestParamet
logger.error(f"User {params.user.id} is not site admin")
raise UserHasNotPermissionError()
studies = self.repository.get_all(
study_filter=StudyFilter(managed=False, query_user=build_query_user_from_params(params))
study_filter=StudyFilter(managed=False, access_permissions=AccessPermissions.from_params(params))
)

for study in studies:
Expand Down
4 changes: 2 additions & 2 deletions antarest/study/storage/auto_archive_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from antarest.core.requests import RequestParameters
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.study.model import RawStudy, Study
from antarest.study.repository import QueryUser, StudyFilter
from antarest.study.repository import AccessPermissions, StudyFilter
from antarest.study.service import StudyService
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy

Expand All @@ -30,7 +30,7 @@ def _try_archive_studies(self) -> None:
with db():
# in this part full `Read` rights over studies are granted to this function
studies: t.Sequence[Study] = self.study_service.repository.get_all(
study_filter=StudyFilter(managed=True, query_user=QueryUser(is_admin=True))
study_filter=StudyFilter(managed=True, access_permissions=AccessPermissions(is_admin=True))
)
# list of study IDs and boolean indicating if it's a raw study (True) or a variant (False)
study_ids_to_archive = [
Expand Down
4 changes: 2 additions & 2 deletions antarest/study/web/studies_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
StudyMetadataPatchDTO,
StudySimResultDTO,
)
from antarest.study.repository import StudyFilter, StudyPagination, StudySortBy, build_query_user_from_params
from antarest.study.repository import AccessPermissions, StudyFilter, StudyPagination, StudySortBy
from antarest.study.service import StudyService
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO

Expand Down Expand Up @@ -161,7 +161,7 @@ def get_studies(
exists=exists,
workspace=workspace,
folder=folder,
query_user=build_query_user_from_params(params),
access_permissions=AccessPermissions.from_params(params),
)

matching_studies = study_service.get_studies_information(
Expand Down
4 changes: 2 additions & 2 deletions tests/storage/repository/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from antarest.core.model import PublicMode
from antarest.login.model import Group, User
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyContentStatus
from antarest.study.repository import QueryUser, StudyFilter, StudyMetadataRepository
from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy


Expand Down Expand Up @@ -68,7 +68,7 @@ def test_lifecycle(db_session: Session) -> None:
c = repo.one(a_id)
assert a_id == c.id

assert len(repo.get_all(study_filter=StudyFilter(query_user=QueryUser(is_admin=True)))) == 4
assert len(repo.get_all(study_filter=StudyFilter(access_permissions=AccessPermissions(is_admin=True)))) == 4
assert len(repo.get_all_raw(exists=True)) == 1
assert len(repo.get_all_raw(exists=False)) == 1
assert len(repo.get_all_raw()) == 2
Expand Down
Loading

0 comments on commit ed23924

Please sign in to comment.