diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 6e89d33d5a..dd60282a43 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, NonNegativeInt from sqlalchemy import and_, func, not_, or_, sql # type: ignore -from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore +from sqlalchemy.orm import Query, Session, joinedload, subqueryload, with_polymorphic # type: ignore from antarest.core.interfaces.cache import ICache from antarest.core.jwt import JWTUser @@ -304,10 +304,15 @@ def _search_studies( q = q.filter(RawStudy.missing.is_(None)) else: q = q.filter(not_(RawStudy.missing.is_(None))) - q = q.options(joinedload(entity.owner)) - q = q.options(joinedload(entity.groups)) + + if study_filter.users is not None: + q = q.options(joinedload(entity.owner)) + if study_filter.groups is not None: + q = q.options(joinedload(entity.groups)) + if study_filter.tags is not None: + q = q.options(joinedload(entity.tags)) q = q.options(joinedload(entity.additional_data)) - q = q.options(joinedload(entity.tags)) + if study_filter.managed is not None: if study_filter.managed: q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME)) diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index cd8b6c790c..e6becac349 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -9,68 +9,83 @@ from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag -from antarest.study.repository import AccessPermissions, StudyFilter, StudyMetadataRepository, StudyPagination +from antarest.study.repository import ( + AccessPermissions, + StudyFilter, + StudyMetadataRepository, + StudyPagination, + StudySortBy, +) from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.db_statement_recorder import DBStatementRecorder @pytest.mark.parametrize( - "managed, study_ids, exists, expected_ids", + "managed, study_names, exists, expected_names", [ - (None, [], False, {"5", "6"}), - (None, [], True, {"1", "2", "3", "4", "7", "8"}), - (None, [], None, {"1", "2", "3", "4", "5", "6", "7", "8"}), - (None, [1, 3, 5, 7], False, {"5"}), - (None, [1, 3, 5, 7], True, {"1", "3", "7"}), - (None, [1, 3, 5, 7], None, {"1", "3", "5", "7"}), - (True, [], False, {"5"}), - (True, [], True, {"1", "2", "3", "4", "8"}), - (True, [], None, {"1", "2", "3", "4", "5", "8"}), - (True, [1, 3, 5, 7], False, {"5"}), - (True, [1, 3, 5, 7], True, {"1", "3"}), - (True, [1, 3, 5, 7], None, {"1", "3", "5"}), - (True, [2, 4, 6, 8], True, {"2", "4", "8"}), - (True, [2, 4, 6, 8], None, {"2", "4", "8"}), - (False, [], False, {"6"}), - (False, [], True, {"7"}), - (False, [], None, {"6", "7"}), - (False, [1, 3, 5, 7], False, set()), - (False, [1, 3, 5, 7], True, {"7"}), - (False, [1, 3, 5, 7], None, {"7"}), + (None, [], False, ["s5", "s6"]), + (None, [], True, ["s1", "s2", "s3", "s4", "s7", "s8"]), + (None, [], None, ["s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"]), + (None, ["s1", "s3", "s5", "s7"], False, ["s5"]), + (None, ["s1", "s3", "s5", "s7"], True, ["s1", "s3", "s7"]), + (None, ["s1", "s3", "s5", "s7"], None, ["s1", "s3", "s5", "s7"]), + (True, [], False, ["s5"]), + (True, [], True, ["s1", "s2", "s3", "s4", "s8"]), + (True, [], None, ["s1", "s2", "s3", "s4", "s5", "s8"]), + (True, ["s1", "s3", "s5", "s7"], False, ["s5"]), + (True, ["s1", "s3", "s5", "s7"], True, ["s1", "s3"]), + (True, ["s1", "s3", "s5", "s7"], None, ["s1", "s3", "s5"]), + (True, ["s2", "s4", "s6", "s8"], True, ["s2", "s4", "s8"]), + (True, ["s2", "s4", "s6", "s8"], None, ["s2", "s4", "s8"]), + (False, [], False, ["s6"]), + (False, [], True, ["s7"]), + (False, [], None, ["s6", "s7"]), + (False, ["s1", "s3", "s5", "s7"], False, []), + (False, ["s1", "s3", "s5", "s7"], True, ["s7"]), + (False, ["s1", "s3", "s5", "s7"], None, ["s7"]), ], ) def test_get_all__general_case( db_session: Session, managed: t.Union[bool, None], - study_ids: t.Sequence[str], + study_names: t.Sequence[str], exists: t.Union[bool, None], - expected_ids: t.Set[str], + expected_names: t.Sequence[str], ) -> None: test_workspace = "test-repository" icache: Mock = Mock(spec=ICache) repository = StudyMetadataRepository(cache_service=icache, session=db_session) - study_1 = VariantStudy(id=1) - study_2 = VariantStudy(id=2) - study_3 = VariantStudy(id=3) - study_4 = VariantStudy(id=4) - study_5 = RawStudy(id=5, missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) - study_6 = RawStudy(id=6, missing=datetime.datetime.now(), workspace=test_workspace) - study_7 = RawStudy(id=7, missing=None, workspace=test_workspace) - study_8 = RawStudy(id=8, missing=None, workspace=DEFAULT_WORKSPACE_NAME) - - db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) + study_1 = VariantStudy(name="s1") + study_2 = VariantStudy(name="s2") + study_3 = VariantStudy(name="s3") + study_4 = VariantStudy(name="s4") + study_5 = RawStudy(name="s5", missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) + study_6 = RawStudy(name="s6", missing=datetime.datetime.now(), workspace=test_workspace) + study_7 = RawStudy(name="s7", missing=None, workspace=test_workspace) + study_8 = RawStudy(name="s8", missing=None, workspace=DEFAULT_WORKSPACE_NAME) + + my_studies = [study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8] + db_session.add_all(my_studies) db_session.commit() + ids_by_names = {s.name: s.id for s in my_studies} + # use the db recorder to check that: # 1- retrieving all studies requires only 1 query # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 study_filter = StudyFilter( - managed=managed, study_ids=study_ids, exists=exists, access_permissions=AccessPermissions(is_admin=True) + managed=managed, + study_ids=[ids_by_names[name] for name in study_names], + exists=exists, + access_permissions=AccessPermissions(is_admin=True), ) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=study_filter) + all_studies = repository.get_all( + study_filter=study_filter, + sort_by=StudySortBy.NAME_ASC, + ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -78,17 +93,32 @@ def test_get_all__general_case( assert len(db_recorder.sql_statements) == 1, str(db_recorder) # test that the expected studies are returned - if expected_ids is not None: - assert {s.id for s in all_studies} == expected_ids + assert [s.name for s in all_studies] == expected_names - # test pagination + # -- test pagination + page_nb = 1 + page_size = 2 + page_slice = slice(page_nb * page_size, (page_nb + 1) * page_size) + + # test pagination in normal order with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all( study_filter=study_filter, - pagination=StudyPagination(page_nb=1, page_size=2), + sort_by=StudySortBy.NAME_ASC, + pagination=StudyPagination(page_nb=page_nb, page_size=page_size), + ) + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + assert [s.name for s in all_studies] == expected_names[page_slice] + + # test pagination in reverse order + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all( + study_filter=study_filter, + sort_by=StudySortBy.NAME_DESC, + pagination=StudyPagination(page_nb=page_nb, page_size=page_size), ) - assert len(all_studies) == max(0, min(len(expected_ids) - 2, 2)) assert len(db_recorder.sql_statements) == 1, str(db_recorder) + assert [s.name for s in all_studies] == expected_names[::-1][page_slice] def test_get_all__incompatible_case(