Skip to content

Commit

Permalink
PR corrections done with DB migration
Browse files Browse the repository at this point in the history
  • Loading branch information
olfamizen committed Feb 12, 2024
1 parent 188099c commit a0ec27f
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 32 deletions.
50 changes: 50 additions & 0 deletions alembic/versions/fd73601a9075_add_delete_cascade_studies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""add-delete-cascade-studies
Revision ID: fd73601a9075
Revises: 3c70366b10ea
Create Date: 2024-02-12 17:27:37.314443
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'fd73601a9075'
down_revision = '3c70366b10ea'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('rawstudy', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key(None, 'study', ['id'], ['id'], ondelete='CASCADE')

with op.batch_alter_table('study_additional_data', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key(None, 'study', ['study_id'], ['id'], ondelete='CASCADE')

with op.batch_alter_table('variantstudy', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key(None, 'study', ['id'], ['id'], ondelete='CASCADE')

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('variantstudy', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key(None, 'study', ['id'], ['id'])

with op.batch_alter_table('study_additional_data', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key(None, 'study', ['study_id'], ['id'])

with op.batch_alter_table('rawstudy', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key(None, 'study', ['id'], ['id'])

# ### end Alembic commands ###
4 changes: 2 additions & 2 deletions antarest/study/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class StudyAdditionalData(Base): # type:ignore

study_id = Column(
String(36),
ForeignKey("study.id"),
ForeignKey("study.id", ondelete="CASCADE"),
primary_key=True,
)
author = Column(String(255), default="Unknown")
Expand Down Expand Up @@ -230,7 +230,7 @@ class RawStudy(Study):

id = Column(
String(36),
ForeignKey("study.id"),
ForeignKey("study.id", ondelete="CASCADE"),
primary_key=True,
)
content_status = Column(Enum(StudyContentStatus))
Expand Down
20 changes: 16 additions & 4 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing as t

from pydantic import BaseModel, NonNegativeInt
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy import func, not_, or_, select # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import ICache
Expand Down Expand Up @@ -272,10 +272,13 @@ def get_all_raw(self, exists: t.Optional[bool] = None) -> t.Sequence[RawStudy]:
studies: t.Sequence[RawStudy] = query.all()
return studies

def delete(self, id: str) -> None:
def delete(self, id_: str, *ids: str) -> None:
ids = (id_,) + ids
session = self.session
u: Study = session.query(Study).get(id)
session.delete(u)
"""for study_id in ids:
study: Study = session.query(Study).get(study_id)
session.delete(study)"""
session.query(Study).filter(Study.id.in_(ids)).delete(synchronize_session=False)
session.commit()

def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None:
Expand All @@ -292,3 +295,12 @@ def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None:
study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags
self.session.merge(study)
self.session.commit()

def list_duplicates(self) -> t.List[t.Tuple[str, str]]:
"""
Get list of duplicates as tuples (id, path).
"""
session = self.session
subquery = session.query(Study.path).group_by(Study.path).having(func.count() > 1).subquery()
query = session.query(Study.id, Study.path).filter(Study.path.in_(subquery))
return t.cast(t.List[t.Tuple[str, str]], query.all())
25 changes: 11 additions & 14 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import collections
import contextlib
import io
import json
Expand Down Expand Up @@ -696,20 +697,16 @@ def get_input_matrix_startdate(
return get_start_date(file_study, output_id, level)

def remove_duplicates(self) -> None:
study_paths: t.Dict[str, t.List[str]] = {}
for study in self.repository.get_all():
if isinstance(study, RawStudy) and not study.archived:
path = str(study.path)
if path not in study_paths:
study_paths[path] = []
study_paths[path].append(study.id)

for studies_with_same_path in study_paths.values():
if len(studies_with_same_path) > 1:
logger.info(f"Found studies {studies_with_same_path} with same path, de duplicating")
for study_name in studies_with_same_path[1:]:
logger.info(f"Removing study {study_name}")
self.repository.delete(study_name)
duplicates = self.repository.list_duplicates()
ids: t.List[str] = []
# ids with same path
duplicates_by_path = collections.defaultdict(list)
for study_id, path in duplicates:
duplicates_by_path[path].append(study_id)
for path, study_ids in duplicates_by_path.items():
ids.extend(study_ids[1:])
# delete list ids
self.repository.delete(*ids)

def sync_studies_on_disk(self, folders: t.List[StudyFolder], directory: t.Optional[Path] = None) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion antarest/study/storage/variantstudy/model/dbmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class VariantStudy(Study):

id: str = Column(
String(36),
ForeignKey("study.id"),
ForeignKey("study.id", ondelete="CASCADE"),
primary_key=True,
)
generation_task: t.Optional[str] = Column(String(), nullable=True)
Expand Down
2 changes: 1 addition & 1 deletion scripts/rollback.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ CUR_DIR=$(cd "$(dirname "$0")" && pwd)
BASE_DIR=$(dirname "$CUR_DIR")

cd "$BASE_DIR"
alembic downgrade 1f5db5dfad80
alembic downgrade 3c70366b10ea
cd -
32 changes: 22 additions & 10 deletions tests/storage/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,18 +350,30 @@ def test_partial_sync_studies_from_disk() -> None:
)


@pytest.mark.unit_test
def test_remove_duplicate() -> None:
ma = RawStudy(id="a", path="a")
mb = RawStudy(id="b", path="a")
@with_db_context
def test_remove_duplicate(db_session: Session) -> None:
with db_session:
db_session.add(RawStudy(id="a", path="/path/to/a"))
db_session.add(RawStudy(id="b", path="/path/to/a"))
db_session.add(RawStudy(id="c", path="/path/to/c"))
db_session.commit()
study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count()
assert study_count == 2 # there are 2 studies with same path before removing duplicates

repository = Mock()
repository.get_all.return_value = [ma, mb]
config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()}))
service = build_study_service(Mock(), repository, config)
with db_session:
repository = StudyMetadataRepository(Mock(), db_session)
config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()}))
service = build_study_service(Mock(), repository, config)
service.remove_duplicates()

service.remove_duplicates()
repository.delete.assert_called_once_with(mb.id)
# example with 1 duplicate with same path
with db_session:
study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/a").count()
assert study_count == 1
# example with no duplicates with same path
with db_session:
study_count = db_session.query(RawStudy).filter(RawStudy.path == "/path/to/c").count()
assert study_count == 1


# noinspection PyArgumentList
Expand Down

0 comments on commit a0ec27f

Please sign in to comment.