Skip to content

Commit

Permalink
Metadata mapping OpenAPI/gRPC for Bool and Struct (kubeflow#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarilabs authored Oct 24, 2023
1 parent 2c36e23 commit 23fe7ed
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 8 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ require (
github.com/urfave/cli/v2 v2.25.5 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
golang.org/x/mod v0.10.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
Expand Down
49 changes: 42 additions & 7 deletions internal/core/mapper/mlmd_mapper.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package mapper

import (
"encoding/base64"
"encoding/json"
"fmt"
"log"
"strconv"

"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
"github.com/opendatahub-io/model-registry/internal/model/openapi"
"google.golang.org/protobuf/types/known/structpb"
)

type Mapper struct {
Expand Down Expand Up @@ -50,12 +52,14 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str
value := proto.Value{}

switch {
// bool value
case v.MetadataBoolValue != nil:
value.Value = &proto.Value_BoolValue{BoolValue: *v.MetadataBoolValue.BoolValue}
// int value
case v.MetadataIntValue != nil:
intValue, err := IdToInt64(*v.MetadataIntValue.IntValue)
if err != nil {
log.Printf("Skipping mapping for %s:%v", key, v)
continue
return nil, fmt.Errorf("unable to decode as int64 %w for key %s", err, key)
}
value.Value = &proto.Value_IntValue{IntValue: *intValue}
// double value
Expand All @@ -64,9 +68,26 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str
// string value
case v.MetadataStringValue != nil:
value.Value = &proto.Value_StringValue{StringValue: *v.MetadataStringValue.StringValue}
// struct value
case v.MetadataStructValue != nil:
data, err := base64.StdEncoding.DecodeString(*v.MetadataStructValue.StructValue)
if err != nil {
return nil, fmt.Errorf("unable to decode %w for key %s", err, key)
}
var asMap map[string]interface{}
err = json.Unmarshal(data, &asMap)
if err != nil {
return nil, fmt.Errorf("unable to decode %w for key %s", err, key)
}
asStruct, err := structpb.NewStruct(asMap)
if err != nil {
return nil, fmt.Errorf("unable to decode %w for key %s", err, key)
}
value.Value = &proto.Value_StructValue{
StructValue: asStruct,
}
default:
log.Printf("Type mapping not found for %s:%v", key, v)
continue
return nil, fmt.Errorf("type mapping not found for %s:%v", key, v)
}

props[key] = &value
Expand Down Expand Up @@ -168,6 +189,10 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op
customValue := openapi.MetadataValue{}

switch typedValue := v.Value.(type) {
case *proto.Value_BoolValue:
customValue.MetadataBoolValue = &openapi.MetadataBoolValue{
BoolValue: &typedValue.BoolValue,
}
case *proto.Value_IntValue:
customValue.MetadataIntValue = &openapi.MetadataIntValue{
IntValue: IdToString(typedValue.IntValue),
Expand All @@ -180,9 +205,19 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op
customValue.MetadataStringValue = &openapi.MetadataStringValue{
StringValue: &typedValue.StringValue,
}
case *proto.Value_StructValue:
sv := typedValue.StructValue
asMap := sv.AsMap()
asJSON, err := json.Marshal(asMap)
if err != nil {
return nil, err
}
b64 := base64.StdEncoding.EncodeToString(asJSON)
customValue.MetadataStructValue = &openapi.MetadataStructValue{
StructValue: &b64,
}
default:
log.Printf("Type mapping not found for %s:%v", key, v)
continue
return nil, fmt.Errorf("type mapping not found for %s:%v", key, v)
}

data[key] = customValue
Expand Down
125 changes: 125 additions & 0 deletions internal/core/mapper/mlmd_mapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package mapper_test

import (
"encoding/base64"
"encoding/json"
"testing"

"github.com/opendatahub-io/model-registry/internal/core/mapper"
"github.com/opendatahub-io/model-registry/internal/model/openapi"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/maps"
)

func TestMetadataValueBool(t *testing.T) {
data := make(map[string]openapi.MetadataValue)
key := "my bool"
mdValue := true
data[key] = openapi.MetadataBoolValueAsMetadataValue(&openapi.MetadataBoolValue{BoolValue: &mdValue})

roundTripAndAssert(t, data, key)
}

func TestMetadataValueInt(t *testing.T) {
data := make(map[string]openapi.MetadataValue)
key := "my int"
mdValue := "987"
data[key] = openapi.MetadataIntValueAsMetadataValue(&openapi.MetadataIntValue{IntValue: &mdValue})

roundTripAndAssert(t, data, key)
}

func TestMetadataValueIntFailure(t *testing.T) {
data := make(map[string]openapi.MetadataValue)
key := "my int"
mdValue := "not a number"
data[key] = openapi.MetadataIntValueAsMetadataValue(&openapi.MetadataIntValue{IntValue: &mdValue})

mapper, assert := setup(t)
asGRPC, err := mapper.MapToProperties(data)
if err == nil {
assert.Fail("Did not expected a converted value but an error: %v", asGRPC)
}
}

func TestMetadataValueDouble(t *testing.T) {
data := make(map[string]openapi.MetadataValue)
key := "my double"
mdValue := 3.1415
data[key] = openapi.MetadataDoubleValueAsMetadataValue(&openapi.MetadataDoubleValue{DoubleValue: &mdValue})

roundTripAndAssert(t, data, key)
}

func TestMetadataValueString(t *testing.T) {
data := make(map[string]openapi.MetadataValue)
key := "my string"
mdValue := "Hello, World!"
data[key] = openapi.MetadataStringValueAsMetadataValue(&openapi.MetadataStringValue{StringValue: &mdValue})

roundTripAndAssert(t, data, key)
}

func TestMetadataValueStruct(t *testing.T) {
data := make(map[string]openapi.MetadataValue)
key := "my struct"

myMap := make(map[string]interface{})
myMap["name"] = "John Doe"
myMap["age"] = 47
asJSON, err := json.Marshal(myMap)
if err != nil {
t.Error(err)
}
b64 := base64.StdEncoding.EncodeToString(asJSON)
data[key] = openapi.MetadataStructValueAsMetadataValue(&openapi.MetadataStructValue{StructValue: &b64})

roundTripAndAssert(t, data, key)
}

func TestMetadataValueProtoUnsupported(t *testing.T) {
data := make(map[string]openapi.MetadataValue)
key := "my proto"

myMap := make(map[string]interface{})
myMap["name"] = "John Doe"
myMap["age"] = 47
asJSON, err := json.Marshal(myMap)
if err != nil {
t.Error(err)
}
b64 := base64.StdEncoding.EncodeToString(asJSON)
typeDef := "map[string]openapi.MetadataValue"
data[key] = openapi.MetadataProtoValueAsMetadataValue(&openapi.MetadataProtoValue{
Type: &typeDef,
ProtoValue: &b64,
})

mapper, assert := setup(t)
asGRPC, err := mapper.MapToProperties(data)
if err == nil {
assert.Fail("Did not expected a converted value but an error: %v", asGRPC)
}
}

func roundTripAndAssert(t *testing.T, data map[string]openapi.MetadataValue, key string) {
mapper, assert := setup(t)

// first half
asGRPC, err := mapper.MapToProperties(data)
if err != nil {
t.Error(err)
}
assert.Contains(maps.Keys(asGRPC), key)

// second half
unmarshall, err := mapper.MapFromProperties(asGRPC)
if err != nil {
t.Error(err)
}
assert.Equal(data, unmarshall, "result of round-trip shall be equal to original data")
}

func setup(t *testing.T) (*mapper.Mapper, *assert.Assertions) {
return mapper.NewMapper(1, 2, 3), assert.New(t)
}

0 comments on commit 23fe7ed

Please sign in to comment.