From ea66a4d74a00fd523aa5d806503b6136f5b21151 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 21 Oct 2024 13:36:53 +0200 Subject: [PATCH] Store ID --- src/zenml/zen_stores/sql_zen_store.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index daeae60a8a2..c6f21283c7b 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -2778,6 +2778,8 @@ def create_artifact_version( assert artifact_version.artifact_id + artifact_version_id = None + if artifact_version.version is None: # No explicit version in the request -> We will try to # auto-increment the numeric version of the artifact version @@ -2800,6 +2802,7 @@ def create_artifact_version( ) session.add(artifact_version_schema) session.commit() + artifact_version_id = artifact_version_id.id except IntegrityError: if remaining_tries == 0: raise EntityExistsError( @@ -2838,6 +2841,7 @@ def create_artifact_version( ) session.add(artifact_version_schema) session.commit() + artifact_version_id = artifact_version_id.id except IntegrityError: raise EntityExistsError( f"Unable to create artifact version " @@ -2852,7 +2856,7 @@ def create_artifact_version( for vis in artifact_version.visualizations: vis_schema = ArtifactVisualizationSchema.from_model( artifact_visualization_request=vis, - artifact_version_id=artifact_version_schema.id, + artifact_version_id=artifact_version_id, ) session.add(vis_schema) @@ -2860,7 +2864,7 @@ def create_artifact_version( if artifact_version.tags: self._attach_tags_to_resource( tag_names=artifact_version.tags, - resource_id=artifact_version_schema.id, + resource_id=artifact_version_id, resource_type=TaggableResourceTypes.ARTIFACT_VERSION, ) @@ -2870,7 +2874,7 @@ def create_artifact_version( run_metadata_schema = RunMetadataSchema( workspace_id=artifact_version.workspace, user_id=artifact_version.user, - resource_id=artifact_version_schema.id, + resource_id=artifact_version_id, resource_type=MetadataResourceTypes.ARTIFACT_VERSION, key=key, value=json.dumps(value), @@ -2881,7 +2885,7 @@ def create_artifact_version( session.commit() artifact_version_schema = session.exec( select(ArtifactVersionSchema).where( - ArtifactVersionSchema.id == artifact_version_schema.id + ArtifactVersionSchema.id == artifact_version_id ) ).one()