diff --git a/alembic/versions/dae93f1d9110_populate_tag_and_study_tag_tables_with_.py b/alembic/versions/dae93f1d9110_populate_tag_and_study_tag_tables_with_.py index c6c29d3716..6fbb060115 100644 --- a/alembic/versions/dae93f1d9110_populate_tag_and_study_tag_tables_with_.py +++ b/alembic/versions/dae93f1d9110_populate_tag_and_study_tag_tables_with_.py @@ -9,6 +9,7 @@ import itertools import json import secrets +import typing as t import sqlalchemy as sa # type: ignore from alembic import op @@ -23,6 +24,22 @@ depends_on = None +def _avoid_duplicates(tags: t.Iterable[str]) -> t.Sequence[str]: + """Avoid duplicate tags (case insensitive)""" + + upper_tags = {tag.upper(): tag for tag in tags} + return list(upper_tags.values()) + + +def _load_patch_obj(patch: t.Optional[str]) -> t.MutableMapping[str, t.Any]: + """Load the patch object from the `patch` field in the `study_additional_data` table.""" + + obj: t.MutableMapping[str, t.Any] = json.loads(patch or "{}") + obj["study"] = obj.get("study") or {} + obj["study"]["tags"] = _avoid_duplicates(obj["study"].get("tags") or []) + return obj + + def upgrade() -> None: """ Populate `tag` and `study_tag` tables from `patch` field in `study_additional_data` table @@ -39,27 +56,31 @@ def upgrade() -> None: connexion: Connection = op.get_bind() # retrieve the tags and the study-tag pairs from the db - study_tags = connexion.execute("SELECT study_id,patch FROM study_additional_data") - tags_by_ids = {} + study_tags = connexion.execute("SELECT study_id, patch FROM study_additional_data") + tags_by_ids: t.MutableMapping[str, t.Set[str]] = {} for study_id, patch in study_tags: - obj = json.loads(patch or "{}") - study = obj.get("study") or {} - tags = frozenset(study.get("tags") or ()) - tags_by_ids[study_id] = tags + obj = _load_patch_obj(patch) + tags_by_ids[study_id] = obj["study"]["tags"] # delete rows in tables `tag` and `study_tag` connexion.execute("DELETE FROM study_tag") connexion.execute("DELETE FROM tag") # insert the tags in the `tag` table - labels = set(itertools.chain.from_iterable(tags_by_ids.values())) - bulk_tags = [{"label": label, "color": secrets.choice(COLOR_NAMES)} for label in labels] + all_labels = {lbl.upper(): lbl for lbl in itertools.chain.from_iterable(tags_by_ids.values())} + bulk_tags = [{"label": label, "color": secrets.choice(COLOR_NAMES)} for label in all_labels.values()] if bulk_tags: sql = sa.text("INSERT INTO tag (label, color) VALUES (:label, :color)") connexion.execute(sql, *bulk_tags) # Create relationships between studies and tags in the `study_tag` table - bulk_study_tags = [{"study_id": id_, "tag_label": lbl} for id_, tags in tags_by_ids.items() for lbl in tags] + bulk_study_tags = [ + # fmt: off + {"study_id": id_, "tag_label": all_labels[lbl.upper()]} + for id_, tags in tags_by_ids.items() + for lbl in tags + # fmt: on + ] if bulk_study_tags: sql = sa.text("INSERT INTO study_tag (study_id, tag_label) VALUES (:study_id, :tag_label)") connexion.execute(sql, *bulk_study_tags) @@ -78,7 +99,7 @@ def downgrade() -> None: connexion: Connection = op.get_bind() # Creating the `tags_by_ids` mapping from data in the `study_tags` table - tags_by_ids = collections.defaultdict(set) + tags_by_ids: t.MutableMapping[str, t.Set[str]] = collections.defaultdict(set) study_tags = connexion.execute("SELECT study_id, tag_label FROM study_tag") for study_id, tag_label in study_tags: tags_by_ids[study_id].add(tag_label) @@ -87,10 +108,8 @@ def downgrade() -> None: objects_by_ids = {} study_tags = connexion.execute("SELECT study_id, patch FROM study_additional_data") for study_id, patch in study_tags: - obj = json.loads(patch or "{}") - obj["study"] = obj.get("study") or {} - obj["study"]["tags"] = obj["study"].get("tags") or [] - obj["study"]["tags"] = sorted(tags_by_ids[study_id] | set(obj["study"]["tags"])) + obj = _load_patch_obj(patch) + obj["study"]["tags"] = _avoid_duplicates(tags_by_ids[study_id] | set(obj["study"]["tags"])) objects_by_ids[study_id] = obj # Updating objects in the `study_additional_data` table diff --git a/antarest/study/model.py b/antarest/study/model.py index 5079198296..a102f327ac 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta from pathlib import Path -from pydantic import BaseModel +from pydantic import BaseModel, validator from sqlalchemy import ( # type: ignore Boolean, Column, @@ -351,7 +351,18 @@ class StudyMetadataPatchDTO(BaseModel): scenario: t.Optional[str] = None status: t.Optional[str] = None doc: t.Optional[str] = None - tags: t.List[str] = [] + tags: t.Sequence[str] = () + + @validator("tags", each_item=True) + def _normalize_tags(cls, v: str) -> str: + """Remove leading and trailing whitespaces, and replace consecutive whitespaces by a single one.""" + tag = " ".join(v.split()) + if not tag: + raise ValueError("Tag cannot be empty") + elif len(tag) > 40: + raise ValueError(f"Tag is too long: {tag!r}") + else: + return tag class StudySimSettingsDTO(BaseModel): diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 9d6c9317fb..25cdd77dd7 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -222,7 +222,8 @@ def get_all( if study_filter.groups: q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups)) if study_filter.tags: - q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags)) + upper_tags = [tag.upper() for tag in study_filter.tags] + q = q.join(entity.tags).filter(func.upper(Tag.label).in_(upper_tags)) if study_filter.archived is not None: q = q.filter(entity.archived == study_filter.archived) if study_filter.name: @@ -279,17 +280,25 @@ def delete(self, id_: str, *ids: str) -> None: def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None: """ Updates the tags associated with a given study in the database, - replacing existing tags with new ones. + replacing existing tags with new ones (case-insensitive). Args: study: The pre-existing study to be updated with the new tags. new_tags: The new tags to be associated with the input study in the database. """ - existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all() - new_labels = set(new_tags) - set([tag.label for tag in existing_tags]) - study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags - self.session.merge(study) - self.session.commit() + new_upper_tags = {tag.upper(): tag for tag in new_tags} + session = self.session + existing_tags = session.query(Tag).filter(func.upper(Tag.label).in_(new_upper_tags)).all() + for tag in existing_tags: + if tag.label.upper() in new_upper_tags: + new_upper_tags.pop(tag.label.upper()) + study.tags = [Tag(label=tag) for tag in new_upper_tags.values()] + existing_tags + session.merge(study) + session.commit() + # Delete any tag that is not associated with any study. + # Note: If tags are to be associated with objects other than Study, this code must be updated. + session.query(Tag).filter(~Tag.studies.any()).delete(synchronize_session=False) # type: ignore + session.commit() def list_duplicates(self) -> t.List[t.Tuple[str, str]]: """ diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index 54292345f5..579ad2dfe7 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -319,7 +319,7 @@ def test_study_listing( task = wait_task_completion(client, admin_access_token, archiving_study_task_id) assert task.status == TaskStatus.COMPLETED, task - # create a raw study version 840 to be tagged with `winter_transition` + # create a raw study version 840 to be tagged with `Winter_Transition` res = client.post( STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, @@ -330,7 +330,7 @@ def test_study_listing( res = client.put( f"{STUDIES_URL}/{tagged_raw_840_id}", headers={"Authorization": f"Bearer {admin_access_token}"}, - json={"tags": ["winter_transition"]}, + json={"tags": ["Winter_Transition"]}, ) assert res.status_code in CREATE_STATUS_CODES, res.json() res = client.get( @@ -341,7 +341,7 @@ def test_study_listing( assert res.status_code == LIST_STATUS_CODE, res.json() study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json() assert len(study_map) == 1 - assert set(study_map[tagged_raw_840_id]["tags"]) == {"winter_transition"} + assert set(study_map[tagged_raw_840_id]["tags"]) == {"Winter_Transition"} # create a raw study version 850 to be tagged with `decennial` res = client.post( @@ -391,7 +391,8 @@ def test_study_listing( assert len(study_map) == 1 assert set(study_map[tagged_variant_840_id]["tags"]) == {"decennial"} - # create a variant study version 850 to be tagged with `winter_transition` + # create a variant study version 850 to be tagged with `winter_transition`. + # also test that the tag label is case-insensitive. res = client.post( f"{STUDIES_URL}/{tagged_raw_850_id}/variants", headers={"Authorization": f"Bearer {admin_access_token}"}, @@ -402,7 +403,7 @@ def test_study_listing( res = client.put( f"{STUDIES_URL}/{tagged_variant_850_id}", headers={"Authorization": f"Bearer {admin_access_token}"}, - json={"tags": ["winter_transition"]}, + json={"tags": ["winter_transition"]}, # note the tag label is in lower case ) assert res.status_code in CREATE_STATUS_CODES, res.json() res = client.get( @@ -413,7 +414,7 @@ def test_study_listing( assert res.status_code == LIST_STATUS_CODE, res.json() study_map = res.json() assert len(study_map) == 1 - assert set(study_map[tagged_variant_850_id]["tags"]) == {"winter_transition"} + assert set(study_map[tagged_variant_850_id]["tags"]) == {"Winter_Transition"} # ========================== # 2. Filtering testing @@ -670,7 +671,7 @@ def test_study_listing( res = client.get( STUDIES_URL, headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"tags": "decennial"}, + params={"tags": "DECENNIAL"}, ) assert res.status_code == LIST_STATUS_CODE, res.json() study_map = res.json() diff --git a/tests/integration/studies_blueprint/test_update_tags.py b/tests/integration/studies_blueprint/test_update_tags.py new file mode 100644 index 0000000000..a65ece2f11 --- /dev/null +++ b/tests/integration/studies_blueprint/test_update_tags.py @@ -0,0 +1,95 @@ +from starlette.testclient import TestClient + + +class TestupdateStudyMetadata: + """ + Test the study tags update through the `update_study_metadata` API endpoint. + """ + + def test_update_tags( + self, + client: TestClient, + user_access_token: str, + study_id: str, + ) -> None: + """ + This test verifies that we can update the tags of a study. + It also tests the tags normalization. + """ + + # Classic usage: set some tags to a study + study_tags = ["Tag1", "Tag2"] + res = client.put( + f"/v1/studies/{study_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"tags": study_tags}, + ) + assert res.status_code == 200, res.json() + actual = res.json() + assert set(actual["tags"]) == set(study_tags) + + # Update the tags with already existing tags (case-insensitive): + # - "Tag1" is preserved, but with the same case as the existing one. + # - "Tag2" is replaced by "Tag3". + study_tags = ["tag1", "Tag3"] + res = client.put( + f"/v1/studies/{study_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"tags": study_tags}, + ) + assert res.status_code == 200, res.json() + actual = res.json() + assert set(actual["tags"]) != set(study_tags) # not the same case + assert set(tag.upper() for tag in actual["tags"]) == {"TAG1", "TAG3"} + + # String normalization: whitespaces are stripped and + # consecutive whitespaces are replaced by a single one. + study_tags = [" \xa0Foo \t Bar \n ", " \t Baz\xa0\xa0"] + res = client.put( + f"/v1/studies/{study_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"tags": study_tags}, + ) + assert res.status_code == 200, res.json() + actual = res.json() + assert set(actual["tags"]) == {"Foo Bar", "Baz"} + + # We can have symbols in the tags + study_tags = ["Foo-Bar", ":Baz%"] + res = client.put( + f"/v1/studies/{study_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"tags": study_tags}, + ) + assert res.status_code == 200, res.json() + actual = res.json() + assert set(actual["tags"]) == {"Foo-Bar", ":Baz%"} + + def test_update_tags__invalid_tags( + self, + client: TestClient, + user_access_token: str, + study_id: str, + ) -> None: + # We cannot have empty tags + study_tags = [""] + res = client.put( + f"/v1/studies/{study_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"tags": study_tags}, + ) + assert res.status_code == 422, res.json() + description = res.json()["description"] + assert "Tag cannot be empty" in description + + # We cannot have tags longer than 40 characters + study_tags = ["very long tags, very long tags, very long tags"] + assert len(study_tags[0]) > 40 + res = client.put( + f"/v1/studies/{study_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"tags": study_tags}, + ) + assert res.status_code == 422, res.json() + description = res.json()["description"] + assert "Tag is too long" in description diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index d30c051a6a..0a6063fac5 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -631,7 +631,7 @@ def test_repository_get_all__study_tags_filter( test_tag_1 = Tag(label="hidden-tag") test_tag_2 = Tag(label="decennial") - test_tag_3 = Tag(label="winter_transition") + test_tag_3 = Tag(label="Winter_Transition") # note the different case study_1 = VariantStudy(id=1, tags=[test_tag_1]) study_2 = VariantStudy(id=2, tags=[test_tag_2]) @@ -655,7 +655,41 @@ def test_repository_get_all__study_tags_filter( _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] _ = [s.tags for s in all_studies] + assert len(db_recorder.sql_statements) == 1, str(db_recorder) if expected_ids is not None: assert {s.id for s in all_studies} == expected_ids + + +def test_update_tags( + db_session: Session, +) -> None: + icache: Mock = Mock(spec=ICache) + repository = StudyMetadataRepository(cache_service=icache, session=db_session) + + study_id = 1 + study = RawStudy(id=study_id, tags=[]) + db_session.add(study) + db_session.commit() + + # use the db recorder to check that: + # 1- finding existing tags requires 1 query + # 2- updating the study tags requires 4 queries (2 selects, 2 inserts) + # 3- deleting orphan tags requires 1 query + with DBStatementRecorder(db_session.bind) as db_recorder: + repository.update_tags(study, ["Tag1", "Tag2"]) + assert len(db_recorder.sql_statements) == 6, str(db_recorder) + + # Check that when we change the tags to ["TAG1", "Tag3"], + # "Tag1" is preserved, "Tag2" is deleted and "Tag3" is created + # 1- finding existing tags requires 1 query + # 2- updating the study tags requires 4 queries (2 selects, 2 inserts, 1 delete) + # 3- deleting orphan tags requires 1 query + with DBStatementRecorder(db_session.bind) as db_recorder: + repository.update_tags(study, ["TAG1", "Tag3"]) + assert len(db_recorder.sql_statements) == 7, str(db_recorder) + + # Check that only "Tag1" and "Tag3" are present in the database + tags = db_session.query(Tag).all() + assert {tag.label for tag in tags} == {"Tag1", "Tag3"} diff --git a/webapp/src/utils/studiesUtils.ts b/webapp/src/utils/studiesUtils.ts index 1fa14e7698..4072e09c51 100644 --- a/webapp/src/utils/studiesUtils.ts +++ b/webapp/src/utils/studiesUtils.ts @@ -64,7 +64,9 @@ const tagsPredicate = R.curry( if (!study.tags || study.tags.length === 0) { return false; } - return R.intersection(study.tags, tags).length > 0; + const upperCaseTags = tags.map((tag) => tag.toUpperCase()); + const upperCaseStudyTags = study.tags.map((tag) => tag.toUpperCase()); + return R.intersection(upperCaseStudyTags, upperCaseTags).length > 0; }, );