From 10293e07b06c91021db9af7489bf62d91c34912b Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Tue, 6 Feb 2024 14:48:06 +0100 Subject: [PATCH] test(tags-db): unittests for repository tags filter --- antarest/study/repository.py | 2 +- tests/study/test_repository.py | 92 +++++++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index d8e2675cd6..05b88c31f2 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -228,7 +228,7 @@ def get_all( if study_filter.groups: q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) if study_filter.tags: - q = q.join(entity.tags).filter(Tag.id.in_(study_filter.tags)) + q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags)) if study_filter.archived is not None: q = q.filter(entity.archived == study_filter.archived) if study_filter.name: diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index 77e4f1554c..a5d08ca658 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -7,7 +7,7 @@ from antarest.core.interfaces.cache import ICache from antarest.login.model import Group, User -from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy +from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag from antarest.study.repository import StudyFilter, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.db_statement_recorder import DBStatementRecorder @@ -63,13 +63,14 @@ def test_repository_get_all__general_case( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -102,6 +103,7 @@ def test_repository_get_all__incompatible_case( _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) assert not {s.id for s in all_studies} @@ -112,6 +114,7 @@ def test_repository_get_all__incompatible_case( _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) assert not {s.id for s in all_studies} @@ -122,6 +125,7 @@ def test_repository_get_all__incompatible_case( _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) assert not {s.id for s in all_studies} @@ -162,13 +166,14 @@ def test_repository_get_all__study_name_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(name=name)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -206,13 +211,14 @@ def test_repository_get_all__managed_study_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(managed=managed)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -245,13 +251,14 @@ def test_repository_get_all__archived_study_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(archived=archived)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -284,13 +291,14 @@ def test_repository_get_all__variant_study_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(variant=variant)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -325,13 +333,14 @@ def test_repository_get_all__study_version_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(versions=versions)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -372,13 +381,14 @@ def test_repository_get_all__study_users_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(users=users)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -419,13 +429,14 @@ def test_repository_get_all__study_groups_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(groups=groups)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -461,13 +472,14 @@ def test_repository_get_all__study_ids_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(study_ids=study_ids)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -500,13 +512,14 @@ def test_repository_get_all__study_existence_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(exists=exists)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -540,13 +553,14 @@ def test_repository_get_all__study_workspace_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(workspace=workspace)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: @@ -582,13 +596,65 @@ def test_repository_get_all__study_folder_filter( # use the db recorder to check that: # 1- retrieving all studies requires only 1 query - # 2- accessing studies attributes does require additional queries to db + # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all(study_filter=StudyFilter(folder=folder)) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] + + assert len(db_recorder.sql_statements) == 1, str(db_recorder) + + if expected_ids is not None: + assert {s.id for s in all_studies} == expected_ids + + +@pytest.mark.parametrize( + "tags, expected_ids", + [ + ([], {"1", "2", "3", "4", "5", "6", "7", "8"}), + (["decennial"], {"2", "4", "6", "8"}), + (["winter_transition"], {"3", "4", "7", "8"}), + (["decennial", "winter_transition"], {"2", "3", "4", "6", "7", "8"}), + (["no-study-tag"], set()), + ], +) +def test_repository_get_all__study_tags_filter( + db_session: Session, + tags: t.List[str], + expected_ids: t.Set[str], +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + test_tag_1 = Tag(label="hidden-tag") + test_tag_2 = Tag(label="decennial") + test_tag_3 = Tag(label="winter_transition") + + study_1 = VariantStudy(id=1, tags=[test_tag_1]) + study_2 = VariantStudy(id=2, tags=[test_tag_2]) + study_3 = VariantStudy(id=3, tags=[test_tag_3]) + study_4 = VariantStudy(id=4, tags=[test_tag_2, test_tag_3]) + study_5 = RawStudy(id=5, tags=[test_tag_1]) + study_6 = RawStudy(id=6, tags=[test_tag_2]) + study_7 = RawStudy(id=7, tags=[test_tag_3]) + study_8 = RawStudy(id=8, tags=[test_tag_2, test_tag_3]) + + db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) + db_session.commit() + + # use the db recorder to check that: + # 1- retrieving all studies requires only 1 query + # 2- accessing studies attributes does not require additional queries to db + # 3- having an exact total of queries equals to 1 + with DBStatementRecorder(db_session.bind) as db_recorder: + all_studies = repository.get_all(study_filter=StudyFilter(tags=tags)) + _ = [s.owner for s in all_studies] + _ = [s.groups for s in all_studies] + _ = [s.additional_data for s in all_studies] + _ = [s.tags for s in all_studies] assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: