diff --git a/antarest/study/service.py b/antarest/study/service.py index 81af28a473..dc3288b4e2 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -478,6 +478,22 @@ def get_studies_information( studies[study_metadata.id] = study_metadata return studies + def count_studies( + self, + study_filter: StudyFilter, + ) -> int: + """ + Get number of matching studies. + Args: + study_filter: filtering parameters + + Returns: total number of studies matching the filtering criteria + """ + total: int = self.repository.count_studies( + study_filter=study_filter, + ) + return total + def _try_get_studies_information(self, study: Study) -> t.Optional[StudyMetadataDTO]: try: return self.storage_service.get_storage(study).get_study_information(study) diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index 9007ed2894..6565538281 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -34,6 +34,8 @@ logger = logging.getLogger(__name__) +QUERY_REGEX = r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$" + def _split_comma_separated_values(value: str, *, default: t.Sequence[str] = ()) -> t.Sequence[str]: """Split a comma-separated list of values into an ordered set of strings.""" @@ -76,23 +78,11 @@ def get_studies( managed: t.Optional[bool] = Query(None, description="Filter studies based on their management status."), archived: t.Optional[bool] = Query(None, description="Filter studies based on their archive status."), variant: t.Optional[bool] = Query(None, description="Filter studies based on their variant status."), - versions: str = Query( - "", - description="Comma-separated list of versions for filtering.", - regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$", - ), - users: str = Query( - "", - description="Comma-separated list of user IDs for filtering.", - regex=r"^\s*(?:\d+\s*(?:,\s*\d+\s*)*)?$", - ), + versions: str = Query("", description="Comma-separated list of versions for filtering.", regex=QUERY_REGEX), + users: str = Query("", description="Comma-separated list of user IDs for filtering.", regex=QUERY_REGEX), groups: str = Query("", description="Comma-separated list of group IDs for filtering."), tags: str = Query("", description="Comma-separated list of tags for filtering."), - study_ids: str = Query( - "", - description="Comma-separated list of study IDs for filtering.", - alias="studyIds", - ), + study_ids: str = Query("", description="Comma-separated list of study IDs for filtering.", alias="studyIds"), exists: t.Optional[bool] = Query(None, description="Filter studies based on their existence on disk."), workspace: str = Query("", description="Filter studies based on their workspace."), folder: str = Query("", description="Filter studies based on their folder."), @@ -102,23 +92,17 @@ def get_studies( description="Sort studies based on their name (case-insensitive) or creation date.", alias="sortBy", ), - page_nb: NonNegativeInt = Query( - 0, - description="Page number (starting from 0).", - alias="pageNb", - ), + page_nb: NonNegativeInt = Query(0, description="Page number (starting from 0).", alias="pageNb"), page_size: NonNegativeInt = Query( - 0, - description="Number of studies per page (0 = no limit).", - alias="pageSize", + 0, description="Number of studies per page (0 = no limit).", alias="pageSize" ), ) -> t.Dict[str, StudyMetadataDTO]: """ Get the list of studies matching the specified criteria. Args: + - `name`: Filter studies based on their name. Case-insensitive search for studies - whose name contains the specified value. - `managed`: Filter studies based on their management status. - `archived`: Filter studies based on their archive status. - `variant`: Filter studies based on their variant status. @@ -171,6 +155,76 @@ def get_studies( return matching_studies + @bp.get( + "/studies/count", + tags=[APITag.study_management], + summary="Count Studies", + ) + def count_studies( + current_user: JWTUser = Depends(auth.get_current_user), + name: str = Query("", description="Case-insensitive: filter studies based on their name.", alias="name"), + managed: t.Optional[bool] = Query(None, description="Management status filter."), + archived: t.Optional[bool] = Query(None, description="Archive status filter."), + variant: t.Optional[bool] = Query(None, description="Variant status filter."), + versions: str = Query("", description="Comma-separated versions filter.", regex=QUERY_REGEX), + users: str = Query("", description="Comma-separated user IDs filter.", regex=QUERY_REGEX), + groups: str = Query("", description="Comma-separated group IDs filter."), + tags: str = Query("", description="Comma-separated tags filter."), + study_ids: str = Query("", description="Comma-separated study IDs filter.", alias="studyIds"), + exists: t.Optional[bool] = Query(None, description="Existence on disk filter."), + workspace: str = Query("", description="Workspace filter."), + folder: str = Query("", description="Study folder filter."), + ) -> int: + """ + Get the number of studies matching the specified criteria. + + Args: + + - `name`: Regexp to filter through studies based on their names + - `managed`: Whether to limit the selection based on management status. + - `archived`: Whether to limit the selection based on archive status. + - `variant`: Whether to limit the selection either raw or variant studies. + - `versions`: Comma-separated versions for studies to be selected. + - `users`: Comma-separated user IDs for studies to be selected. + - `groups`: Comma-separated group IDs for studies to be selected. + - `tags`: Comma-separated tags for studies to be selected. + - `studyIds`: Comma-separated IDs of studies to be selected. + - `exists`: Whether to limit the selection based on studies' existence on disk. + - `workspace`: to limit studies selection based on their workspace. + - `folder`: to limit studies selection based on their folder. + + Returns: + - An integer representing the total number of studies matching the filters above and the user permissions. + """ + + logger.info("Counting matching studies", extra={"user": current_user.id}) + params = RequestParameters(user=current_user) + + user_list = [int(v) for v in _split_comma_separated_values(users)] + + if not params.user: + raise UserHasNotPermissionError("FAIL permission: user is not logged") + + count = study_service.count_studies( + study_filter=StudyFilter( + name=name, + managed=managed, + archived=archived, + variant=variant, + versions=_split_comma_separated_values(versions), + users=user_list, + groups=_split_comma_separated_values(groups), + tags=_split_comma_separated_values(tags), + study_ids=_split_comma_separated_values(study_ids), + exists=exists, + workspace=workspace, + folder=folder, + access_permissions=AccessPermissions.from_params(params), + ), + ) + + return count + @bp.get( "/studies/{uuid}/comments", tags=[APITag.study_management], diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index 48dadaf829..2cff53f047 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -454,6 +454,15 @@ def test_study_listing( study_map = res.json() assert not all_studies.intersection(study_map) assert all(map(lambda x: pm(x) in [PublicMode.READ, PublicMode.FULL], study_map.values())) + # test pagination + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {john_doe_access_token}"}, + params={"pageNb": 1, "pageSize": 2}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + page_studies = res.json() + assert len(page_studies) == max(0, min(2, len(study_map) - 2)) # test 1.b for an admin user res = client.get( @@ -463,6 +472,31 @@ def test_study_listing( assert res.status_code == LIST_STATUS_CODE, res.json() study_map = res.json() assert not all_studies.difference(study_map) + # test pagination + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"pageNb": 1, "pageSize": 2}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + page_studies = res.json() + assert len(page_studies) == max(0, min(len(study_map) - 2, 2)) + # test pagination concatenation + paginated_studies = {} + page_number = 0 + number_of_pages = 0 + while len(paginated_studies) < len(study_map): + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"pageNb": page_number, "pageSize": 2}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + paginated_studies.update(res.json()) + page_number += 1 + number_of_pages += 1 + assert paginated_studies == study_map + assert number_of_pages == len(study_map) // 2 + len(study_map) % 2 # test 1.c for a user with access to select studies res = client.get( @@ -620,6 +654,15 @@ def test_study_listing( study_map = res.json() assert not all_studies.difference(studies_version_850.union(studies_version_860)).intersection(study_map) assert not studies_version_850.union(studies_version_860).difference(study_map) + # test pagination + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"versions": "850,860", "pageNb": 1, "pageSize": 2}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + page_studies = res.json() + assert len(page_studies) == max(0, min(len(study_map) - 2, 2)) # tests (7) for users filtering # test 7.a to get studies for one user: James Bond @@ -1318,6 +1361,7 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_ # fmt: off ([], {"1", "2", "5", "6", "7", "8", "9", "10", "13", "14", "15", "16", "17", "18", "21", "22", "23", "24", "25", "26", "29", "30", "31", "32", "34"}), + # fmt: on (["1"], {"1", "7", "8", "9", "17", "23", "24", "25"}), (["2"], {"2", "5", "6", "7", "8", "9", "18", "21", "22", "23", "24", "25", "34"}), (["3"], set()), @@ -1343,12 +1387,23 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_ study_map = res.json() assert not expected_studies.difference(set(study_map)) assert not all_studies.difference(expected_studies).intersection(set(study_map)) + # test pagination + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {users_tokens['user_1']}"}, + params={"groups": ",".join(request_groups_ids), "pageNb": 1, "pageSize": 2} + if request_groups_ids + else {"pageNb": 1, "pageSize": 2}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + assert len(res.json()) == max(0, min(2, len(expected_studies) - 2)) # user_2 access requests_params_expected_studies = [ # fmt: off ([], {"1", "3", "4", "5", "7", "8", "9", "11", "13", "14", "15", "16", "17", "19", "20", "21", "23", "24", "25", "27", "29", "30", "31", "32", "33"}), + # fmt: on (["1"], {"1", "3", "4", "7", "8", "9", "17", "19", "20", "23", "24", "25", "33"}), (["2"], {"5", "7", "8", "9", "21", "23", "24", "25"}), (["3"], set()), @@ -1473,3 +1528,14 @@ def test_get_studies__invalid_parameters( assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + + +def test_studies_counting(client: TestClient, admin_access_token: str, user_access_token: str) -> None: + # test admin and non admin user studies count requests + for access_token in [admin_access_token, user_access_token]: + res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {access_token}"}) + assert res.status_code == 200, res.json() + expected_studies_count = len(res.json()) + res = client.get(STUDIES_URL + "/count", headers={"Authorization": f"Bearer {access_token}"}) + assert res.status_code == 200, res.json() + assert res.json() == expected_studies_count diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index 4762cc7fed..b698497b9c 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -9,7 +9,7 @@ from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag -from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository +from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository, StudyPagination from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.db_statement_recorder import DBStatementRecorder @@ -66,24 +66,30 @@ def test_get_all__general_case( # 1- retrieving all studies requires only 1 query # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 + study_filter = StudyFilter( + managed=managed, study_ids=study_ids, exists=exists, access_permissions=AccessPermissions(is_admin=True) + ) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all( - study_filter=StudyFilter( - managed=managed, - study_ids=study_ids, - exists=exists, - access_permissions=AccessPermissions(is_admin=True), - ) - ) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) + # test that the expected studies are returned if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=study_filter, + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + def test_get_all__incompatible_case( db_session: Session, @@ -191,6 +197,15 @@ def test_get_all__study_name_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(name=name, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "managed, expected_ids", @@ -238,6 +253,15 @@ def test_get_all__managed_study_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(managed=managed, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "archived, expected_ids", @@ -267,10 +291,9 @@ def test_get_all__archived_study_filter( # 1- retrieving all studies requires only 1 query # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 + study_filter = StudyFilter(archived=archived, access_permissions=AccessPermissions(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all( - study_filter=StudyFilter(archived=archived, access_permissions=AccessPermissions(is_admin=True)) - ) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -280,6 +303,15 @@ def test_get_all__archived_study_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=study_filter, + pagination=StudyPagination(page_nb=1, page_size=1), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "variant, expected_ids", @@ -309,10 +341,9 @@ def test_get_all__variant_study_filter( # 1- retrieving all studies requires only 1 query # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 + study_filter = StudyFilter(variant=variant, access_permissions=AccessPermissions(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all( - study_filter=StudyFilter(variant=variant, access_permissions=AccessPermissions(is_admin=True)) - ) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -322,6 +353,15 @@ def test_get_all__variant_study_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=study_filter, + pagination=StudyPagination(page_nb=1, page_size=1), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "versions, expected_ids", @@ -353,10 +393,9 @@ def test_get_all__study_version_filter( # 1- retrieving all studies requires only 1 query # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 + study_filter = StudyFilter(versions=versions, access_permissions=AccessPermissions(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all( - study_filter=StudyFilter(versions=versions, access_permissions=AccessPermissions(is_admin=True)) - ) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -366,6 +405,15 @@ def test_get_all__study_version_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=study_filter, + pagination=StudyPagination(page_nb=1, page_size=1), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "users, expected_ids", @@ -416,6 +464,15 @@ def test_get_all__study_users_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(users=users, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "groups, expected_ids", @@ -466,6 +523,15 @@ def test_get_all__study_groups_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(groups=groups, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "study_ids, expected_ids", @@ -511,6 +577,15 @@ def test_get_all__study_ids_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(study_ids=study_ids, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "exists, expected_ids", @@ -553,6 +628,15 @@ def test_get_all__study_existence_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(exists=exists, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "workspace, expected_ids", @@ -596,6 +680,15 @@ def test_get_all__study_workspace_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(workspace=workspace, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "folder, expected_ids", @@ -628,10 +721,9 @@ def test_get_all__study_folder_filter( # 1- retrieving all studies requires only 1 query # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 + study_filter = StudyFilter(folder=folder, access_permissions=AccessPermissions(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all( - study_filter=StudyFilter(folder=folder, access_permissions=AccessPermissions(is_admin=True)) - ) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -642,6 +734,15 @@ def test_get_all__study_folder_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=study_filter, + pagination=StudyPagination(page_nb=1, page_size=1), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "tags, expected_ids", @@ -695,6 +796,15 @@ def test_get_all__study_tags_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(tags=tags, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "user_id, study_groups, expected_ids", @@ -847,6 +957,12 @@ def test_get_all__non_admin_permissions_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter, pagination=StudyPagination(page_nb=1, page_size=2)) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + @pytest.mark.parametrize( "is_admin, study_groups, expected_ids", @@ -972,6 +1088,12 @@ def test_get_all__admin_permissions_filter( if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + # test pagination + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=study_filter, pagination=StudyPagination(page_nb=1, page_size=2)) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + def test_update_tags( db_session: Session, @@ -1004,3 +1126,71 @@ def test_update_tags( # Check that only "Tag1" and "Tag3" are present in the database tags = db_session.query(Tag).all() assert {tag.label for tag in tags} == {"Tag1", "Tag3"} + + +@pytest.mark.parametrize( + "managed, study_ids, exists, expected_ids", + [ + (None, [], False, {"5", "6"}), + (None, [], True, {"1", "2", "3", "4", "7", "8"}), + (None, [], None, {"1", "2", "3", "4", "5", "6", "7", "8"}), + (None, [1, 3, 5, 7], False, {"5"}), + (None, [1, 3, 5, 7], True, {"1", "3", "7"}), + (None, [1, 3, 5, 7], None, {"1", "3", "5", "7"}), + (True, [], False, {"5"}), + (True, [], True, {"1", "2", "3", "4", "8"}), + (True, [], None, {"1", "2", "3", "4", "5", "8"}), + (True, [1, 3, 5, 7], False, {"5"}), + (True, [1, 3, 5, 7], True, {"1", "3"}), + (True, [1, 3, 5, 7], None, {"1", "3", "5"}), + (True, [2, 4, 6, 8], True, {"2", "4", "8"}), + (True, [2, 4, 6, 8], None, {"2", "4", "8"}), + (False, [], False, {"6"}), + (False, [], True, {"7"}), + (False, [], None, {"6", "7"}), + (False, [1, 3, 5, 7], False, set()), + (False, [1, 3, 5, 7], True, {"7"}), + (False, [1, 3, 5, 7], None, {"7"}), + ], +) +def test_count_studies__general_case( + db_session: Session, + managed: t.Union[bool, None], + study_ids: t.Sequence[str], + exists: t.Union[bool, None], + expected_ids: t.Set[str], +) -> None: + test_workspace = "test-repository" + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_1 = VariantStudy(id=1) + study_2 = VariantStudy(id=2) + study_3 = VariantStudy(id=3) + study_4 = VariantStudy(id=4) + study_5 = RawStudy(id=5, missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) + study_6 = RawStudy(id=6, missing=datetime.datetime.now(), workspace=test_workspace) + study_7 = RawStudy(id=7, missing=None, workspace=test_workspace) + study_8 = RawStudy(id=8, missing=None, workspace=DEFAULT_WORKSPACE_NAME) + + db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does not require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + count = repository.count_studies( + study_filter=StudyFilter( + managed=managed, + study_ids=study_ids, + exists=exists, + access_permissions=AccessPermissions(is_admin=True), + ) + ) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + # test that the expected studies are returned + if expected_ids is not None: + assert count == len(expected_ids)