diff --git a/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py b/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py index bee39e31ae3..f18722a1aea 100644 --- a/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py +++ b/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py @@ -6,9 +6,10 @@ """ -from alembic import op +from collections import defaultdict + import sqlalchemy as sa -import sqlmodel +from alembic import op # revision identifiers, used by Alembic. revision = "c22561cbb3a9" @@ -17,39 +18,47 @@ depends_on = None +def resolve_duplicate_versions() -> None: + """Resolve duplicate artifact versions.""" + connection = op.get_bind() + + meta = sa.MetaData() + meta.reflect( + bind=op.get_bind(), + only=("artifact_version",), + ) + + artifact_version_table = sa.Table("artifact_version", meta) + query = sa.select( + artifact_version_table.c.id, + artifact_version_table.c.artifact_id, + artifact_version_table.c.version, + ) + + versions_per_artifact = defaultdict(set) + + for id, artifact_id, version in connection.execute(query).fetchall(): + versions = versions_per_artifact[artifact_id] + if version in versions: + for suffix_length in range(4, len(id)): + new_version = f"{version}-{id[:suffix_length]}" + if new_version not in versions: + version = new_version + break + + connection.execute( + sa.update(artifact_version_table) + .where(artifact_version_table.c.id == id) + .values(version=version) + ) + + versions.add(version) + + def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" # ### commands auto generated by Alembic - please adjust! ### - bind = op.get_bind() - session = sqlmodel.Session(bind=bind) - - avs = session.exec( - sa.text( - """ - SELECT artifact_id, version - FROM artifact_version - """ - ) - ).all() - - print(avs) - from collections import defaultdict - mapping = defaultdict(set) - - for artifact_id, version in avs: - if version in mapping[artifact_id]: - artifact_name = session.exec( - sa.text( - """ - SELECT name - FROM artifact - WHERE id = :id_ - """ - ), params={"id_": artifact_id} - ).one() - print(f"Found duplicate for artifact version {artifact_name} (version {version})") - - mapping[artifact_id].add(version) + resolve_duplicate_versions() with op.batch_alter_table("artifact", schema=None) as batch_op: batch_op.create_unique_constraint("unique_artifact_name", ["name"])