Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(study-search): add a studies counting endpoint #1942

Merged
merged 5 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,22 @@ def get_studies_information(
studies[study_metadata.id] = study_metadata
return studies

def count_studies(
self,
study_filter: StudyFilter,
) -> int:
"""
Get number of matching studies.
Args:
study_filter: filtering parameters

Returns: total number of studies matching the filtering criteria
"""
total: int = self.repository.count_studies(
study_filter=study_filter,
)
return total

def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]:
try:
return self.storage_service.get_storage(study).get_study_information(study)
Expand Down
102 changes: 78 additions & 24 deletions antarest/study/web/studies_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

logger = logging.getLogger(__name__)

QUERY_REGEX = r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$"


def _split_comma_separated_values(value: str, *, default: t.Sequence[str] = ()) -> t.Sequence[str]:
"""Split a comma-separated list of values into an ordered set of strings."""
Expand Down Expand Up @@ -76,23 +78,11 @@ def get_studies(
managed: t.Optional[bool] = Query(None, description="Filter studies based on their management status."),
archived: t.Optional[bool] = Query(None, description="Filter studies based on their archive status."),
variant: t.Optional[bool] = Query(None, description="Filter studies based on their variant status."),
versions: str = Query(
"",
description="Comma-separated list of versions for filtering.",
regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$",
),
users: str = Query(
"",
description="Comma-separated list of user IDs for filtering.",
regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$",
),
versions: str = Query("", description="Comma-separated list of versions for filtering.", regex=QUERY_REGEX),
users: str = Query("", description="Comma-separated list of user IDs for filtering.", regex=QUERY_REGEX),
groups: str = Query("", description="Comma-separated list of group IDs for filtering."),
tags: str = Query("", description="Comma-separated list of tags for filtering."),
study_ids: str = Query(
"",
description="Comma-separated list of study IDs for filtering.",
alias="studyIds",
),
study_ids: str = Query("", description="Comma-separated list of study IDs for filtering.", alias="studyIds"),
exists: t.Optional[bool] = Query(None, description="Filter studies based on their existence on disk."),
workspace: str = Query("", description="Filter studies based on their workspace."),
folder: str = Query("", description="Filter studies based on their folder."),
Expand All @@ -102,23 +92,17 @@ def get_studies(
description="Sort studies based on their name (case-insensitive) or creation date.",
alias="sortBy",
),
page_nb: NonNegativeInt = Query(
0,
description="Page number (starting from 0).",
alias="pageNb",
),
page_nb: NonNegativeInt = Query(0, description="Page number (starting from 0).", alias="pageNb"),
page_size: NonNegativeInt = Query(
0,
description="Number of studies per page (0 = no limit).",
alias="pageSize",
0, description="Number of studies per page (0 = no limit).", alias="pageSize"
),
) -> t.Dict[str, StudyMetadataDTO]:
"""
Get the list of studies matching the specified criteria.

Args:

- `name`: Filter studies based on their name. Case-insensitive search for studies
whose name contains the specified value.
mabw-rte marked this conversation as resolved.
Show resolved Hide resolved
- `managed`: Filter studies based on their management status.
- `archived`: Filter studies based on their archive status.
- `variant`: Filter studies based on their variant status.
Expand Down Expand Up @@ -171,6 +155,76 @@ def get_studies(

return matching_studies

@bp.get(
"/studies/count",
tags=[APITag.study_management],
summary="Count Studies",
)
def count_studies(
current_user: JWTUser = Depends(auth.get_current_user),
name: str = Query("", description="Case-insensitive: filter studies based on their name.", alias="name"),
managed: t.Optional[bool] = Query(None, description="Management status filter."),
archived: t.Optional[bool] = Query(None, description="Archive status filter."),
variant: t.Optional[bool] = Query(None, description="Variant status filter."),
versions: str = Query("", description="Comma-separated versions filter.", regex=QUERY_REGEX),
mabw-rte marked this conversation as resolved.
Show resolved Hide resolved
users: str = Query("", description="Comma-separated user IDs filter.", regex=QUERY_REGEX),
groups: str = Query("", description="Comma-separated group IDs filter."),
tags: str = Query("", description="Comma-separated tags filter."),
study_ids: str = Query("", description="Comma-separated study IDs filter.", alias="studyIds"),
exists: t.Optional[bool] = Query(None, description="Existence on disk filter."),
workspace: str = Query("", description="Workspace filter."),
folder: str = Query("", description="Study folder filter."),
) -> int:
"""
Get the number of studies matching the specified criteria.

Args:

- `name`: Regexp to filter through studies based on their names
- `managed`: Whether to limit the selection based on management status.
- `archived`: Whether to limit the selection based on archive status.
- `variant`: Whether to limit the selection either raw or variant studies.
- `versions`: Comma-separated versions for studies to be selected.
- `users`: Comma-separated user IDs for studies to be selected.
- `groups`: Comma-separated group IDs for studies to be selected.
- `tags`: Comma-separated tags for studies to be selected.
- `studyIds`: Comma-separated IDs of studies to be selected.
- `exists`: Whether to limit the selection based on studies' existence on disk.
- `workspace`: to limit studies selection based on their workspace.
- `folder`: to limit studies selection based on their folder.

Returns:
- An integer representing the total number of studies matching the filters above and the user permissions.
"""
mabw-rte marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Counting matching studies", extra={"user": current_user.id})
params = RequestParameters(user=current_user)

user_list = [int(v) for v in _split_comma_separated_values(users)]

if not params.user:
raise UserHasNotPermissionError("FAIL permission: user is not logged")

count = study_service.count_studies(
study_filter=StudyFilter(
name=name,
managed=managed,
archived=archived,
variant=variant,
versions=_split_comma_separated_values(versions),
users=user_list,
groups=_split_comma_separated_values(groups),
tags=_split_comma_separated_values(tags),
study_ids=_split_comma_separated_values(study_ids),
exists=exists,
workspace=workspace,
folder=folder,
access_permissions=AccessPermissions.from_params(params),
),
)

return count

@bp.get(
"/studies/{uuid}/comments",
tags=[APITag.study_management],
Expand Down
61 changes: 61 additions & 0 deletions tests/integration/studies_blueprint/test_get_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,16 @@ def test_study_listing(
study_map = res.json()
assert not all_studies.intersection(study_map)
assert all(map(lambda x: pm(x) in [PublicMode.READ, PublicMode.FULL], study_map.values()))
# test pagination
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {john_doe_access_token}"},
params={"pageNb": 1, "pageSize": 2},
)
if len(study_map) > 2:
mabw-rte marked this conversation as resolved.
Show resolved Hide resolved
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == min(2, len(study_map) - 2)

# test 1.b for an admin user
res = client.get(
Expand All @@ -463,6 +473,16 @@ def test_study_listing(
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map = res.json()
assert not all_studies.difference(study_map)
# test pagination
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"pageNb": 1, "pageSize": 2},
)
if len(study_map) > 2:
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == min(2, len(study_map) - 2)

# test 1.c for a user with access to select studies
res = client.get(
Expand Down Expand Up @@ -620,6 +640,16 @@ def test_study_listing(
study_map = res.json()
assert not all_studies.difference(studies_version_850.union(studies_version_860)).intersection(study_map)
assert not studies_version_850.union(studies_version_860).difference(study_map)
# test pagination
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"versions": "850,860", "pageNb": 1, "pageSize": 2},
)
if len(study_map) > 2:
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == min(2, len(study_map) - 2)

# tests (7) for users filtering
# test 7.a to get studies for one user: James Bond
Expand Down Expand Up @@ -1318,6 +1348,7 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_
# fmt: off
([], {"1", "2", "5", "6", "7", "8", "9", "10", "13", "14", "15", "16", "17",
"18", "21", "22", "23", "24", "25", "26", "29", "30", "31", "32", "34"}),
# fmt: on
(["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}),
(["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}),
(["3"], set()),
Expand All @@ -1343,12 +1374,24 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_
study_map = res.json()
assert not expected_studies.difference(set(study_map))
assert not all_studies.difference(expected_studies).intersection(set(study_map))
# test pagination
if len(expected_studies) > 2:
mabw-rte marked this conversation as resolved.
Show resolved Hide resolved
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {users_tokens['user_1']}"},
params={"groups": ",".join(request_groups_ids), "pageNb": 1, "pageSize": 2}
if request_groups_ids
else {"pageNb": 1, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
assert len(res.json()) == min(2, len(expected_studies) - 2)

# user_2 access
requests_params_expected_studies = [
# fmt: off
([], {"1", "3", "4", "5", "7", "8", "9", "11", "13", "14", "15", "16", "17",
"19", "20", "21", "23", "24", "25", "27", "29", "30", "31", "32", "33"}),
# fmt: on
(["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}),
(["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}),
(["3"], set()),
Expand Down Expand Up @@ -1473,3 +1516,21 @@ def test_get_studies__invalid_parameters(
assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json()
description = res.json()["description"]
assert re.search(r"could not be parsed to a boolean", description), f"{description=}"


def test_studies_counting(client: TestClient, admin_access_token: str, user_access_token: str) -> None:
mabw-rte marked this conversation as resolved.
Show resolved Hide resolved
# test admin studies count
res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"})
assert res.status_code == 200, res.json()
expected_studies_count = len(res.json())
res = client.get(STUDIES_URL + "/count", headers={"Authorization": f"Bearer {admin_access_token}"})
assert res.status_code == 200, res.json()
assert res.json() == expected_studies_count

# test user studies count
mabw-rte marked this conversation as resolved.
Show resolved Hide resolved
res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {user_access_token}"})
assert res.status_code == 200, res.json()
expected_studies_count = len(res.json())
res = client.get(STUDIES_URL + "/count", headers={"Authorization": f"Bearer {user_access_token}"})
assert res.status_code == 200, res.json()
assert res.json() == expected_studies_count
Loading
Loading