diff --git a/antarest/study/repository.py b/antarest/study/repository.py index ac7f730fca..deb92eb5e2 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -1,7 +1,9 @@ import datetime import logging import typing as t +from typing import List, Tuple, Union +from sqlalchemy import select, func from sqlalchemy import and_, or_ # type: ignore from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore @@ -135,13 +137,18 @@ def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: studies: t.List[RawStudy] = query.all() return studies - def delete(self, id: str) -> None: + def delete(self, ids: Union[str, List[str]]) -> None: logger.debug(f"Deleting study {id}") + if isinstance(ids, str): + # if id is str, convert it to list with one element + ids = [ids] session = self.session - u: Study = session.query(Study).get(id) - session.delete(u) + for study_id in ids: + study: Study = session.query(Study).get(study_id) + if study: + session.delete(study) + self._remove_study_from_cache_listing(study_id) session.commit() - self._remove_study_from_cache_listing(id) def _remove_study_from_cache_listing(self, study_id: str) -> None: try: @@ -172,3 +179,20 @@ def _update_study_from_cache_listing(self, study: Study) -> None: self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value) except Exception as e: logger.error("Failed to invalidate listing cache", exc_info=e) + + def list_duplicates(self) -> List[Tuple[str, str]]: + """ + Get list of duplicates as tuples (id, path). + """ + session = self.session + subquery = ( + session.query(Study.path) + .group_by(Study.path) + .having(func.count()>1) + .subquery() + ) + query = ( + session.query(Study.id, Study.path) + .filter(Study.path.in_(subquery)) + ) + return query.all() diff --git a/antarest/study/service.py b/antarest/study/service.py index 569d41ad1d..ebcd04d9d4 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -1,4 +1,6 @@ import base64 +import collections +import typing as t import contextlib import io import json @@ -697,20 +699,19 @@ def get_input_matrix_startdate(self, study_id: str, path: Optional[str], params: return get_start_date(file_study, output_id, level) def remove_duplicates(self) -> None: - study_paths: Dict[str, List[str]] = {} - for study in self.repository.get_all(): - if isinstance(study, RawStudy) and not study.archived: - path = str(study.path) - if path not in study_paths: - study_paths[path] = [] - study_paths[path].append(study.id) - - for studies_with_same_path in study_paths.values(): - if len(studies_with_same_path) > 1: - logger.info(f"Found studies {studies_with_same_path} with same path, de duplicating") - for study_name in studies_with_same_path[1:]: - logger.info(f"Removing study {study_name}") - self.repository.delete(study_name) + duplicates: List[Tuple[str, str]] = self.repository.list_duplicates() + ids: List[str] = [] + # ids with same path + duplicates_by_path: t.Dict[str, t.List[str]] = collections.defaultdict(list) + for study_id, path in duplicates: + duplicates_by_path[path].append(study_id) + for path, study_ids in duplicates_by_path.items(): + ids.extend(study_ids[1:]) + # delete list ids + self.repository.delete(ids) + #db.session.query(RawStudy).filter(RawStudy.id.in_(ids)).delete(synchronize_session=False) + db.session.commit() + def sync_studies_on_disk(self, folders: List[StudyFolder], directory: Optional[Path] = None) -> None: """ diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index b509445052..04356dc5f7 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -334,19 +334,28 @@ def test_partial_sync_studies_from_disk() -> None: ) -@pytest.mark.unit_test -def test_remove_duplicate() -> None: - ma = RawStudy(id="a", path="a") - mb = RawStudy(id="b", path="a") +@with_db_context +def test_remove_duplicate(db_session: Session) -> None: + with db_session: + db_session.add(RawStudy(id="a", path="/path/to/a")) + db_session.add(RawStudy(id="b", path="/path/to/a")) + db_session.add(RawStudy(id="c", path="/path/to/c")) + db_session.commit() + study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count() + assert study_count == 2 # there are 2 studies with same path before removing duplicates - repository = Mock() - repository.get_all.return_value = [ma, mb] + repository = StudyMetadataRepository(Mock(), db_session) config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()})) service = build_study_service(Mock(), repository, config) - service.remove_duplicates() - repository.delete.assert_called_once_with(mb.id) - + # example with 1 duplicate with same path + with db_session: + study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count() + assert study_count == 1 + # example with no duplicates with same path + with db_session: + study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/c").count() + assert study_count == 1 # noinspection PyArgumentList @pytest.mark.unit_test