Skip to content

Commit

Permalink
Store ID
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Oct 21, 2024
1 parent abd6e50 commit ea66a4d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2778,6 +2778,8 @@ def create_artifact_version(

assert artifact_version.artifact_id

artifact_version_id = None

if artifact_version.version is None:
# No explicit version in the request -> We will try to
# auto-increment the numeric version of the artifact version
Expand All @@ -2800,6 +2802,7 @@ def create_artifact_version(
)
session.add(artifact_version_schema)
session.commit()
artifact_version_id = artifact_version_id.id
except IntegrityError:
if remaining_tries == 0:
raise EntityExistsError(
Expand Down Expand Up @@ -2838,6 +2841,7 @@ def create_artifact_version(
)
session.add(artifact_version_schema)
session.commit()
artifact_version_id = artifact_version_id.id
except IntegrityError:
raise EntityExistsError(
f"Unable to create artifact version "
Expand All @@ -2852,15 +2856,15 @@ def create_artifact_version(
for vis in artifact_version.visualizations:
vis_schema = ArtifactVisualizationSchema.from_model(
artifact_visualization_request=vis,
artifact_version_id=artifact_version_schema.id,
artifact_version_id=artifact_version_id,
)
session.add(vis_schema)

# Save tags of the artifact
if artifact_version.tags:
self._attach_tags_to_resource(
tag_names=artifact_version.tags,
resource_id=artifact_version_schema.id,
resource_id=artifact_version_id,
resource_type=TaggableResourceTypes.ARTIFACT_VERSION,
)

Expand All @@ -2870,7 +2874,7 @@ def create_artifact_version(
run_metadata_schema = RunMetadataSchema(
workspace_id=artifact_version.workspace,
user_id=artifact_version.user,
resource_id=artifact_version_schema.id,
resource_id=artifact_version_id,
resource_type=MetadataResourceTypes.ARTIFACT_VERSION,
key=key,
value=json.dumps(value),
Expand All @@ -2881,7 +2885,7 @@ def create_artifact_version(
session.commit()
artifact_version_schema = session.exec(
select(ArtifactVersionSchema).where(
ArtifactVersionSchema.id == artifact_version_schema.id
ArtifactVersionSchema.id == artifact_version_id
)
).one()

Expand Down

0 comments on commit ea66a4d

Please sign in to comment.