diff --git a/alembic/versions/fd73601a9075_add_delete_cascade_studies.py b/alembic/versions/fd73601a9075_add_delete_cascade_studies.py new file mode 100644 index 0000000000..3f9f42c684 --- /dev/null +++ b/alembic/versions/fd73601a9075_add_delete_cascade_studies.py @@ -0,0 +1,86 @@ +""" +Add delete cascade constraint to study foreign keys + +Revision ID: fd73601a9075 +Revises: 3c70366b10ea +Create Date: 2024-02-12 17:27:37.314443 +""" +import sqlalchemy as sa # type: ignore +from alembic import op + +# revision identifiers, used by Alembic. +revision = "fd73601a9075" +down_revision = "dae93f1d9110" +branch_labels = None +depends_on = None + +# noinspection SpellCheckingInspection +RAWSTUDY_FK = "rawstudy_id_fkey" + +# noinspection SpellCheckingInspection +VARIANTSTUDY_FK = "variantstudy_id_fkey" + +# noinspection SpellCheckingInspection +STUDY_ADDITIONAL_DATA_FK = "study_additional_data_study_id_fkey" + + +def upgrade() -> None: + dialect_name: str = op.get_context().dialect.name + + # SQLite doesn't support dropping foreign keys, so we need to ignore it here + if dialect_name == "postgresql": + with op.batch_alter_table("rawstudy", schema=None) as batch_op: + batch_op.drop_constraint(RAWSTUDY_FK, type_="foreignkey") + batch_op.create_foreign_key(RAWSTUDY_FK, "study", ["id"], ["id"], ondelete="CASCADE") + + with op.batch_alter_table("study_additional_data", schema=None) as batch_op: + batch_op.drop_constraint(STUDY_ADDITIONAL_DATA_FK, type_="foreignkey") + batch_op.create_foreign_key(STUDY_ADDITIONAL_DATA_FK, "study", ["study_id"], ["id"], ondelete="CASCADE") + + with op.batch_alter_table("variantstudy", schema=None) as batch_op: + batch_op.drop_constraint(VARIANTSTUDY_FK, type_="foreignkey") + batch_op.create_foreign_key(VARIANTSTUDY_FK, "study", ["id"], ["id"], ondelete="CASCADE") + + with op.batch_alter_table("group_metadata", schema=None) as batch_op: + batch_op.alter_column("group_id", existing_type=sa.VARCHAR(length=36), nullable=False) + batch_op.alter_column("study_id", existing_type=sa.VARCHAR(length=36), nullable=False) + batch_op.create_index(batch_op.f("ix_group_metadata_group_id"), ["group_id"], unique=False) + batch_op.create_index(batch_op.f("ix_group_metadata_study_id"), ["study_id"], unique=False) + if dialect_name == "postgresql": + batch_op.drop_constraint("group_metadata_group_id_fkey", type_="foreignkey") + batch_op.drop_constraint("group_metadata_study_id_fkey", type_="foreignkey") + batch_op.create_foreign_key( + "group_metadata_group_id_fkey", "groups", ["group_id"], ["id"], ondelete="CASCADE" + ) + batch_op.create_foreign_key( + "group_metadata_study_id_fkey", "study", ["study_id"], ["id"], ondelete="CASCADE" + ) + + +def downgrade() -> None: + dialect_name: str = op.get_context().dialect.name + # SQLite doesn't support dropping foreign keys, so we need to ignore it here + if dialect_name == "postgresql": + with op.batch_alter_table("rawstudy", schema=None) as batch_op: + batch_op.drop_constraint(RAWSTUDY_FK, type_="foreignkey") + batch_op.create_foreign_key(RAWSTUDY_FK, "study", ["id"], ["id"]) + + with op.batch_alter_table("study_additional_data", schema=None) as batch_op: + batch_op.drop_constraint(STUDY_ADDITIONAL_DATA_FK, type_="foreignkey") + batch_op.create_foreign_key(STUDY_ADDITIONAL_DATA_FK, "study", ["study_id"], ["id"]) + + with op.batch_alter_table("variantstudy", schema=None) as batch_op: + batch_op.drop_constraint(VARIANTSTUDY_FK, type_="foreignkey") + batch_op.create_foreign_key(VARIANTSTUDY_FK, "study", ["id"], ["id"]) + + with op.batch_alter_table("group_metadata", schema=None) as batch_op: + # SQLite doesn't support dropping foreign keys, so we need to ignore it here + if dialect_name == "postgresql": + batch_op.drop_constraint("group_metadata_study_id_fkey", type_="foreignkey") + batch_op.drop_constraint("group_metadata_group_id_fkey", type_="foreignkey") + batch_op.create_foreign_key("group_metadata_study_id_fkey", "study", ["study_id"], ["id"]) + batch_op.create_foreign_key("group_metadata_group_id_fkey", "groups", ["group_id"], ["id"]) + batch_op.drop_index(batch_op.f("ix_group_metadata_study_id")) + batch_op.drop_index(batch_op.f("ix_group_metadata_group_id")) + batch_op.alter_column("study_id", existing_type=sa.VARCHAR(length=36), nullable=True) + batch_op.alter_column("group_id", existing_type=sa.VARCHAR(length=36), nullable=True) diff --git a/antarest/study/model.py b/antarest/study/model.py index fe10b4f211..5079198296 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -16,7 +16,6 @@ Integer, PrimaryKeyConstraint, String, - Table, ) from sqlalchemy.orm import relationship # type: ignore @@ -50,12 +49,31 @@ NEW_DEFAULT_STUDY_VERSION: str = "860" -groups_metadata = Table( - "group_metadata", - Base.metadata, - Column("group_id", String(36), ForeignKey("groups.id")), - Column("study_id", String(36), ForeignKey("study.id")), -) + +class StudyGroup(Base): # type:ignore + """ + A table to manage the many-to-many relationship between `Study` and `Group` + + Attributes: + study_id: The ID of the study associated with the group. + group_id: The IS of the group associated with the study. + """ + + __tablename__ = "group_metadata" + __table_args__ = (PrimaryKeyConstraint("study_id", "group_id"),) + + group_id: str = Column(String(36), ForeignKey("groups.id", ondelete="CASCADE"), index=True, nullable=False) + study_id: str = Column(String(36), ForeignKey("study.id", ondelete="CASCADE"), index=True, nullable=False) + + def __str__(self) -> str: # pragma: no cover + cls_name = self.__class__.__name__ + return f"[{cls_name}] study_id={self.study_id}, group={self.group_id}" + + def __repr__(self) -> str: # pragma: no cover + cls_name = self.__class__.__name__ + study_id = self.study_id + group_id = self.group_id + return f"{cls_name}({study_id=}, {group_id=})" class StudyTag(Base): # type:ignore @@ -63,8 +81,8 @@ class StudyTag(Base): # type:ignore A table to manage the many-to-many relationship between `Study` and `Tag` Attributes: - study_id (str): The ID of the study associated with the tag. - tag_label (str): The label of the tag associated with the study. + study_id: The ID of the study associated with the tag. + tag_label: The label of the tag associated with the study. """ __tablename__ = "study_tag" @@ -74,7 +92,8 @@ class StudyTag(Base): # type:ignore tag_label: str = Column(String(40), ForeignKey("tag.label", ondelete="CASCADE"), index=True, nullable=False) def __str__(self) -> str: # pragma: no cover - return f"[StudyTag] study_id={self.study_id}, tag={self.tag}" + cls_name = self.__class__.__name__ + return f"[{cls_name}] study_id={self.study_id}, tag={self.tag}" def __repr__(self) -> str: # pragma: no cover cls_name = self.__class__.__name__ @@ -90,8 +109,8 @@ class Tag(Base): # type:ignore This class is used to store tags associated with studies. Attributes: - label (str): The label of the tag. - color (str): The color code associated with the tag. + label: The label of the tag. + color: The color code associated with the tag. """ __tablename__ = "tag" @@ -130,7 +149,7 @@ class StudyAdditionalData(Base): # type:ignore study_id = Column( String(36), - ForeignKey("study.id"), + ForeignKey("study.id", ondelete="CASCADE"), primary_key=True, ) author = Column(String(255), default="Unknown") @@ -174,7 +193,7 @@ class Study(Base): # type: ignore tags: t.List[Tag] = relationship(Tag, secondary=StudyTag.__table__, back_populates="studies") owner = relationship(Identity, uselist=False) - groups = relationship(Group, secondary=lambda: groups_metadata, cascade="") + groups = relationship(Group, secondary=StudyGroup.__table__, cascade="") additional_data = relationship( StudyAdditionalData, uselist=False, @@ -230,7 +249,7 @@ class RawStudy(Study): id = Column( String(36), - ForeignKey("study.id"), + ForeignKey("study.id", ondelete="CASCADE"), primary_key=True, ) content_status = Column(Enum(StudyContentStatus)) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 3aa6e60681..9d6c9317fb 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -138,7 +138,7 @@ def save( def refresh(self, metadata: Study) -> None: self.session.refresh(metadata) - def get(self, id: str) -> t.Optional[Study]: + def get(self, study_id: str) -> t.Optional[Study]: """Get the study by ID or return `None` if not found in database.""" # todo: I think we should use a `entity = with_polymorphic(Study, "*")` # to make sure RawStudy and VariantStudy fields are also fetched. @@ -146,13 +146,11 @@ def get(self, id: str) -> t.Optional[Study]: # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. study: Study = ( - # fmt: off self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) .options(joinedload(Study.tags)) - .get(id) - # fmt: on + .get(study_id) ) return study @@ -272,10 +270,10 @@ def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]: studies: t.Sequence[RawStudy] = query.all() return studies - def delete(self, id: str) -> None: + def delete(self, id_: str, *ids: str) -> None: + ids = (id_,) + ids session = self.session - u: Study = session.query(Study).get(id) - session.delete(u) + session.query(Study).filter(Study.id.in_(ids)).delete(synchronize_session=False) session.commit() def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None: @@ -292,3 +290,12 @@ def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None: study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags self.session.merge(study) self.session.commit() + + def list_duplicates(self) -> t.List[t.Tuple[str, str]]: + """ + Get list of duplicates as tuples (id, path). + """ + session = self.session + subquery = session.query(Study.path).group_by(Study.path).having(func.count() > 1).subquery() + query = session.query(Study.id, Study.path).filter(Study.path.in_(subquery)) + return t.cast(t.List[t.Tuple[str, str]], query.all()) diff --git a/antarest/study/service.py b/antarest/study/service.py index 910cf62027..9b22ae7638 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -1,4 +1,5 @@ import base64 +import collections import contextlib import io import json @@ -701,20 +702,16 @@ def get_input_matrix_startdate( return get_start_date(file_study, output_id, level) def remove_duplicates(self) -> None: - study_paths: t.Dict[str, t.List[str]] = {} - for study in self.repository.get_all(): - if isinstance(study, RawStudy) and not study.archived: - path = str(study.path) - if path not in study_paths: - study_paths[path] = [] - study_paths[path].append(study.id) - - for studies_with_same_path in study_paths.values(): - if len(studies_with_same_path) > 1: - logger.info(f"Found studies {studies_with_same_path} with same path, de duplicating") - for study_name in studies_with_same_path[1:]: - logger.info(f"Removing study {study_name}") - self.repository.delete(study_name) + duplicates = self.repository.list_duplicates() + ids: t.List[str] = [] + # ids with same path + duplicates_by_path = collections.defaultdict(list) + for study_id, path in duplicates: + duplicates_by_path[path].append(study_id) + for path, study_ids in duplicates_by_path.items(): + ids.extend(study_ids[1:]) + if ids: # Check if ids is not empty + self.repository.delete(*ids) def sync_studies_on_disk(self, folders: t.List[StudyFolder], directory: t.Optional[Path] = None) -> None: """ diff --git a/antarest/study/storage/variantstudy/model/dbmodel.py b/antarest/study/storage/variantstudy/model/dbmodel.py index bbe264f89f..9272eb797f 100644 --- a/antarest/study/storage/variantstudy/model/dbmodel.py +++ b/antarest/study/storage/variantstudy/model/dbmodel.py @@ -77,7 +77,7 @@ class VariantStudy(Study): id: str = Column( String(36), - ForeignKey("study.id"), + ForeignKey("study.id", ondelete="CASCADE"), primary_key=True, ) generation_task: t.Optional[str] = Column(String(), nullable=True) diff --git a/tests/integration/studies_blueprint/test_synthesis.py b/tests/integration/studies_blueprint/test_synthesis.py index 70f5f0c907..9afd66be9b 100644 --- a/tests/integration/studies_blueprint/test_synthesis.py +++ b/tests/integration/studies_blueprint/test_synthesis.py @@ -108,4 +108,4 @@ def test_variant_study( ) assert res.status_code == 200, res.json() duration = time.time() - start - assert 0 <= duration <= 0.1, f"Duration is {duration} seconds" + assert 0 <= duration <= 0.2, f"Duration is {duration} seconds" diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index f865ab613a..bb4ea795e3 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -1,22 +1,24 @@ from datetime import datetime +from sqlalchemy.orm import Session # type: ignore + from antarest.core.cache.business.local_chache import LocalCache from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyContentStatus from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy -from tests.helpers import with_db_context -@with_db_context -def test_lifecycle() -> None: - user = User(id=0, name="admin") +def test_lifecycle(db_session: Session) -> None: + repo = StudyMetadataRepository(LocalCache(), session=db_session) + + user = User(id=1, name="admin") group = Group(id="my-group", name="group") - repo = StudyMetadataRepository(LocalCache()) + a = Study( name="a", - version="42", + version="820", author="John Smith", created_at=datetime.utcnow(), updated_at=datetime.utcnow(), @@ -26,7 +28,7 @@ def test_lifecycle() -> None: ) b = RawStudy( name="b", - version="43", + version="830", author="Morpheus", created_at=datetime.utcnow(), updated_at=datetime.utcnow(), @@ -36,7 +38,7 @@ def test_lifecycle() -> None: ) c = RawStudy( name="c", - version="43", + version="830", author="Trinity", created_at=datetime.utcnow(), updated_at=datetime.utcnow(), @@ -47,7 +49,7 @@ def test_lifecycle() -> None: ) d = VariantStudy( name="d", - version="43", + version="830", author="Mr. Anderson", created_at=datetime.utcnow(), updated_at=datetime.utcnow(), @@ -57,30 +59,32 @@ def test_lifecycle() -> None: ) a = repo.save(a) - b = repo.save(b) + a_id = a.id + + repo.save(b) repo.save(c) repo.save(d) - assert b.id - c = repo.one(a.id) - assert a == c + + c = repo.one(a_id) + assert a_id == c.id assert len(repo.get_all()) == 4 assert len(repo.get_all_raw(exists=True)) == 1 assert len(repo.get_all_raw(exists=False)) == 1 assert len(repo.get_all_raw()) == 2 - repo.delete(a.id) - assert repo.get(a.id) is None + repo.delete(a_id) + assert repo.get(a_id) is None + +def test_study_inheritance(db_session: Session) -> None: + repo = StudyMetadataRepository(LocalCache(), session=db_session) -@with_db_context -def test_study_inheritance() -> None: user = User(id=0, name="admin") group = Group(id="my-group", name="group") - repo = StudyMetadataRepository(LocalCache()) a = RawStudy( name="a", - version="42", + version="820", author="John Smith", created_at=datetime.utcnow(), updated_at=datetime.utcnow(), diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index 12f61e6489..fa7ed5c62d 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -350,18 +350,30 @@ def test_partial_sync_studies_from_disk() -> None: ) -@pytest.mark.unit_test -def test_remove_duplicate() -> None: - ma = RawStudy(id="a", path="a") - mb = RawStudy(id="b", path="a") +@with_db_context +def test_remove_duplicate(db_session: Session) -> None: + with db_session: + db_session.add(RawStudy(id="a", path="/path/to/a")) + db_session.add(RawStudy(id="b", path="/path/to/a")) + db_session.add(RawStudy(id="c", path="/path/to/c")) + db_session.commit() + study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count() + assert study_count == 2 # there are 2 studies with same path before removing duplicates - repository = Mock() - repository.get_all.return_value = [ma, mb] - config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()})) - service = build_study_service(Mock(), repository, config) + with db_session: + repository = StudyMetadataRepository(Mock(), db_session) + config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()})) + service = build_study_service(Mock(), repository, config) + service.remove_duplicates() - service.remove_duplicates() - repository.delete.assert_called_once_with(mb.id) + # example with 1 duplicate with same path + with db_session: + study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count() + assert study_count == 1 + # example with no duplicates with same path + with db_session: + study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/c").count() + assert study_count == 1 # noinspection PyArgumentList