Skip to content

Commit

Permalink
feat(permission-db): check user permission through the db query (#1931)
Browse files Browse the repository at this point in the history
Context: 
Additionally to ANT-1107 related tags db management, another problem of
permission filtering is blocking the pagination from working properly.

Issue: 
The permission to read studies in the search engine function is
performed after the db query. This is creating a problem for our
pagination process.

Solution:
Pass on the permission status directly in queries to the db.
  • Loading branch information
mabw-rte authored Feb 23, 2024
1 parent 3396f2e commit e8a0d6f
Show file tree
Hide file tree
Showing 10 changed files with 1,111 additions and 94 deletions.
8 changes: 6 additions & 2 deletions antarest/launcher/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from antarest.launcher.repository import JobResultRepository
from antarest.launcher.ssh_client import calculates_slurm_load
from antarest.launcher.ssh_config import SSHConfigDTO
from antarest.study.repository import StudyFilter
from antarest.study.repository import AccessPermissions, StudyFilter
from antarest.study.service import StudyService
from antarest.study.storage.utils import assert_permission, extract_output_name, find_single_output_path

Expand Down Expand Up @@ -312,7 +312,11 @@ def _filter_from_user_permission(self, job_results: List[JobResult], user: Optio
if study_ids:
studies = {
study.id: study
for study in self.study_service.repository.get_all(study_filter=StudyFilter(study_ids=study_ids))
for study in self.study_service.repository.get_all(
study_filter=StudyFilter(
study_ids=study_ids, access_permissions=AccessPermissions.from_params(user)
)
)
}
else:
studies = {}
Expand Down
142 changes: 121 additions & 21 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import typing as t

from pydantic import BaseModel, NonNegativeInt
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore
from sqlalchemy import and_, func, not_, or_, sql # type: ignore
from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import ICache
from antarest.core.jwt import JWTUser
from antarest.core.model import PublicMode
from antarest.core.requests import RequestParameters
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.login.model import Group
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag
Expand Down Expand Up @@ -34,6 +37,43 @@ def escape_like(string: str, escape_char: str = "\\") -> str:
return string.replace(escape_char, escape_char * 2).replace("%", escape_char + "%").replace("_", escape_char + "_")


class AccessPermissions(BaseModel, frozen=True, extra="forbid"):
"""
This class object is build to pass on the user identity and its associated groups information
into the listing function get_all below
"""

is_admin: bool = False
user_id: t.Optional[int] = None
user_groups: t.Sequence[str] = ()

@classmethod
def from_params(cls, params: t.Union[RequestParameters, JWTUser]) -> "AccessPermissions":
"""
This function makes it easier to pass on user ids and groups into the repository filtering function by
extracting the associated `AccessPermissions` object.
Args:
params: `RequestParameters` or `JWTUser` holding user ids and groups
Returns: `AccessPermissions`
"""
if isinstance(params, RequestParameters):
user = params.user
else:
user = params

if user:
return cls(
is_admin=user.is_site_admin() or user.is_admin_token(),
user_id=user.id,
user_groups=[group.id for group in user.groups],
)
else:
return cls()


class StudyFilter(BaseModel, frozen=True, extra="forbid"):
"""Study filter class gathering the main filtering parameters
Expand All @@ -50,6 +90,7 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"):
exists: if raw study missing
workspace: optional workspace of the study
folder: optional folder prefix of the study
access_permissions: query user ID, groups and admins status
"""

name: str = ""
Expand All @@ -64,6 +105,7 @@ class StudyFilter(BaseModel, frozen=True, extra="forbid"):
exists: t.Optional[bool] = None
workspace: str = ""
folder: str = ""
access_permissions: AccessPermissions = AccessPermissions()


class StudySortBy(str, enum.Enum):
Expand Down Expand Up @@ -198,6 +240,63 @@ def get_all(
# efficiently (see: `AbstractStorageService.get_study_information`)
entity = with_polymorphic(Study, "*")

q = self._search_studies(study_filter)

# sorting
if sort_by:
if sort_by == StudySortBy.DATE_DESC:
q = q.order_by(entity.created_at.desc())
elif sort_by == StudySortBy.DATE_ASC:
q = q.order_by(entity.created_at.asc())
elif sort_by == StudySortBy.NAME_DESC:
q = q.order_by(func.upper(entity.name).desc())
elif sort_by == StudySortBy.NAME_ASC:
q = q.order_by(func.upper(entity.name).asc())
else:
raise NotImplementedError(sort_by)

# pagination
if pagination.page_nb or pagination.page_size:
q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size)

studies: t.Sequence[Study] = q.all()
return studies

def count_studies(self, study_filter: StudyFilter = StudyFilter()) -> int:
"""
Count all studies matching with specified filters.
Args:
study_filter: composed of all filtering criteria.
Returns:
Integer, corresponding to total number of studies matching with specified filters.
"""
q = self._search_studies(study_filter)

total: int = q.count()

return total

def _search_studies(
self,
study_filter: StudyFilter,
) -> Query:
"""
Build a `SQL Query` based on specified filters.
Args:
study_filter: composed of all filtering criteria.
Returns:
The `Query` corresponding to specified criteria (except for permissions).
"""
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
# We also need to fetch the additional data to display the study information
# efficiently (see: `AbstractStorageService.get_study_information`)
entity = with_polymorphic(Study, "*")

# noinspection PyTypeChecker
q = self.session.query(entity)
if study_filter.exists is not None:
Expand All @@ -216,11 +315,9 @@ def get_all(
q = q.filter(entity.type == "rawstudy")
q = q.filter(RawStudy.workspace != DEFAULT_WORKSPACE_NAME)
if study_filter.study_ids:
q = q.filter(entity.id.in_(study_filter.study_ids))
q = q.filter(entity.id.in_(study_filter.study_ids)) if study_filter.study_ids else q
if study_filter.users:
q = q.filter(entity.owner_id.in_(study_filter.users))
if study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
if study_filter.tags:
upper_tags = [tag.upper() for tag in study_filter.tags]
q = q.join(entity.tags).filter(func.upper(Tag.label).in_(upper_tags))
Expand All @@ -242,24 +339,27 @@ def get_all(
if study_filter.versions:
q = q.filter(entity.version.in_(study_filter.versions))

if sort_by:
if sort_by == StudySortBy.DATE_DESC:
q = q.order_by(entity.created_at.desc())
elif sort_by == StudySortBy.DATE_ASC:
q = q.order_by(entity.created_at.asc())
elif sort_by == StudySortBy.NAME_DESC:
q = q.order_by(func.upper(entity.name).desc())
elif sort_by == StudySortBy.NAME_ASC:
q = q.order_by(func.upper(entity.name).asc())
# permissions + groups filtering
if not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is not None:
condition_1 = entity.public_mode != PublicMode.NONE
condition_2 = entity.owner_id == study_filter.access_permissions.user_id
q1 = q.join(entity.groups).filter(Group.id.in_(study_filter.access_permissions.user_groups))
if study_filter.groups:
q2 = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
q2 = q1.intersect(q2)
q = q2.union(
q.join(entity.groups).filter(and_(or_(condition_1, condition_2), Group.id.in_(study_filter.groups)))
)
else:
raise NotImplementedError(sort_by)

# pagination
if pagination.page_nb or pagination.page_size:
q = q.offset(pagination.page_nb * pagination.page_size).limit(pagination.page_size)
q = q1.union(q.filter(or_(condition_1, condition_2)))
elif not study_filter.access_permissions.is_admin and study_filter.access_permissions.user_id is None:
# return empty result
# noinspection PyTypeChecker
q = self.session.query(entity).filter(sql.false())
elif study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))

studies: t.Sequence[Study] = q.all()
return studies
return q

def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]:
query = self.session.query(RawStudy)
Expand Down
39 changes: 23 additions & 16 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@
StudyMetadataPatchDTO,
StudySimResultDTO,
)
from antarest.study.repository import StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy
from antarest.study.repository import (
AccessPermissions,
StudyFilter,
StudyMetadataRepository,
StudyPagination,
StudySortBy,
)
from antarest.study.storage.matrix_profile import adjust_matrix_columns_index
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO
from antarest.study.storage.rawstudy.model.filesystem.folder_node import ChildNotFoundError
Expand Down Expand Up @@ -445,15 +451,13 @@ def edit_comments(

def get_studies_information(
self,
params: RequestParameters,
study_filter: StudyFilter,
sort_by: t.Optional[StudySortBy] = None,
pagination: StudyPagination = StudyPagination(),
) -> t.Dict[str, StudyMetadataDTO]:
"""
Get information for matching studies of a search query.
Args:
params: request parameters
study_filter: filtering parameters
sort_by: how to sort the db query results
pagination: set offset and limit for db query
Expand All @@ -472,18 +476,7 @@ def get_studies_information(
study_metadata = self._try_get_studies_information(study)
if study_metadata is not None:
studies[study_metadata.id] = study_metadata
return {
s.id: s
for s in filter(
lambda study_dto: assert_permission(
params.user,
study_dto,
StudyPermissionType.READ,
raising=False,
),
studies.values(),
)
}
return studies

def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]:
try:
Expand Down Expand Up @@ -2139,10 +2132,24 @@ def update_matrix(
raise BadEditInstructionException(str(exc)) from exc

def check_and_update_all_study_versions_in_database(self, params: RequestParameters) -> None:
"""
This function updates studies version on the db.
**Warnings: Only users with Admins rights should be able to run this function.**
Args:
params: Request parameters holding user ID and groups
Raises:
UserHasNotPermissionError: if params user is not admin.
"""
if params.user and not params.user.is_site_admin():
logger.error(f"User {params.user.id} is not site admin")
raise UserHasNotPermissionError()
studies = self.repository.get_all(study_filter=StudyFilter(managed=False))
studies = self.repository.get_all(
study_filter=StudyFilter(managed=False, access_permissions=AccessPermissions.from_params(params))
)

for study in studies:
storage = self.storage_service.raw_study_service
Expand Down
7 changes: 5 additions & 2 deletions antarest/study/storage/auto_archive_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from antarest.core.requests import RequestParameters
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.study.model import RawStudy, Study
from antarest.study.repository import StudyFilter
from antarest.study.repository import AccessPermissions, StudyFilter
from antarest.study.service import StudyService
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy

Expand All @@ -28,7 +28,10 @@ def __init__(self, study_service: StudyService, config: Config):
def _try_archive_studies(self) -> None:
old_date = datetime.datetime.utcnow() - datetime.timedelta(days=self.config.storage.auto_archive_threshold_days)
with db():
studies: t.Sequence[Study] = self.study_service.repository.get_all(study_filter=StudyFilter(managed=True))
# in this part full `Read` rights over studies are granted to this function
studies: t.Sequence[Study] = self.study_service.repository.get_all(
study_filter=StudyFilter(managed=True, access_permissions=AccessPermissions(is_admin=True))
)
# list of study IDs and boolean indicating if it's a raw study (True) or a variant (False)
study_ids_to_archive = [
(study.id, isinstance(study, RawStudy))
Expand Down
1 change: 1 addition & 0 deletions antarest/study/storage/rawstudy/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _loop(self) -> None:
"Removing duplicates, this is a temporary fix that should be removed when previous duplicates are removed"
)
with db():
# in this part full `Read` rights over studies are granted to this function
self.study_service.remove_duplicates()
except Exception as e:
logger.error("Unexpected error when removing duplicates", exc_info=e)
Expand Down
9 changes: 6 additions & 3 deletions antarest/study/web/studies_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from antarest.core.filetransfer.service import FileTransferManager
from antarest.core.jwt import JWTUser
from antarest.core.model import PublicMode
from antarest.core.requests import RequestParameters
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.utils.utils import BadArchiveContent, sanitize_uuid
from antarest.core.utils.web import APITag
from antarest.login.auth import Auth
Expand All @@ -28,7 +28,7 @@
StudyMetadataPatchDTO,
StudySimResultDTO,
)
from antarest.study.repository import StudyFilter, StudyPagination, StudySortBy
from antarest.study.repository import AccessPermissions, StudyFilter, StudyPagination, StudySortBy
from antarest.study.service import StudyService
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO

Expand Down Expand Up @@ -144,6 +144,9 @@ def get_studies(

user_list = [int(v) for v in _split_comma_separated_values(users)]

if not params.user:
raise UserHasNotPermissionError("FAIL permission: user is not logged")

study_filter = StudyFilter(
name=name,
managed=managed,
Expand All @@ -157,10 +160,10 @@ def get_studies(
exists=exists,
workspace=workspace,
folder=folder,
access_permissions=AccessPermissions.from_params(params),
)

matching_studies = study_service.get_studies_information(
params=params,
study_filter=study_filter,
sort_by=sort_by,
pagination=StudyPagination(page_nb=page_nb, page_size=page_size),
Expand Down
Loading

0 comments on commit e8a0d6f

Please sign in to comment.