From abd6e5095c69d6aeeb068c4ca75c56dedcfb7512 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 21 Oct 2024 13:04:29 +0200 Subject: [PATCH] Add debugging log for DB migration --- ...1cbb3a9_add_artifact_unique_constraints.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py b/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py index 7246f3d5095..bee39e31ae3 100644 --- a/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py +++ b/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py @@ -7,6 +7,8 @@ """ from alembic import op +import sqlalchemy as sa +import sqlmodel # revision identifiers, used by Alembic. revision = "c22561cbb3a9" @@ -18,6 +20,37 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" # ### commands auto generated by Alembic - please adjust! ### + bind = op.get_bind() + session = sqlmodel.Session(bind=bind) + + avs = session.exec( + sa.text( + """ + SELECT artifact_id, version + FROM artifact_version + """ + ) + ).all() + + print(avs) + from collections import defaultdict + mapping = defaultdict(set) + + for artifact_id, version in avs: + if version in mapping[artifact_id]: + artifact_name = session.exec( + sa.text( + """ + SELECT name + FROM artifact + WHERE id = :id_ + """ + ), params={"id_": artifact_id} + ).one() + print(f"Found duplicate for artifact version {artifact_name} (version {version})") + + mapping[artifact_id].add(version) + with op.batch_alter_table("artifact", schema=None) as batch_op: batch_op.create_unique_constraint("unique_artifact_name", ["name"])