Skip to content

Commit

Permalink
Fix core library issue (kubeflow#154)
Browse files Browse the repository at this point in the history
* Manage *ByParams errors

* Fix runtime movement to inference service

* Add mapper test
  • Loading branch information
lampajr authored Nov 16, 2023
1 parent 8005f07 commit 9ed849d
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 33 deletions.
4 changes: 2 additions & 2 deletions internal/converter/mlmd_openapi_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ func MapNameFromOwned(source *string) *string {
// MODEL ARTIFACT

func MapArtifactType(source *proto.Artifact) (string, error) {
if *source.Type == ModelArtifactTypeName {
if source.Type != nil && *source.Type == ModelArtifactTypeName {
return "model-artifact", nil
}
return "", fmt.Errorf("invalid artifact type found")
return "", fmt.Errorf("invalid artifact type found: %v", source.Type)
}

func MapMLMDModelArtifactState(source *proto.Artifact_State) *openapi.ArtifactState {
Expand Down
24 changes: 6 additions & 18 deletions internal/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,15 @@ func (m *Mapper) MapFromServeModel(serveModel *openapi.ServeModel, inferenceServ
// Utilities for MLMD --> OpenAPI mapping, make use of generated Converters

func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
if ctx.GetTypeId() != m.RegisteredModelTypeId {
return nil, fmt.Errorf("invalid TypeId, expected %d but received %d", m.RegisteredModelTypeId, ctx.GetTypeId())
}

return m.MLMDConverter.ConvertRegisteredModel(ctx)
return mapTo(ctx, m.RegisteredModelTypeId, m.MLMDConverter.ConvertRegisteredModel)
}

func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, error) {
if ctx.GetTypeId() != m.ModelVersionTypeId {
return nil, fmt.Errorf("invalid TypeId, expected %d but received %d", m.ModelVersionTypeId, ctx.GetTypeId())
}

return m.MLMDConverter.ConvertModelVersion(ctx)
return mapTo(ctx, m.ModelVersionTypeId, m.MLMDConverter.ConvertModelVersion)
}

func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArtifact, error) {
if artifact.GetTypeId() != m.ModelArtifactTypeId {
return nil, fmt.Errorf("invalid TypeId, expected %d but received %d", m.ModelArtifactTypeId, artifact.GetTypeId())
}

return m.MLMDConverter.ConvertModelArtifact(artifact)
func (m *Mapper) MapToModelArtifact(art *proto.Artifact) (*openapi.ModelArtifact, error) {
return mapTo(art, m.ModelArtifactTypeId, m.MLMDConverter.ConvertModelArtifact)
}

func (m *Mapper) MapToServingEnvironment(ctx *proto.Context) (*openapi.ServingEnvironment, error) {
Expand All @@ -162,8 +150,8 @@ func (m *Mapper) MapToInferenceService(ctx *proto.Context) (*openapi.InferenceSe
return mapTo(ctx, m.InferenceServiceTypeId, m.MLMDConverter.ConvertInferenceService)
}

func (m *Mapper) MapToServeModel(ctx *proto.Execution) (*openapi.ServeModel, error) {
return mapTo(ctx, m.ServeModelTypeId, m.MLMDConverter.ConvertServeModel)
func (m *Mapper) MapToServeModel(ex *proto.Execution) (*openapi.ServeModel, error) {
return mapTo(ex, m.ServeModelTypeId, m.MLMDConverter.ConvertServeModel)
}

type getTypeIder interface {
Expand Down
147 changes: 147 additions & 0 deletions internal/mapper/mapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package mapper

import (
"fmt"
"testing"

"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
"github.com/stretchr/testify/assert"
)

const (
invalidTypeId = int64(9999)
registeredModelTypeId = int64(1)
modelVersionTypeId = int64(2)
modelArtifactTypeId = int64(3)
servingEnvironmentTypeId = int64(4)
inferenceServiceTypeId = int64(5)
serveModelTypeId = int64(6)
)

func setup(t *testing.T) (*assert.Assertions, *Mapper) {
return assert.New(t), NewMapper(
registeredModelTypeId,
modelVersionTypeId,
modelArtifactTypeId,
servingEnvironmentTypeId,
inferenceServiceTypeId,
serveModelTypeId,
)
}

func TestMapToRegisteredModel(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToRegisteredModel(&proto.Context{
TypeId: of(registeredModelTypeId),
})
assertion.Nil(err)
}

func TestMapToRegisteredModelInvalid(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToRegisteredModel(&proto.Context{
TypeId: of(invalidTypeId),
})
assertion.NotNil(err)
assertion.Equal(fmt.Sprintf("invalid TypeId, expected %d but received %d", registeredModelTypeId, invalidTypeId), err.Error())
}

func TestMapToModelVersion(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToModelVersion(&proto.Context{
TypeId: of(modelVersionTypeId),
})
assertion.Nil(err)
}

func TestMapToModelVersionInvalid(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToModelVersion(&proto.Context{
TypeId: of(invalidTypeId),
})
assertion.NotNil(err)
assertion.Equal(fmt.Sprintf("invalid TypeId, expected %d but received %d", modelVersionTypeId, invalidTypeId), err.Error())
}

func TestMapToModelArtifact(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToModelArtifact(&proto.Artifact{
TypeId: of(modelArtifactTypeId),
Type: of("odh.ModelArtifact"),
})
assertion.Nil(err)
}

func TestMapToModelArtifactMissingType(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToModelArtifact(&proto.Artifact{
TypeId: of(modelArtifactTypeId),
})
assertion.NotNil(err)
assertion.Equal("error setting field ArtifactType: invalid artifact type found: <nil>", err.Error())
}

func TestMapToModelArtifactInvalid(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToModelArtifact(&proto.Artifact{
TypeId: of(invalidTypeId),
})
assertion.NotNil(err)
assertion.Equal(fmt.Sprintf("invalid TypeId, expected %d but received %d", modelArtifactTypeId, invalidTypeId), err.Error())
}

func TestMapToServingEnvironment(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToServingEnvironment(&proto.Context{
TypeId: of(servingEnvironmentTypeId),
})
assertion.Nil(err)
}

func TestMapToServingEnvironmentInvalid(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToServingEnvironment(&proto.Context{
TypeId: of(invalidTypeId),
})
assertion.NotNil(err)
assertion.Equal(fmt.Sprintf("invalid TypeId, expected %d but received %d", servingEnvironmentTypeId, invalidTypeId), err.Error())
}

func TestMapToInferenceService(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToInferenceService(&proto.Context{
TypeId: of(inferenceServiceTypeId),
})
assertion.Nil(err)
}

func TestMapToInferenceServiceInvalid(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToInferenceService(&proto.Context{
TypeId: of(invalidTypeId),
})
assertion.NotNil(err)
assertion.Equal(fmt.Sprintf("invalid TypeId, expected %d but received %d", inferenceServiceTypeId, invalidTypeId), err.Error())
}

func TestMapToServeModel(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToServeModel(&proto.Execution{
TypeId: of(serveModelTypeId),
})
assertion.Nil(err)
}

func TestMapToServeModelInvalid(t *testing.T) {
assertion, m := setup(t)
_, err := m.MapToServeModel(&proto.Execution{
TypeId: of(invalidTypeId),
})
assertion.NotNil(err)
assertion.Equal(fmt.Sprintf("invalid TypeId, expected %d but received %d", serveModelTypeId, invalidTypeId), err.Error())
}

// of returns a pointer to the provided literal/const input
func of[E any](e E) *E {
return &e
}
32 changes: 24 additions & 8 deletions pkg/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (api.ModelRegistryApi,
"registered_model_id": proto.PropertyType_INT,
// same information tracked using ParentContext association
"serving_environment_id": proto.PropertyType_INT,
"runtime": proto.PropertyType_STRING,
},
},
}
Expand All @@ -99,7 +100,6 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (api.ModelRegistryApi,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_version_id": proto.PropertyType_INT,
"runtime": proto.PropertyType_STRING,
},
},
}
Expand Down Expand Up @@ -285,8 +285,12 @@ func (serv *modelRegistryService) GetRegisteredModelByParams(name *string, exter
return nil, err
}

if len(getByParamsResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple registered models found for name=%v, externalId=%v", *name, *externalId)
if len(getByParamsResp.Contexts) > 1 {
return nil, fmt.Errorf("multiple registered models found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId))
}

if len(getByParamsResp.Contexts) == 0 {
return nil, fmt.Errorf("no registered models found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId))
}

regModel, err := serv.mapper.MapToRegisteredModel(getByParamsResp.Contexts[0])
Expand Down Expand Up @@ -506,10 +510,14 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
return nil, err
}

if len(getByParamsResp.Contexts) != 1 {
if len(getByParamsResp.Contexts) > 1 {
return nil, fmt.Errorf("multiple model versions found for versionName=%v, parentResourceId=%v, externalId=%v", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(parentResourceId), apiutils.ZeroIfNil(externalId))
}

if len(getByParamsResp.Contexts) == 0 {
return nil, fmt.Errorf("no model versions found for versionName=%v, parentResourceId=%v, externalId=%v", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(parentResourceId), apiutils.ZeroIfNil(externalId))
}

modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0])
if err != nil {
return nil, err
Expand Down Expand Up @@ -848,8 +856,12 @@ func (serv *modelRegistryService) GetServingEnvironmentByParams(name *string, ex
return nil, err
}

if len(getByParamsResp.Contexts) != 1 {
return nil, fmt.Errorf("could not find exactly one Context matching criteria: %v", getByParamsResp.Contexts)
if len(getByParamsResp.Contexts) > 1 {
return nil, fmt.Errorf("multiple serving environments found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId))
}

if len(getByParamsResp.Contexts) == 0 {
return nil, fmt.Errorf("no serving environments found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId))
}

openapiModel, err := serv.mapper.MapToServingEnvironment(getByParamsResp.Contexts[0])
Expand Down Expand Up @@ -1052,8 +1064,12 @@ func (serv *modelRegistryService) GetInferenceServiceByParams(name *string, pare
return nil, err
}

if len(getByParamsResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple InferenceServices found for name=%v, parentResourceId=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(parentResourceId), apiutils.ZeroIfNil(externalId))
if len(getByParamsResp.Contexts) > 1 {
return nil, fmt.Errorf("multiple inference services found for versionName=%v, parentResourceId=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(parentResourceId), apiutils.ZeroIfNil(externalId))
}

if len(getByParamsResp.Contexts) == 0 {
return nil, fmt.Errorf("no inference services found for versionName=%v, parentResourceId=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(parentResourceId), apiutils.ZeroIfNil(externalId))
}

toReturn, err := serv.mapper.MapToInferenceService(getByParamsResp.Contexts[0])
Expand Down
25 changes: 20 additions & 5 deletions pkg/core/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1141,10 +1141,14 @@ func TestCreateModelArtifact(t *testing.T) {
modelVersionId := registerModelVersion(assertion, service, nil, nil, nil, nil)

modelArtifact := &openapi.ModelArtifact{
Name: &artifactName,
State: (*openapi.ArtifactState)(&artifactState),
Uri: &artifactUri,
Description: &artifactDescription,
Name: &artifactName,
State: (*openapi.ArtifactState)(&artifactState),
Uri: &artifactUri,
Description: &artifactDescription,
ModelFormatName: of("onnx"),
ModelFormatVersion: of("1"),
StorageKey: of("aws-connection-models"),
StoragePath: of("bucket"),
CustomProperties: &map[string]openapi.MetadataValue{
"author": {
MetadataStringValue: &openapi.MetadataStringValue{
Expand All @@ -1163,6 +1167,10 @@ func TestCreateModelArtifact(t *testing.T) {
assertion.Equal(*state, *createdArtifact.State)
assertion.Equal(artifactUri, *createdArtifact.Uri)
assertion.Equal(artifactDescription, *createdArtifact.Description)
assertion.Equal("onnx", *createdArtifact.ModelFormatName)
assertion.Equal("1", *createdArtifact.ModelFormatVersion)
assertion.Equal("aws-connection-models", *createdArtifact.StorageKey)
assertion.Equal("bucket", *createdArtifact.StoragePath)
assertion.Equal(author, *(*createdArtifact.CustomProperties)["author"].MetadataStringValue.StringValue)

createdArtifactId, _ := converter.StringToInt64(createdArtifact.Id)
Expand All @@ -1176,6 +1184,10 @@ func TestCreateModelArtifact(t *testing.T) {
assertion.Equal(string(*createdArtifact.State), getById.Artifacts[0].State.String())
assertion.Equal(*createdArtifact.Uri, *getById.Artifacts[0].Uri)
assertion.Equal(*createdArtifact.Description, getById.Artifacts[0].Properties["description"].GetStringValue())
assertion.Equal(*createdArtifact.ModelFormatName, getById.Artifacts[0].Properties["model_format_name"].GetStringValue())
assertion.Equal(*createdArtifact.ModelFormatVersion, getById.Artifacts[0].Properties["model_format_version"].GetStringValue())
assertion.Equal(*createdArtifact.StorageKey, getById.Artifacts[0].Properties["storage_key"].GetStringValue())
assertion.Equal(*createdArtifact.StoragePath, getById.Artifacts[0].Properties["storage_path"].GetStringValue())
assertion.Equal(*(*createdArtifact.CustomProperties)["author"].MetadataStringValue.StringValue, getById.Artifacts[0].CustomProperties["author"].GetStringValue())

modelVersionIdAsInt, _ := converter.StringToInt64(&modelVersionId)
Expand Down Expand Up @@ -1937,13 +1949,15 @@ func TestCreateInferenceService(t *testing.T) {

parentResourceId := registerServingEnvironment(assertion, service, nil, nil)
registeredModelId := registerModel(assertion, service, nil, nil)
runtime := "model-server"

eut := &openapi.InferenceService{
Name: &entityName,
ExternalID: &entityExternalId2,
Description: &entityDescription,
ServingEnvironmentId: parentResourceId,
RegisteredModelId: registeredModelId,
Runtime: &runtime,
CustomProperties: &map[string]openapi.MetadataValue{
"author": {
MetadataStringValue: &openapi.MetadataStringValue{
Expand All @@ -1954,7 +1968,7 @@ func TestCreateInferenceService(t *testing.T) {
}

createdEntity, err := service.UpsertInferenceService(eut)
assertion.Nilf(err, "error creating new eut for %v", parentResourceId)
assertion.Nilf(err, "error creating new eut for %s: %v", parentResourceId, err)

assertion.NotNilf(createdEntity.Id, "created eut should not have nil Id")

Expand All @@ -1973,6 +1987,7 @@ func TestCreateInferenceService(t *testing.T) {
assertion.Equal(entityExternalId2, *byId.Contexts[0].ExternalId, "saved external id should match the provided one")
assertion.Equal(author, byId.Contexts[0].CustomProperties["author"].GetStringValue(), "saved author custom property should match the provided one")
assertion.Equal(entityDescription, byId.Contexts[0].Properties["description"].GetStringValue(), "saved description should match the provided one")
assertion.Equal(runtime, byId.Contexts[0].Properties["runtime"].GetStringValue(), "saved runtime should match the provided one")
assertion.Equalf(*inferenceServiceTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *inferenceServiceTypeName)

getAllResp, err := client.GetContexts(context.Background(), &proto.GetContextsRequest{})
Expand Down

0 comments on commit 9ed849d

Please sign in to comment.