From 23ad643a69e2e58f79255560a767743fb7db5954 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Tue, 6 Feb 2024 19:22:51 +0100 Subject: [PATCH] test(tags-db): integration tests for study tags filtering and updating --- antarest/study/repository.py | 4 +- antarest/study/service.py | 6 +- .../studies_blueprint/test_get_studies.py | 171 +++++++++++++++--- 3 files changed, 151 insertions(+), 30 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 077d4ccd3c..7a6e99cae2 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -298,6 +298,8 @@ def update_tags(self, study: Study, new_tags: t.List[str]) -> None: """ logger.debug(f"Updating tags for study: {study.id}") - study.tags = [Tag(label=tag) for tag in new_tags] + existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all() + new_labels = set(new_tags) - set([tag.label for tag in existing_tags]) + study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags self.session.merge(study) self.session.commit() diff --git a/antarest/study/service.py b/antarest/study/service.py index 3d10b2bef1..20d4640ffe 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -551,6 +551,9 @@ def update_study_information( if metadata_patch.horizon: study.additional_data.horizon = metadata_patch.horizon + new_tags = metadata_patch.tags + self.repository.update_tags(study, new_tags) + new_metadata = self.storage_service.get_storage(study).patch_update_study_metadata(study, metadata_patch) self.event_bus.push( @@ -561,9 +564,6 @@ def update_study_information( ) ) - new_tags = new_metadata.tags - self.repository.update_tags(study, new_tags) - return new_metadata def check_study_access( diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index b134406e50..c9bf08bb62 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -266,24 +266,6 @@ def test_study_listing( assert res.status_code in CREATE_STATUS_CODES, res.json() archived_raw_850_id = res.json() - # create a variant study version 840 - res = client.post( - f"{STUDIES_URL}/{archived_raw_840_id}/variants", - headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"name": "archived-variant-840", "version": "840"}, - ) - assert res.status_code in CREATE_STATUS_CODES, res.json() - archived_variant_840_id = res.json() - - # create a variant study version 850 to be archived - res = client.post( - f"{STUDIES_URL}/{archived_raw_850_id}/variants", - headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"name": "archived-variant-850", "version": "850"}, - ) - assert res.status_code in CREATE_STATUS_CODES, res.json() - archived_variant_850_id = res.json() - # create a raw study to be transferred in folder1 zip_path = ASSETS_DIR / "STA-mini.zip" res = client.post( @@ -337,6 +319,120 @@ def test_study_listing( task = wait_task_completion(client, admin_access_token, archiving_study_task_id) assert task.status == TaskStatus.COMPLETED, task + # create a raw study version 840 to be tagged with `winter_transition` + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-raw-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_raw_840_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_raw_840_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["winter_transition"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-raw-840"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_raw_840_id).get("tags")) == {"winter_transition"} + + # create a raw study version 850 to be tagged with `decennial` + res = client.post( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "decennial-raw-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_raw_850_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_raw_850_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["decennial"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, params={"name": "decennial-raw-850"} + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_raw_850_id).get("tags")) == {"decennial"} + + # create a variant study version 840 to be tagged with `decennial` + res = client.post( + f"{STUDIES_URL}/{tagged_raw_840_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "decennial-variant-840", "version": "840"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_variant_840_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_variant_840_id}/generate", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + generation_task_id = res.json() + task = wait_task_completion(client, admin_access_token, generation_task_id) + assert task.status == TaskStatus.COMPLETED, task + res = client.put( + f"{STUDIES_URL}/{tagged_variant_840_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["decennial"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "decennial-variant-840"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_variant_840_id).get("tags")) == {"decennial"} + + # create a variant study version 850 to be tagged with `winter_transition` + res = client.post( + f"{STUDIES_URL}/{tagged_raw_850_id}/variants", + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-variant-850", "version": "850"}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + tagged_variant_850_id = res.json() + res = client.put( + f"{STUDIES_URL}/{tagged_variant_850_id}/generate", + headers={"Authorization": f"Bearer {admin_access_token}"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + generation_task_id = res.json() + task = wait_task_completion(client, admin_access_token, generation_task_id) + assert task.status == TaskStatus.COMPLETED, task + res = client.put( + f"{STUDIES_URL}/{tagged_variant_850_id}", + headers={"Authorization": f"Bearer {admin_access_token}"}, + json={"tags": ["winter_transition"]}, + ) + assert res.status_code in CREATE_STATUS_CODES, res.json() + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"name": "winter-transition-variant-850"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() + assert len(study_map) == 1 + assert set(study_map.get(tagged_variant_850_id).get("tags")) == {"winter_transition"} + + # ========================== + # 2. Filtering testing + # ========================== + # the testing studies set all_studies = { raw_840_id, @@ -350,10 +446,12 @@ def test_study_listing( variant_860_id, archived_raw_840_id, archived_raw_850_id, - archived_variant_840_id, - archived_variant_850_id, folder1_study_id, to_be_deleted_id, + tagged_raw_840_id, + tagged_raw_850_id, + tagged_variant_840_id, + tagged_variant_850_id, } pm = operator.itemgetter("public_mode") @@ -392,15 +490,14 @@ def test_study_listing( [e for k, e in study_map.items() if k not in james_bond_studies], ) ) - # #TODO you need to update the permission for James Bond bot + # #TODO you need to update the permission for James Bond bot above # test 1.d for a user bot with access to select studies res = client.get( STUDIES_URL, headers={"Authorization": f"Bearer {james_bond_bot_token}"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() - # #TODO add the correct test assertions # ] = res.json() # assert not set(james_bond_studies).difference(study_map) @@ -477,8 +574,8 @@ def test_study_listing( variant_840_id, variant_850_id, variant_860_id, - archived_variant_840_id, - archived_variant_850_id, + tagged_variant_840_id, + tagged_variant_850_id, } # test 5.a get variant studies res = client.get( @@ -507,7 +604,8 @@ def test_study_listing( non_managed_850_id, variant_850_id, archived_raw_850_id, - archived_variant_850_id, + tagged_variant_850_id, + tagged_raw_850_id, } studies_version_860 = { raw_860_id, @@ -579,10 +677,31 @@ def test_study_listing( assert not all_studies.difference(group_x_studies.union(group_y_studies)).intersection(study_map) assert not group_x_studies.union(group_y_studies).difference(study_map) - # TODO you need to add filtering through tags to the search engine # tests (9) for tags filtering # test 9.a filtering for one tag: decennial + decennial_tagged_studies = {tagged_raw_850_id, tagged_variant_840_id} + winter_transition_tagged_studies = {tagged_raw_840_id, tagged_variant_850_id} + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"tags": f"decennial"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference(decennial_tagged_studies).intersection(study_map) + assert not decennial_tagged_studies.difference(study_map) # test 9.b filtering for two tags: decennial,winter_transition + res = client.get( + STUDIES_URL, + headers={"Authorization": f"Bearer {admin_access_token}"}, + params={"tags": f"decennial,winter_transition"}, + ) + assert res.status_code == LIST_STATUS_CODE, res.json() + study_map = res.json() + assert not all_studies.difference( + decennial_tagged_studies.union(winter_transition_tagged_studies) + ).intersection(study_map) + assert not decennial_tagged_studies.union(winter_transition_tagged_studies).difference(study_map) # tests (10) for studies uuids sequence filtering # test 10.a filter for one uuid