Skip to content

Commit

Permalink
test(tags-db): integration tests for study tags filtering and updating
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Feb 7, 2024
1 parent dac11fd commit 23ad643
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 30 deletions.
4 changes: 3 additions & 1 deletion antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def update_tags(self, study: Study, new_tags: t.List[str]) -> None:
"""
logger.debug(f"Updating tags for study: {study.id}")
study.tags = [Tag(label=tag) for tag in new_tags]
existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all()
new_labels = set(new_tags) - set([tag.label for tag in existing_tags])
study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags
self.session.merge(study)
self.session.commit()
6 changes: 3 additions & 3 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,9 @@ def update_study_information(
if metadata_patch.horizon:
study.additional_data.horizon = metadata_patch.horizon

new_tags = metadata_patch.tags
self.repository.update_tags(study, new_tags)

new_metadata = self.storage_service.get_storage(study).patch_update_study_metadata(study, metadata_patch)

self.event_bus.push(
Expand All @@ -561,9 +564,6 @@ def update_study_information(
)
)

new_tags = new_metadata.tags
self.repository.update_tags(study, new_tags)

return new_metadata

def check_study_access(
Expand Down
171 changes: 145 additions & 26 deletions tests/integration/studies_blueprint/test_get_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,24 +266,6 @@ def test_study_listing(
assert res.status_code in CREATE_STATUS_CODES, res.json()
archived_raw_850_id = res.json()

# create a variant study version 840
res = client.post(
f"{STUDIES_URL}/{archived_raw_840_id}/variants",
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "archived-variant-840", "version": "840"},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
archived_variant_840_id = res.json()

# create a variant study version 850 to be archived
res = client.post(
f"{STUDIES_URL}/{archived_raw_850_id}/variants",
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "archived-variant-850", "version": "850"},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
archived_variant_850_id = res.json()

# create a raw study to be transferred in folder1
zip_path = ASSETS_DIR / "STA-mini.zip"
res = client.post(
Expand Down Expand Up @@ -337,6 +319,120 @@ def test_study_listing(
task = wait_task_completion(client, admin_access_token, archiving_study_task_id)
assert task.status == TaskStatus.COMPLETED, task

# create a raw study version 840 to be tagged with `winter_transition`
res = client.post(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "winter-transition-raw-840", "version": "840"},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
tagged_raw_840_id = res.json()
res = client.put(
f"{STUDIES_URL}/{tagged_raw_840_id}",
headers={"Authorization": f"Bearer {admin_access_token}"},
json={"tags": ["winter_transition"]},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "winter-transition-raw-840"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json()
assert len(study_map) == 1
assert set(study_map.get(tagged_raw_840_id).get("tags")) == {"winter_transition"}

# create a raw study version 850 to be tagged with `decennial`
res = client.post(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "decennial-raw-850", "version": "850"},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
tagged_raw_850_id = res.json()
res = client.put(
f"{STUDIES_URL}/{tagged_raw_850_id}",
headers={"Authorization": f"Bearer {admin_access_token}"},
json={"tags": ["decennial"]},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
res = client.get(
STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, params={"name": "decennial-raw-850"}
)
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json()
assert len(study_map) == 1
assert set(study_map.get(tagged_raw_850_id).get("tags")) == {"decennial"}

# create a variant study version 840 to be tagged with `decennial`
res = client.post(
f"{STUDIES_URL}/{tagged_raw_840_id}/variants",
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "decennial-variant-840", "version": "840"},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
tagged_variant_840_id = res.json()
res = client.put(
f"{STUDIES_URL}/{tagged_variant_840_id}/generate",
headers={"Authorization": f"Bearer {admin_access_token}"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
generation_task_id = res.json()
task = wait_task_completion(client, admin_access_token, generation_task_id)
assert task.status == TaskStatus.COMPLETED, task
res = client.put(
f"{STUDIES_URL}/{tagged_variant_840_id}",
headers={"Authorization": f"Bearer {admin_access_token}"},
json={"tags": ["decennial"]},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "decennial-variant-840"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json()
assert len(study_map) == 1
assert set(study_map.get(tagged_variant_840_id).get("tags")) == {"decennial"}

# create a variant study version 850 to be tagged with `winter_transition`
res = client.post(
f"{STUDIES_URL}/{tagged_raw_850_id}/variants",
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "winter-transition-variant-850", "version": "850"},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
tagged_variant_850_id = res.json()
res = client.put(
f"{STUDIES_URL}/{tagged_variant_850_id}/generate",
headers={"Authorization": f"Bearer {admin_access_token}"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
generation_task_id = res.json()
task = wait_task_completion(client, admin_access_token, generation_task_id)
assert task.status == TaskStatus.COMPLETED, task
res = client.put(
f"{STUDIES_URL}/{tagged_variant_850_id}",
headers={"Authorization": f"Bearer {admin_access_token}"},
json={"tags": ["winter_transition"]},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"name": "winter-transition-variant-850"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json()
assert len(study_map) == 1
assert set(study_map.get(tagged_variant_850_id).get("tags")) == {"winter_transition"}

# ==========================
# 2. Filtering testing
# ==========================

# the testing studies set
all_studies = {
raw_840_id,
Expand All @@ -350,10 +446,12 @@ def test_study_listing(
variant_860_id,
archived_raw_840_id,
archived_raw_850_id,
archived_variant_840_id,
archived_variant_850_id,
folder1_study_id,
to_be_deleted_id,
tagged_raw_840_id,
tagged_raw_850_id,
tagged_variant_840_id,
tagged_variant_850_id,
}

pm = operator.itemgetter("public_mode")
Expand Down Expand Up @@ -392,15 +490,14 @@ def test_study_listing(
[e for k, e in study_map.items() if k not in james_bond_studies],
)
)
# #TODO you need to update the permission for James Bond bot

# #TODO you need to update the permission for James Bond bot above
# test 1.d for a user bot with access to select studies
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {james_bond_bot_token}"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()

# #TODO add the correct test assertions
# ] = res.json()
# assert not set(james_bond_studies).difference(study_map)
Expand Down Expand Up @@ -477,8 +574,8 @@ def test_study_listing(
variant_840_id,
variant_850_id,
variant_860_id,
archived_variant_840_id,
archived_variant_850_id,
tagged_variant_840_id,
tagged_variant_850_id,
}
# test 5.a get variant studies
res = client.get(
Expand Down Expand Up @@ -507,7 +604,8 @@ def test_study_listing(
non_managed_850_id,
variant_850_id,
archived_raw_850_id,
archived_variant_850_id,
tagged_variant_850_id,
tagged_raw_850_id,
}
studies_version_860 = {
raw_860_id,
Expand Down Expand Up @@ -579,10 +677,31 @@ def test_study_listing(
assert not all_studies.difference(group_x_studies.union(group_y_studies)).intersection(study_map)
assert not group_x_studies.union(group_y_studies).difference(study_map)

# TODO you need to add filtering through tags to the search engine
# tests (9) for tags filtering
# test 9.a filtering for one tag: decennial
decennial_tagged_studies = {tagged_raw_850_id, tagged_variant_840_id}
winter_transition_tagged_studies = {tagged_raw_840_id, tagged_variant_850_id}
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"tags": f"decennial"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map = res.json()
assert not all_studies.difference(decennial_tagged_studies).intersection(study_map)
assert not decennial_tagged_studies.difference(study_map)
# test 9.b filtering for two tags: decennial,winter_transition
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"tags": f"decennial,winter_transition"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map = res.json()
assert not all_studies.difference(
decennial_tagged_studies.union(winter_transition_tagged_studies)
).intersection(study_map)
assert not decennial_tagged_studies.union(winter_transition_tagged_studies).difference(study_map)

# tests (10) for studies uuids sequence filtering
# test 10.a filter for one uuid
Expand Down

0 comments on commit 23ad643

Please sign in to comment.