Skip to content

Commit

Permalink
PR corrections done with DB migration
Browse files Browse the repository at this point in the history
  • Loading branch information
olfamizen committed Feb 12, 2024
1 parent 188099c commit 4fe8ae1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
31 changes: 29 additions & 2 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import datetime
import enum
import typing as t
from typing import List, Tuple, Union

from pydantic import BaseModel, NonNegativeInt
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy import select, func
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import ICache
Expand Down Expand Up @@ -273,9 +275,17 @@ def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]:
return studies

def delete(self, id: str) -> None:
def delete(self, ids: Union[str, List[str]]) -> None:
logger.debug(f"Deleting study {id}")
if isinstance(ids, str):
# if id is str, convert it to list with one element
ids = [ids]
session = self.session
u: Study = session.query(Study).get(id)
session.delete(u)
for study_id in ids:
study: Study = session.query(Study).get(study_id)
if study:
session.delete(study)
self._remove_study_from_cache_listing(study_id)
session.commit()

def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None:
Expand All @@ -292,3 +302,20 @@ def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None:
study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags
self.session.merge(study)
self.session.commit()

def list_duplicates(self) -> List[Tuple[str, str]]:
"""
Get list of duplicates as tuples (id, path).
"""
session = self.session
subquery = (
session.query(Study.path)
.group_by(Study.path)
.having(func.count()>1)
.subquery()
)
query = (
session.query(Study.id, Study.path)
.filter(Study.path.in_(subquery))
)
return query.all()
29 changes: 15 additions & 14 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import collections
import typing as t
import contextlib
import io
import json
Expand Down Expand Up @@ -696,20 +698,19 @@ def get_input_matrix_startdate(
return get_start_date(file_study, output_id, level)

def remove_duplicates(self) -> None:
study_paths: t.Dict[str, t.List[str]] = {}
for study in self.repository.get_all():
if isinstance(study, RawStudy) and not study.archived:
path = str(study.path)
if path not in study_paths:
study_paths[path] = []
study_paths[path].append(study.id)

for studies_with_same_path in study_paths.values():
if len(studies_with_same_path) > 1:
logger.info(f"Found studies {studies_with_same_path} with same path, de duplicating")
for study_name in studies_with_same_path[1:]:
logger.info(f"Removing study {study_name}")
self.repository.delete(study_name)
duplicates: List[Tuple[str, str]] = self.repository.list_duplicates()
ids: List[str] = []
# ids with same path
duplicates_by_path: t.Dict[str, t.List[str]] = collections.defaultdict(list)
for study_id, path in duplicates:
duplicates_by_path[path].append(study_id)
for path, study_ids in duplicates_by_path.items():
ids.extend(study_ids[1:])
# delete list ids
self.repository.delete(ids)
#db.session.query(RawStudy).filter(RawStudy.id.in_(ids)).delete(synchronize_session=False)
db.session.commit()


def sync_studies_on_disk(self, folders: t.List[StudyFolder], directory: t.Optional[Path] = None) -> None:
"""
Expand Down
27 changes: 18 additions & 9 deletions tests/storage/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,19 +350,28 @@ def test_partial_sync_studies_from_disk() -> None:
)


@pytest.mark.unit_test
def test_remove_duplicate() -> None:
ma = RawStudy(id="a", path="a")
mb = RawStudy(id="b", path="a")
@with_db_context
def test_remove_duplicate(db_session: Session) -> None:
with db_session:
db_session.add(RawStudy(id="a", path="/path/to/a"))
db_session.add(RawStudy(id="b", path="/path/to/a"))
db_session.add(RawStudy(id="c", path="/path/to/c"))
db_session.commit()
study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count()
assert study_count == 2 # there are 2 studies with same path before removing duplicates

repository = Mock()
repository.get_all.return_value = [ma, mb]
repository = StudyMetadataRepository(Mock(), db_session)
config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()}))
service = build_study_service(Mock(), repository, config)

service.remove_duplicates()
repository.delete.assert_called_once_with(mb.id)

# example with 1 duplicate with same path
with db_session:
study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count()
assert study_count == 1
# example with no duplicates with same path
with db_session:
study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/c").count()
assert study_count == 1

# noinspection PyArgumentList
@pytest.mark.unit_test
Expand Down

0 comments on commit 4fe8ae1

Please sign in to comment.