diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index 99812eb0f8..2cff53f047 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -481,6 +481,22 @@ def test_study_listing( assert res.status_code == LIST_STATUS_CODE, res.json() page_studies = res.json() assert len(page_studies) == max(0, min(len(study_map) - 2, 2)) + # test pagination concatenation + paginated_studies = {} + page_number = 0 + number_of_pages = 0 + while len(paginated_studies) < len(study_map): + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"pageNb": page_number, "pageSize": 2}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + paginated_studies.update(res.json()) + page_number += 1 + number_of_pages += 1 + assert paginated_studies == study_map + assert number_of_pages == len(study_map) // 2 + len(study_map) % 2 # test 1.c for a user with access to select studies res = client.get( @@ -1515,18 +1531,11 @@ def test_get_studies__invalid_parameters( def test_studies_counting(client: TestClient, admin_access_token: str, user_access_token: str) -> None: - # test admin studies count - res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}) - assert res.status_code == 200, res.json() - expected_studies_count = len(res.json()) - res = client.get(STUDIES_URL + "/count", headers={"Authorization": f"Bearer {admin_access_token}"}) - assert res.status_code == 200, res.json() - assert res.json() == expected_studies_count - - # test user studies count - res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {user_access_token}"}) - assert res.status_code == 200, res.json() - expected_studies_count = len(res.json()) - res = client.get(STUDIES_URL + "/count", headers={"Authorization": f"Bearer {user_access_token}"}) - assert res.status_code == 200, res.json() - assert res.json() == expected_studies_count + # test admin and non admin user studies count requests + for access_token in [admin_access_token, user_access_token]: + res = client.get(STUDIES_URL, headers={"Authorization": f"Bearer {access_token}"}) + assert res.status_code == 200, res.json() + expected_studies_count = len(res.json()) + res = client.get(STUDIES_URL + "/count", headers={"Authorization": f"Bearer {access_token}"}) + assert res.status_code == 200, res.json() + assert res.json() == expected_studies_count diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index c8fb90e73a..b698497b9c 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -198,14 +198,13 @@ def test_get_all__study_name_filter( assert {s.id for s in all_studies} == expected_ids # test pagination - if len(expected_ids) >= 2: - with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all( - study_filter=StudyFilter(name=name, access_permissions=AccessPermissions(is_admin=True)), - pagination=StudyPagination(page_nb=1, page_size=2), - ) - assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) - assert len(db_recorder.sql_statements) == 1, str(db_recorder) + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(name=name, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) @pytest.mark.parametrize( @@ -255,14 +254,13 @@ def test_get_all__managed_study_filter( assert {s.id for s in all_studies} == expected_ids # test pagination - if len(expected_ids) >= 2: - with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all( - study_filter=StudyFilter(managed=managed, access_permissions=AccessPermissions(is_admin=True)), - pagination=StudyPagination(page_nb=1, page_size=2), - ) - assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) - assert len(db_recorder.sql_statements) == 1, str(db_recorder) + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=StudyFilter(managed=managed, access_permissions=AccessPermissions(is_admin=True)), + pagination=StudyPagination(page_nb=1, page_size=2), + ) + assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) @pytest.mark.parametrize(