Skip to content

Commit

Permalink
core: enable standalone artifacts
Browse files Browse the repository at this point in the history
Fixes: #231
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Sep 10, 2024
1 parent 67bd67d commit bae8a24
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 208 deletions.
6 changes: 3 additions & 3 deletions internal/server/openapi/api_model_registry_service_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (s *ModelRegistryServiceAPIService) CreateModelArtifact(ctx context.Context
return ErrorResponse(http.StatusBadRequest, err), err
}

result, err := s.coreApi.UpsertModelArtifact(entity, nil)
result, err := s.coreApi.UpsertModelArtifact(entity)
if err != nil {
return ErrorResponse(api.ErrToStatus(err), err), err
}
Expand All @@ -107,7 +107,7 @@ func (s *ModelRegistryServiceAPIService) CreateModelVersion(ctx context.Context,

// CreateModelVersionArtifact - Create an Artifact in a ModelVersion
func (s *ModelRegistryServiceAPIService) CreateModelVersionArtifact(ctx context.Context, modelversionId string, artifact model.Artifact) (ImplResponse, error) {
result, err := s.coreApi.UpsertArtifact(&artifact, &modelversionId)
result, err := s.coreApi.UpsertModelVersionArtifact(&artifact, modelversionId)
if err != nil {
return ErrorResponse(api.ErrToStatus(err), err), err
}
Expand Down Expand Up @@ -445,7 +445,7 @@ func (s *ModelRegistryServiceAPIService) UpdateModelArtifact(ctx context.Context
if err != nil {
return ErrorResponse(http.StatusBadRequest, err), err
}
result, err := s.coreApi.UpsertModelArtifact(&update, nil)
result, err := s.coreApi.UpsertModelArtifact(&update)
if err != nil {
return ErrorResponse(api.ErrToStatus(err), err), err
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,18 @@ type ModelRegistryApi interface {

// ARTIFACT

UpsertArtifact(artifact *openapi.Artifact, modelVersionId *string) (*openapi.Artifact, error)
UpsertModelVersionArtifact(artifact *openapi.Artifact, modelVersionId string) (*openapi.Artifact, error)

UpsertArtifact(artifact *openapi.Artifact) (*openapi.Artifact, error)

GetArtifactById(id string) (*openapi.Artifact, error)

GetArtifacts(listOptions ListOptions, modelVersionId *string) (*openapi.ArtifactList, error)

// MODEL ARTIFACT

// UpsertModelArtifact create a new Artifact or update an Artifact associated to a specific
// ModelVersion identified by modelVersionId parameter
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*openapi.ModelArtifact, error)
// UpsertModelArtifact creates or inserts an Artifact
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact) (*openapi.ModelArtifact, error)

// GetModelArtifactById retrieve ModelArtifact by id
GetModelArtifactById(id string) (*openapi.ModelArtifact, error)
Expand Down
163 changes: 90 additions & 73 deletions pkg/core/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,62 @@ import (

// ARTIFACTS

// UpsertArtifact creates a new artifact if the provided artifact's ID is nil, or updates an existing artifact if the
// UpsertModelVersionArtifact creates a new artifact if the provided artifact's ID is nil, or updates an existing artifact if the
// ID is provided.
// A model version ID must be provided to disambiguate between artifacts.
// Upon creation, new artifacts will be associated with their corresponding model version.
func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, modelVersionId *string) (*openapi.Artifact, error) {
func (serv *ModelRegistryService) UpsertModelVersionArtifact(artifact *openapi.Artifact, modelVersionId string) (*openapi.Artifact, error) {
art, err := serv.upsertArtifact(artifact, &modelVersionId)
if err != nil {
return nil, err
}
// upsertArtifact already validates modelVersion

var id *string
if art.ModelArtifact != nil {
id = art.ModelArtifact.Id
} else if art.DocArtifact != nil {
id = art.DocArtifact.Id
} else {
return nil, fmt.Errorf("unexpected artifact type: %v", art)
}

mv, _ := serv.getModelVersionByArtifactId(*id)
fmt.Printf("found associated mv: %v", mv)

if mv == nil {
// add explicit Attribution between Artifact and ModelVersion
modelVersionId, err := converter.StringToInt64(&modelVersionId)
if err != nil {
// unreachable
return nil, fmt.Errorf("%v", err)
}
artifactId, err := converter.StringToInt64(id)
if err != nil {
return nil, fmt.Errorf("%v", err)
}
attributions := []*proto.Attribution{}
attributions = append(attributions, &proto.Attribution{
ContextId: modelVersionId,
ArtifactId: artifactId,
})
_, err = serv.mlmdClient.PutAttributionsAndAssociations(context.Background(), &proto.PutAttributionsAndAssociationsRequest{
Attributions: attributions,
Associations: make([]*proto.Association, 0),
})
if err != nil {
return nil, err
}
}
return art, nil
}

func (serv *ModelRegistryService) upsertArtifact(artifact *openapi.Artifact, modelVersionId *string) (*openapi.Artifact, error) {
if artifact == nil {
return nil, fmt.Errorf("invalid artifact pointer, can't upsert nil")
}
creating := false
if ma := artifact.ModelArtifact; ma != nil {
if ma.Id == nil {
creating = true
glog.Info("Creating model artifact")
if modelVersionId == nil {
return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest)
}
_, err := serv.GetModelVersionById(*modelVersionId)
if err != nil {
return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound)
}
} else {
glog.Info("Updating model artifact")
existing, err := serv.GetModelArtifactById(*ma.Id)
Expand All @@ -45,24 +81,11 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
}
ma = &withNotEditable

_, err = serv.getModelVersionByArtifactId(*ma.Id)
if err != nil {
return nil, err
}
artifact.ModelArtifact = &withNotEditable
}
} else if da := artifact.DocArtifact; da != nil {
if da.Id == nil {
creating = true
glog.Info("Creating doc artifact")
if modelVersionId == nil {
return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest)
}
_, err := serv.GetModelVersionById(*modelVersionId)
if err != nil {
return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound)
}
} else {
glog.Info("Updating doc artifact")
existing, err := serv.GetArtifactById(*da.Id)
Expand All @@ -77,16 +100,16 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
}
da = &withNotEditable

_, err = serv.getModelVersionByArtifactId(*da.Id)
if err != nil {
return nil, err
}
artifact.DocArtifact = &withNotEditable
}
} else {
return nil, fmt.Errorf("invalid artifact type, must be either ModelArtifact or DocArtifact: %w", api.ErrBadRequest)
}
if modelVersionId != nil {
if _, err := serv.GetModelVersionById(*modelVersionId); err != nil {
return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound)
}
}
pa, err := serv.mapper.MapFromArtifact(artifact, modelVersionId)
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
Expand All @@ -98,32 +121,16 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod
return nil, err
}

if creating {
// add explicit Attribution between Artifact and ModelVersion
modelVersionId, err := converter.StringToInt64(modelVersionId)
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
}
attributions := []*proto.Attribution{}
for _, a := range artifactsResp.ArtifactIds {
attributions = append(attributions, &proto.Attribution{
ContextId: modelVersionId,
ArtifactId: &a,
})
}
_, err = serv.mlmdClient.PutAttributionsAndAssociations(context.Background(), &proto.PutAttributionsAndAssociationsRequest{
Attributions: attributions,
Associations: make([]*proto.Association, 0),
})
if err != nil {
return nil, err
}
}

idAsString := converter.Int64ToString(&artifactsResp.ArtifactIds[0])
return serv.GetArtifactById(*idAsString)
}

// UpsertArtifact creates a new artifact if the provided artifact's ID is nil, or updates an existing artifact if the
// ID is provided.
func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact) (*openapi.Artifact, error) {
return serv.upsertArtifact(artifact, nil)
}

func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact, error) {
idAsInt, err := converter.StringToInt64(&id)
if err != nil {
Expand All @@ -145,29 +152,39 @@ func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact,
return serv.mapper.MapToArtifact(artifactsResp.Artifacts[0])
}

// GetArtifacts retrieves a list of artifacts based on the provided list options and optional model version ID.
func (serv *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ArtifactList, error) {
listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions)
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
}

var artifacts []*proto.Artifact
var nextPageToken *string
if modelVersionId == nil {
return nil, fmt.Errorf("missing model version id, cannot get artifacts without model version: %w", api.ErrBadRequest)
}
ctxId, err := converter.StringToInt64(modelVersionId)
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
}
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
ContextId: ctxId,
Options: listOperationOptions,
})
if err != nil {
return nil, err
if modelVersionId != nil {
ctxId, err := converter.StringToInt64(modelVersionId)
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
}
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
ContextId: ctxId,
Options: listOperationOptions,
})
if err != nil {
return nil, err
}
artifacts = artifactsResp.Artifacts
nextPageToken = artifactsResp.NextPageToken
} else {
artifactsResp, err := serv.mlmdClient.GetArtifacts(context.Background(), &proto.GetArtifactsRequest{
Options: listOperationOptions,
})
if err != nil {
return nil, err
}
artifacts = artifactsResp.Artifacts
nextPageToken = artifactsResp.NextPageToken
}
artifacts = artifactsResp.Artifacts
nextPageToken = artifactsResp.NextPageToken

results := []openapi.Artifact{}
for _, a := range artifacts {
Expand All @@ -191,12 +208,10 @@ func (serv *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, mode

// UpsertModelArtifact creates a new model artifact if the provided model artifact's ID is nil,
// or updates an existing model artifact if the ID is provided.
// If a model version ID is provided and the model artifact is newly created, establishes an
// explicit attribution between the model version and the created model artifact.
func (serv *ModelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*openapi.ModelArtifact, error) {
func (serv *ModelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact) (*openapi.ModelArtifact, error) {
art, err := serv.UpsertArtifact(&openapi.Artifact{
ModelArtifact: modelArtifact,
}, modelVersionId)
})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -292,6 +307,8 @@ func (serv *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions,
if err != nil {
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
}
typeQuery := fmt.Sprintf("type = '%v'", serv.nameConfig.ModelArtifactTypeName)
listOperationOptions.FilterQuery = &typeQuery
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
ContextId: ctxId,
Options: listOperationOptions,
Expand Down
Loading

0 comments on commit bae8a24

Please sign in to comment.