Skip to content

Commit

Permalink
test(study-search): generalize pagination testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Feb 27, 2024
1 parent a3571d5 commit ca8f55d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 158 deletions.
40 changes: 18 additions & 22 deletions tests/integration/studies_blueprint/test_get_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,9 @@ def test_study_listing(
headers={"Authorization": f"Bearer {john_doe_access_token}"},
params={"pageNb": 1, "pageSize": 2},
)
if len(study_map) > 2:
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == min(2, len(study_map) - 2)
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == max(0, min(2, len(study_map) - 2))

# test 1.b for an admin user
res = client.get(
Expand All @@ -479,10 +478,9 @@ def test_study_listing(
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"pageNb": 1, "pageSize": 2},
)
if len(study_map) > 2:
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == min(2, len(study_map) - 2)
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == max(0, min(len(study_map) - 2, 2))

# test 1.c for a user with access to select studies
res = client.get(
Expand Down Expand Up @@ -646,10 +644,9 @@ def test_study_listing(
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"versions": "850,860", "pageNb": 1, "pageSize": 2},
)
if len(study_map) > 2:
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == min(2, len(study_map) - 2)
assert res.status_code == LIST_STATUS_CODE, res.json()
page_studies = res.json()
assert len(page_studies) == max(0, min(len(study_map) - 2, 2))

# tests (7) for users filtering
# test 7.a to get studies for one user: James Bond
Expand Down Expand Up @@ -1375,16 +1372,15 @@ def test_get_studies__access_permissions(self, client: TestClient, admin_access_
assert not expected_studies.difference(set(study_map))
assert not all_studies.difference(expected_studies).intersection(set(study_map))
# test pagination
if len(expected_studies) > 2:
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {users_tokens['user_1']}"},
params={"groups": ",".join(request_groups_ids), "pageNb": 1, "pageSize": 2}
if request_groups_ids
else {"pageNb": 1, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
assert len(res.json()) == min(2, len(expected_studies) - 2)
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {users_tokens['user_1']}"},
params={"groups": ",".join(request_groups_ids), "pageNb": 1, "pageSize": 2}
if request_groups_ids
else {"pageNb": 1, "pageSize": 2},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
assert len(res.json()) == max(0, min(2, len(expected_studies) - 2))

# user_2 access
requests_params_expected_studies = [
Expand Down
223 changes: 87 additions & 136 deletions tests/study/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@ def test_get_all__general_case(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


def test_get_all__incompatible_case(
Expand Down Expand Up @@ -205,7 +204,7 @@ def test_get_all__study_name_filter(
study_filter=StudyFilter(name=name, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


Expand Down Expand Up @@ -262,7 +261,7 @@ def test_get_all__managed_study_filter(
study_filter=StudyFilter(managed=managed, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


Expand Down Expand Up @@ -307,22 +306,13 @@ def test_get_all__archived_study_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
elif len(expected_ids) == 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == 1
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -366,22 +356,13 @@ def test_get_all__variant_study_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
elif len(expected_ids) == 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == 1
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -427,22 +408,13 @@ def test_get_all__study_version_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
elif len(expected_ids) == 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == 1
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -495,14 +467,13 @@ def test_get_all__study_users_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(users=users, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(users=users, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -555,14 +526,13 @@ def test_get_all__study_groups_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(groups=groups, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(groups=groups, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -610,14 +580,13 @@ def test_get_all__study_ids_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(study_ids=study_ids, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(study_ids=study_ids, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -662,14 +631,13 @@ def test_get_all__study_existence_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(exists=exists, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(exists=exists, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -715,14 +683,13 @@ def test_get_all__study_workspace_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(workspace=workspace, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(workspace=workspace, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -770,22 +737,13 @@ def test_get_all__study_folder_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
elif len(expected_ids) == 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == 1
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=1),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 1, 1))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -841,14 +799,13 @@ def test_get_all__study_tags_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(tags=tags, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=StudyFilter(tags=tags, access_permissions=AccessPermissions(is_admin=True)),
pagination=StudyPagination(page_nb=1, page_size=2),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1003,13 +960,10 @@ def test_get_all__non_admin_permissions_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter, pagination=StudyPagination(page_nb=1, page_size=2)
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=study_filter, pagination=StudyPagination(page_nb=1, page_size=2))
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1137,13 +1091,10 @@ def test_get_all__admin_permissions_filter(
assert {s.id for s in all_studies} == expected_ids

# test pagination
if len(expected_ids) > 2:
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter, pagination=StudyPagination(page_nb=1, page_size=2)
)
assert len(all_studies) == min(len(expected_ids) - 2, 2)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=study_filter, pagination=StudyPagination(page_nb=1, page_size=2))
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)


def test_update_tags(
Expand Down

0 comments on commit ca8f55d

Please sign in to comment.