Skip to content

Commit

Permalink
feature(study-search): add tag and study_tag tables to the db (mi…
Browse files Browse the repository at this point in the history
…gration), update tags related services and endpoints
  • Loading branch information
mabw-rte committed Jan 31, 2024
1 parent dc817cf commit 42a0575
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 3 deletions.
56 changes: 56 additions & 0 deletions alembic/versions/6a6634ed2c2f_add_tag_and_study_tag_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Add tag and study_tag tables
Revision ID: 6a6634ed2c2f
Revises: 1f5db5dfad80
Create Date: 2024-01-30 18:07:11.116834
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '6a6634ed2c2f'
down_revision = '1f5db5dfad80'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tag',
sa.Column('label', sa.String(), nullable=False),
sa.Column('color', sa.String(length=7), nullable=True),
sa.PrimaryKeyConstraint('label')
)
with op.batch_alter_table('tag', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_tag_color'), ['color'], unique=False)
batch_op.create_index(batch_op.f('ix_tag_label'), ['label'], unique=False)

op.create_table('study_tag',
sa.Column('study_id', sa.String(length=36), nullable=False),
sa.Column('tag', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['study_id'], ['study.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['tag'], ['tag.label'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('study_id', 'tag')
)
with op.batch_alter_table('study_tag', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_study_tag_study_id'), ['study_id'], unique=False)
batch_op.create_index(batch_op.f('ix_study_tag_tag'), ['tag'], unique=False)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('study_tag', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_study_tag_tag'))
batch_op.drop_index(batch_op.f('ix_study_tag_study_id'))

op.drop_table('study_tag')
with op.batch_alter_table('tag', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_tag_label'))
batch_op.drop_index(batch_op.f('ix_tag_color'))

op.drop_table('tag')
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion antarest/study/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Tag(Base): # type:ignore

__tablename__ = "tag"

label: str = Column(String, primary_key=True, index=True)
label = Column(String, primary_key=True, index=True)
color: str = Column(String(7), index=True)

def __str__(self) -> str:
Expand Down
13 changes: 12 additions & 1 deletion antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.login.model import Group
from antarest.study.common.utils import get_study_information
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData
from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -159,6 +159,7 @@ def get(self, id: str) -> t.Optional[Study]:
self.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.options(joinedload(Study.tags))
.get(id)
# fmt: on
)
Expand Down Expand Up @@ -218,6 +219,7 @@ def get_all(
q = q.options(joinedload(entity.owner))
q = q.options(joinedload(entity.groups))
q = q.options(joinedload(entity.additional_data))
q = q.options(joinedload(entity.tags))
if study_filter.managed is not None:
if study_filter.managed:
q = q.filter(or_(entity.type == "variantstudy", RawStudy.workspace == DEFAULT_WORKSPACE_NAME))
Expand All @@ -230,6 +232,8 @@ def get_all(
q = q.filter(entity.owner_id.in_(study_filter.users))
if study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
if study_filter.tags:
q = q.join(entity.tags).filter(Tag.id.in_(study_filter.tags))
if study_filter.archived is not None:
q = q.filter(entity.archived == study_filter.archived)
if study_filter.name:
Expand Down Expand Up @@ -314,3 +318,10 @@ def _update_study_from_cache_listing(self, study: Study) -> None:
self.cache_service.invalidate(CacheConstants.STUDY_LISTING.value)
except Exception as e:
logger.error("Failed to invalidate listing cache", exc_info=e)

def update_tags(self, tags: t.Set[str]) -> None:
if tags:
existing_tags = self.session.query(Tag).filter(Tag.label.in_(tags)).all()
new_tags = tags.difference(map(lambda x: x.label, existing_tags))
self.session.add_all([Tag(label=tag, color=Tag.generate_random_color_code(tag)) for tag in new_tags])
self.session.commit()
13 changes: 13 additions & 0 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
StudyMetadataDTO,
StudyMetadataPatchDTO,
StudySimResultDTO,
StudyTag,
Tag,
)
from antarest.study.repository import StudyFilter, StudyMetadataRepository, StudyPagination, StudySortBy
Expand Down Expand Up @@ -561,6 +562,15 @@ def update_study_information(

new_metadata = self.storage_service.get_storage(study).patch_update_study_metadata(study, metadata_patch)

self.repository.session.query(StudyTag).filter(StudyTag.study_id == uuid).delete()
self.repository.session.commit()
existing_tags = self.repository.session.query(Tag).filter(Tag.label.in_(new_metadata.tags)).all()
new_tags = set(new_metadata.tags).difference(map(lambda x: x.label, existing_tags))
self.repository.session.add_all([Tag(label=tag, color=Tag.generate_random_color_code(tag)) for tag in new_tags])
self.repository.session.commit()
self.repository.session.add_all([StudyTag(study_id=uuid, tag=tag) for tag in metadata_patch.tags])
self.repository.session.commit()

self.event_bus.push(
Event(
type=EventType.STUDY_DATA_EDITED,
Expand Down Expand Up @@ -625,6 +635,9 @@ def create_study(

author = self.get_user_name(params)

if tags:
self.repository.update_tags(tags)

raw = RawStudy(
id=sid,
name=study_name,
Expand Down
2 changes: 1 addition & 1 deletion scripts/rollback.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ CUR_DIR=$(cd "$(dirname "$0")" && pwd)
BASE_DIR=$(dirname "$CUR_DIR")

cd "$BASE_DIR"
alembic downgrade 782a481f3414
alembic downgrade 1f5db5dfad80
cd -
1 change: 1 addition & 0 deletions tests/storage/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def test_create_study() -> None:
}
}
study_service.create.return_value = expected
repository.update_tags.return_value = None
config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()}))
service = build_study_service(study_service, repository, config, user_service=user_service)

Expand Down

0 comments on commit 42a0575

Please sign in to comment.