Skip to content

Commit

Permalink
test(tags-db): unittests for repository tags filter
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Feb 6, 2024
1 parent e561322 commit 10293e0
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 14 deletions.
2 changes: 1 addition & 1 deletion antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def get_all(
if study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
if study_filter.tags:
q = q.join(entity.tags).filter(Tag.id.in_(study_filter.tags))
q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags))
if study_filter.archived is not None:
q = q.filter(entity.archived == study_filter.archived)
if study_filter.name:
Expand Down
92 changes: 79 additions & 13 deletions tests/study/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from antarest.core.interfaces.cache import ICache
from antarest.login.model import Group, User
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Tag
from antarest.study.repository import StudyFilter, StudyMetadataRepository
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy
from tests.db_statement_recorder import DBStatementRecorder
Expand Down Expand Up @@ -63,13 +63,14 @@ def test_repository_get_all__general_case(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -102,6 +103,7 @@ def test_repository_get_all__incompatible_case(
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
assert not {s.id for s in all_studies}

Expand All @@ -112,6 +114,7 @@ def test_repository_get_all__incompatible_case(
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
assert not {s.id for s in all_studies}

Expand All @@ -122,6 +125,7 @@ def test_repository_get_all__incompatible_case(
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)
assert not {s.id for s in all_studies}

Expand Down Expand Up @@ -162,13 +166,14 @@ def test_repository_get_all__study_name_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(name=name))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -206,13 +211,14 @@ def test_repository_get_all__managed_study_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(managed=managed))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -245,13 +251,14 @@ def test_repository_get_all__archived_study_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(archived=archived))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -284,13 +291,14 @@ def test_repository_get_all__variant_study_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(variant=variant))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -325,13 +333,14 @@ def test_repository_get_all__study_version_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(versions=versions))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -372,13 +381,14 @@ def test_repository_get_all__study_users_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(users=users))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -419,13 +429,14 @@ def test_repository_get_all__study_groups_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(groups=groups))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -461,13 +472,14 @@ def test_repository_get_all__study_ids_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(study_ids=study_ids))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -500,13 +512,14 @@ def test_repository_get_all__study_existence_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(exists=exists))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -540,13 +553,14 @@ def test_repository_get_all__study_workspace_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(workspace=workspace))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down Expand Up @@ -582,13 +596,65 @@ def test_repository_get_all__study_folder_filter(

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does require additional queries to db
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(folder=folder))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]

assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
assert {s.id for s in all_studies} == expected_ids


@pytest.mark.parametrize(
"tags, expected_ids",
[
([], {"1", "2", "3", "4", "5", "6", "7", "8"}),
(["decennial"], {"2", "4", "6", "8"}),
(["winter_transition"], {"3", "4", "7", "8"}),
(["decennial", "winter_transition"], {"2", "3", "4", "6", "7", "8"}),
(["no-study-tag"], set()),
],
)
def test_repository_get_all__study_tags_filter(
db_session: Session,
tags: t.List[str],
expected_ids: t.Set[str],
) -> None:
icache: Mock = Mock(spec=ICache)
repository = StudyMetadataRepository(cache_service=icache, session=db_session)

test_tag_1 = Tag(label="hidden-tag")
test_tag_2 = Tag(label="decennial")
test_tag_3 = Tag(label="winter_transition")

study_1 = VariantStudy(id=1, tags=[test_tag_1])
study_2 = VariantStudy(id=2, tags=[test_tag_2])
study_3 = VariantStudy(id=3, tags=[test_tag_3])
study_4 = VariantStudy(id=4, tags=[test_tag_2, test_tag_3])
study_5 = RawStudy(id=5, tags=[test_tag_1])
study_6 = RawStudy(id=6, tags=[test_tag_2])
study_7 = RawStudy(id=7, tags=[test_tag_3])
study_8 = RawStudy(id=8, tags=[test_tag_2, test_tag_3])

db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8])
db_session.commit()

# use the db recorder to check that:
# 1- retrieving all studies requires only 1 query
# 2- accessing studies attributes does not require additional queries to db
# 3- having an exact total of queries equals to 1
with DBStatementRecorder(db_session.bind) as db_recorder:
all_studies = repository.get_all(study_filter=StudyFilter(tags=tags))
_ = [s.owner for s in all_studies]
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]
assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
Expand Down

0 comments on commit 10293e0

Please sign in to comment.