Skip to content

Commit

Permalink
fix(study-search): correct the SQL query used for pagination
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Mar 25, 2024
1 parent e757801 commit 62a5908
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 45 deletions.
13 changes: 9 additions & 4 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, NonNegativeInt
from sqlalchemy import and_, func, not_, or_, sql # type: ignore
from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore
from sqlalchemy.orm import Query, Session, joinedload, subqueryload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import ICache
from antarest.core.jwt import JWTUser
Expand Down Expand Up @@ -304,10 +304,15 @@ def _search_studies(
q = q.filter(RawStudy.missing.is_(None))
else:
q = q.filter(not_(RawStudy.missing.is_(None)))
q = q.options(joinedload(entity.owner))
q = q.options(joinedload(entity.groups))

if study_filter.users is not None:
q = q.options(joinedload(entity.owner))
if study_filter.groups is not None:
q = q.options(joinedload(entity.groups))
if study_filter.tags is not None:
q = q.options(joinedload(entity.tags))
q = q.options(joinedload(entity.additional_data))
q = q.options(joinedload(entity.tags))

if study_filter.managed is not None:
if study_filter.managed:
q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
Expand Down
112 changes: 71 additions & 41 deletions tests/study/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,86 +9,116 @@
from antarest.core.model import PublicMode
from antarest.login.model import Group, User
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag
from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository, StudyPagination
from antarest.study.repository import (
AccessPermissions,
StudyFilter,
StudyMetadataRepository,
StudyPagination,
StudySortBy,
)
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy
from tests.db_statement_recorder import DBStatementRecorder


@pytest.mark.parametrize(
"managed, study_ids, exists, expected_ids",
"managed, study_names, exists, expected_names",
[
(None, [], False, {"5", "6"}),
(None, [], True, {"1", "2", "3", "4", "7", "8"}),
(None, [], None, {"1", "2", "3", "4", "5", "6", "7", "8"}),
(None, [1, 3, 5, 7], False, {"5"}),
(None, [1, 3, 5, 7], True, {"1", "3", "7"}),
(None, [1, 3, 5, 7], None, {"1", "3", "5", "7"}),
(True, [], False, {"5"}),
(True, [], True, {"1", "2", "3", "4", "8"}),
(True, [], None, {"1", "2", "3", "4", "5", "8"}),
(True, [1, 3, 5, 7], False, {"5"}),
(True, [1, 3, 5, 7], True, {"1", "3"}),
(True, [1, 3, 5, 7], None, {"1", "3", "5"}),
(True, [2, 4, 6, 8], True, {"2", "4", "8"}),
(True, [2, 4, 6, 8], None, {"2", "4", "8"}),
(False, [], False, {"6"}),
(False, [], True, {"7"}),
(False, [], None, {"6", "7"}),
(False, [1, 3, 5, 7], False, set()),
(False, [1, 3, 5, 7], True, {"7"}),
(False, [1, 3, 5, 7], None, {"7"}),
(None, [], False, ["s5", "s6"]),
(None, [], True, ["s1", "s2", "s3", "s4", "s7", "s8"]),
(None, [], None, ["s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"]),
(None, ["s1", "s3", "s5", "s7"], False, ["s5"]),
(None, ["s1", "s3", "s5", "s7"], True, ["s1", "s3", "s7"]),
(None, ["s1", "s3", "s5", "s7"], None, ["s1", "s3", "s5", "s7"]),
(True, [], False, ["s5"]),
(True, [], True, ["s1", "s2", "s3", "s4", "s8"]),
(True, [], None, ["s1", "s2", "s3", "s4", "s5", "s8"]),
(True, ["s1", "s3", "s5", "s7"], False, ["s5"]),
(True, ["s1", "s3", "s5", "s7"], True, ["s1", "s3"]),
(True, ["s1", "s3", "s5", "s7"], None, ["s1", "s3", "s5"]),
(True, ["s2", "s4", "s6", "s8"], True, ["s2", "s4", "s8"]),
(True, ["s2", "s4", "s6", "s8"], None, ["s2", "s4", "s8"]),
(False, [], False, ["s6"]),
(False, [], True, ["s7"]),
(False, [], None, ["s6", "s7"]),
(False, ["s1", "s3", "s5", "s7"], False, []),
(False, ["s1", "s3", "s5", "s7"], True, ["s7"]),
(False, ["s1", "s3", "s5", "s7"], None, ["s7"]),
],
)
def test_get_all__general_case(
db_session: Session,
managed: t.Union[bool, None],
study_ids: t.Sequence[str],
study_names: t.Sequence[str],
exists: t.Union[bool, None],
expected_ids: t.Set[str],
expected_names: t.Sequence[str],
) -> None:
test_workspace = "test-repository"
icache: Mock = Mock(spec=ICache)
repository = StudyMetadataRepository(cache_service=icache, session=db_session)

study_1 = VariantStudy(id=1)
study_2 = VariantStudy(id=2)
study_3 = VariantStudy(id=3)
study_4 = VariantStudy(id=4)
study_5 = RawStudy(id=5, missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME)
study_6 = RawStudy(id=6, missing=datetime.datetime.now(), workspace=test_workspace)
study_7 = RawStudy(id=7, missing=None, workspace=test_workspace)
study_8 = RawStudy(id=8, missing=None, workspace=DEFAULT_WORKSPACE_NAME)

db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8])
study_1 = VariantStudy(name="s1")
study_2 = VariantStudy(name="s2")
study_3 = VariantStudy(name="s3")
study_4 = VariantStudy(name="s4")
study_5 = RawStudy(name="s5", missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME)
study_6 = RawStudy(name="s6", missing=datetime.datetime.now(), workspace=test_workspace)
study_7 = RawStudy(name="s7", missing=None, workspace=test_workspace)
study_8 = RawStudy(name="s8", missing=None, workspace=DEFAULT_WORKSPACE_NAME)

my_studies = [study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]
db_session.add_all(my_studies)
db_session.commit()

ids_by_names = {s.name: s.id for s in my_studies}

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
study_filter = StudyFilter(
managed=managed, study_ids=study_ids, exists=exists, access_permissions=AccessPermissions(is_admin=True)
managed=managed,
study_ids=[ids_by_names[name] for name in study_names],
exists=exists,
access_permissions=AccessPermissions(is_admin=True),
)
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=study_filter)
all_studies = repository.get_all(
study_filter=study_filter,
sort_by=StudySortBy.NAME_ASC,
)
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

# test that the expected studies are returned
if expected_ids is not None:
assert {s.id for s in all_studies} == expected_ids
assert [s.name for s in all_studies] == expected_names

# test pagination
# -- test pagination
page_nb = 1
page_size = 2
page_slice = slice(page_nb * page_size, (page_nb + 1) * page_size)

# test pagination in normal order
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
pagination=StudyPagination(page_nb=1, page_size=2),
sort_by=StudySortBy.NAME_ASC,
pagination=StudyPagination(page_nb=page_nb, page_size=page_size),
)
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
assert [s.name for s in all_studies] == expected_names[page_slice]

# test pagination in reverse order
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(
study_filter=study_filter,
sort_by=StudySortBy.NAME_DESC,
pagination=StudyPagination(page_nb=page_nb, page_size=page_size),
)
assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2))
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
assert [s.name for s in all_studies] == expected_names[::-1][page_slice]


def test_get_all__incompatible_case(
Expand Down

0 comments on commit 62a5908

Please sign in to comment.