Skip to content

Commit

Permalink
feat(study-search): add a studies counting endpoint (#1942)
Browse files Browse the repository at this point in the history
Context:
Related to ANT-1107 (tags-db) and ANT-1106 (permissions-db), it happens
that front-end can not predict the total number of studies matching some
given filtering parameters, to perform the pagination properly.

Solution:
Add an endpoint that return the total studies count.
  • Loading branch information
mabw-rte authored Feb 27, 2024
1 parent 8d95e70 commit bae236d
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 45 deletions.
16 changes: 16 additions & 0 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,22 @@ def get_studies_information(
studies[study_metadata.id] = study_metadata
return studies

def count_studies(
self,
study_filter: StudyFilter,
) -> int:
"""
Get number of matching studies.
Args:
study_filter: filtering parameters
Returns: total number of studies matching the filtering criteria
"""
total: int = self.repository.count_studies(
study_filter=study_filter,
)
return total

def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]:
try:
return self.storage_service.get_storage(study).get_study_information(study)
Expand Down
102 changes: 78 additions & 24 deletions antarest/study/web/studies_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

logger = logging.getLogger(__name__)

QUERY_REGEX = r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$"


def _split_comma_separated_values(value: str, *, default: t.Sequence[str] = ()) -> t.Sequence[str]:
"""Split a comma-separated list of values into an ordered set of strings."""
Expand Down Expand Up @@ -76,23 +78,11 @@ def get_studies(
managed: t.Optional[bool] = Query(None, description="Filter studies based on their management status."),
archived: t.Optional[bool] = Query(None, description="Filter studies based on their archive status."),
variant: t.Optional[bool] = Query(None, description="Filter studies based on their variant status."),
versions: str = Query(
"",
description="Comma-separated list of versions for filtering.",
regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$",
),
users: str = Query(
"",
description="Comma-separated list of user IDs for filtering.",
regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$",
),
versions: str = Query("", description="Comma-separated list of versions for filtering.", regex=QUERY_REGEX),
users: str = Query("", description="Comma-separated list of user IDs for filtering.", regex=QUERY_REGEX),
groups: str = Query("", description="Comma-separated list of group IDs for filtering."),
tags: str = Query("", description="Comma-separated list of tags for filtering."),
study_ids: str = Query(
"",
description="Comma-separated list of study IDs for filtering.",
alias="studyIds",
),
study_ids: str = Query("", description="Comma-separated list of study IDs for filtering.", alias="studyIds"),
exists: t.Optional[bool] = Query(None, description="Filter studies based on their existence on disk."),
workspace: str = Query("", description="Filter studies based on their workspace."),
folder: str = Query("", description="Filter studies based on their folder."),
Expand All @@ -102,23 +92,17 @@ def get_studies(
description="Sort studies based on their name (case-insensitive) or creation date.",
alias="sortBy",
),
page_nb: NonNegativeInt = Query(
0,
description="Page number (starting from 0).",
alias="pageNb",
),
page_nb: NonNegativeInt = Query(0, description="Page number (starting from 0).", alias="pageNb"),
page_size: NonNegativeInt = Query(
0,
description="Number of studies per page (0 = no limit).",
alias="pageSize",
0, description="Number of studies per page (0 = no limit).", alias="pageSize"
),
) -> t.Dict[str, StudyMetadataDTO]:
"""
Get the list of studies matching the specified criteria.
Args:
- `name`: Filter studies based on their name. Case-insensitive search for studies
whose name contains the specified value.
- `managed`: Filter studies based on their management status.
- `archived`: Filter studies based on their archive status.
- `variant`: Filter studies based on their variant status.
Expand Down Expand Up @@ -171,6 +155,76 @@ def get_studies(

return matching_studies

@bp.get(
"/studies/count",
tags=[APITag.study_management],
summary="Count Studies",
)
def count_studies(
current_user: JWTUser = Depends(auth.get_current_user),
name: str = Query("", description="Case-insensitive: filter studies based on their name.", alias="name"),
managed: t.Optional[bool] = Query(None, description="Management status filter."),
archived: t.Optional[bool] = Query(None, description="Archive status filter."),
variant: t.Optional[bool] = Query(None, description="Variant status filter."),
versions: str = Query("", description="Comma-separated versions filter.", regex=QUERY_REGEX),
users: str = Query("", description="Comma-separated user IDs filter.", regex=QUERY_REGEX),
groups: str = Query("", description="Comma-separated group IDs filter."),
tags: str = Query("", description="Comma-separated tags filter."),
study_ids: str = Query("", description="Comma-separated study IDs filter.", alias="studyIds"),
exists: t.Optional[bool] = Query(None, description="Existence on disk filter."),
workspace: str = Query("", description="Workspace filter."),
folder: str = Query("", description="Study folder filter."),
) -> int:
"""
Get the number of studies matching the specified criteria.
Args:
- `name`: Regexp to filter through studies based on their names
- `managed`: Whether to limit the selection based on management status.
- `archived`: Whether to limit the selection based on archive status.
- `variant`: Whether to limit the selection either raw or variant studies.
- `versions`: Comma-separated versions for studies to be selected.
- `users`: Comma-separated user IDs for studies to be selected.
- `groups`: Comma-separated group IDs for studies to be selected.
- `tags`: Comma-separated tags for studies to be selected.
- `studyIds`: Comma-separated IDs of studies to be selected.
- `exists`: Whether to limit the selection based on studies' existence on disk.
- `workspace`: to limit studies selection based on their workspace.
- `folder`: to limit studies selection based on their folder.
Returns:
- An integer representing the total number of studies matching the filters above and the user permissions.
"""

logger.info("Counting matching studies", extra={"user": current_user.id})
params = RequestParameters(user=current_user)

user_list = [int(v) for v in _split_comma_separated_values(users)]

if not params.user:
raise UserHasNotPermissionError("FAIL permission: user is not logged")

count = study_service.count_studies(
study_filter=StudyFilter(
name=name,
managed=managed,
archived=archived,
variant=variant,
versions=_split_comma_separated_values(versions),
users=user_list,
groups=_split_comma_separated_values(groups),
tags=_split_comma_separated_values(tags),
study_ids=_split_comma_separated_values(study_ids),
exists=exists,
workspace=workspace,
folder=folder,
access_permissions=AccessPermissions.from_params(params),
),
)

return count

@bp.get(
"/studies/{uuid}/comments",
tags=[APITag.study_management],
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/studies_blueprint/test_get_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,15 @@ def test_study_listing(
study_map = res.json()
assert not all_studies.intersection(study_map)
assert all(map(lambda x: pm(x) in [PublicMode.READ, PublicMode.FULL], study_map.values()))
# test pagination
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {john_doe_access_token}"},
params={"pageNb": 1, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == max(0, min(2, len(study_map) - 2))

# test 1.b for an admin user
res = client.get(
Expand All @@ -463,6 +472,31 @@ def test_study_listing(
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map = res.json()
assert not all_studies.difference(study_map)
# test pagination
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"pageNb": 1, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == max(0, min(len(study_map) - 2, 2))
# test pagination concatenation
paginated_studies = {}
page_number = 0
number_of_pages = 0
while len(paginated_studies) < len(study_map):
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"pageNb": page_number, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
paginated_studies.update(res.json())
page_number += 1
number_of_pages += 1
assert paginated_studies == study_map
assert number_of_pages == len(study_map) // 2 + len(study_map) % 2

# test 1.c for a user with access to select studies
res = client.get(
Expand Down Expand Up @@ -620,6 +654,15 @@ def test_study_listing(
study_map = res.json()
assert not all_studies.difference(studies_version_850.union(studies_version_860)).intersection(study_map)
assert not studies_version_850.union(studies_version_860).difference(study_map)
# test pagination
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"versions": "850,860", "pageNb": 1, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == max(0, min(len(study_map) - 2, 2))

# tests (7) for users filtering
# test 7.a to get studies for one user: James Bond
Expand Down Expand Up @@ -1318,6 +1361,7 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_
# fmt: off
([], {"1", "2", "5", "6", "7", "8", "9", "10", "13", "14", "15", "16", "17",
"18", "21", "22", "23", "24", "25", "26", "29", "30", "31", "32", "34"}),
# fmt: on
(["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}),
(["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}),
(["3"], set()),
Expand All @@ -1343,12 +1387,23 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_
study_map = res.json()
assert not expected_studies.difference(set(study_map))
assert not all_studies.difference(expected_studies).intersection(set(study_map))
# test pagination
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {users_tokens['user_1']}"},
params={"groups": ",".join(request_groups_ids), "pageNb": 1, "pageSize": 2}
if request_groups_ids
else {"pageNb": 1, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
assert len(res.json()) == max(0, min(2, len(expected_studies) - 2))

# user_2 access
requests_params_expected_studies = [
# fmt: off
([], {"1", "3", "4", "5", "7", "8", "9", "11", "13", "14", "15", "16", "17",
"19", "20", "21", "23", "24", "25", "27", "29", "30", "31", "32", "33"}),
# fmt: on
(["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}),
(["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}),
(["3"], set()),
Expand Down Expand Up @@ -1473,3 +1528,14 @@ def test_get_studies__invalid_parameters(
assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json()
description = res.json()["description"]
assert re.search(r"could not be parsed to a boolean", description), f"{description=}"


def test_studies_counting(client: TestClient, admin_access_token: str, user_access_token: str) -> None:
# test admin and non admin user studies count requests
for access_token in [admin_access_token, user_access_token]:
res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {access_token}"})
assert res.status_code == 200, res.json()
expected_studies_count = len(res.json())
res = client.get(STUDIES_URL + "/count", headers={"Authorization": f"Bearer {access_token}"})
assert res.status_code == 200, res.json()
assert res.json() == expected_studies_count
Loading

0 comments on commit bae236d

Please sign in to comment.