diff --git a/go.mod b/go.mod index e96d7555..c6384b8e 100644 --- a/go.mod +++ b/go.mod @@ -91,7 +91,7 @@ require ( github.com/urfave/cli/v2 v2.25.5 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect - golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect + golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect diff --git a/internal/core/mapper/mlmd_mapper.go b/internal/core/mapper/mlmd_mapper.go index d9572423..71b7e108 100644 --- a/internal/core/mapper/mlmd_mapper.go +++ b/internal/core/mapper/mlmd_mapper.go @@ -1,12 +1,14 @@ package mapper import ( + "encoding/base64" + "encoding/json" "fmt" - "log" "strconv" "github.com/opendatahub-io/model-registry/internal/ml_metadata/proto" "github.com/opendatahub-io/model-registry/internal/model/openapi" + "google.golang.org/protobuf/types/known/structpb" ) type Mapper struct { @@ -50,12 +52,14 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str value := proto.Value{} switch { + // bool value + case v.MetadataBoolValue != nil: + value.Value = &proto.Value_BoolValue{BoolValue: *v.MetadataBoolValue.BoolValue} // int value case v.MetadataIntValue != nil: intValue, err := IdToInt64(*v.MetadataIntValue.IntValue) if err != nil { - log.Printf("Skipping mapping for %s:%v", key, v) - continue + return nil, fmt.Errorf("unable to decode as int64 %w for key %s", err, key) } value.Value = &proto.Value_IntValue{IntValue: *intValue} // double value @@ -64,9 +68,26 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str // string value case v.MetadataStringValue != nil: value.Value = &proto.Value_StringValue{StringValue: *v.MetadataStringValue.StringValue} + // struct value + case v.MetadataStructValue != nil: + data, err := base64.StdEncoding.DecodeString(*v.MetadataStructValue.StructValue) + if err != nil { + return nil, fmt.Errorf("unable to decode %w for key %s", err, key) + } + var asMap map[string]interface{} + err = json.Unmarshal(data, &asMap) + if err != nil { + return nil, fmt.Errorf("unable to decode %w for key %s", err, key) + } + asStruct, err := structpb.NewStruct(asMap) + if err != nil { + return nil, fmt.Errorf("unable to decode %w for key %s", err, key) + } + value.Value = &proto.Value_StructValue{ + StructValue: asStruct, + } default: - log.Printf("Type mapping not found for %s:%v", key, v) - continue + return nil, fmt.Errorf("type mapping not found for %s:%v", key, v) } props[key] = &value @@ -168,6 +189,10 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op customValue := openapi.MetadataValue{} switch typedValue := v.Value.(type) { + case *proto.Value_BoolValue: + customValue.MetadataBoolValue = &openapi.MetadataBoolValue{ + BoolValue: &typedValue.BoolValue, + } case *proto.Value_IntValue: customValue.MetadataIntValue = &openapi.MetadataIntValue{ IntValue: IdToString(typedValue.IntValue), @@ -180,9 +205,19 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op customValue.MetadataStringValue = &openapi.MetadataStringValue{ StringValue: &typedValue.StringValue, } + case *proto.Value_StructValue: + sv := typedValue.StructValue + asMap := sv.AsMap() + asJSON, err := json.Marshal(asMap) + if err != nil { + return nil, err + } + b64 := base64.StdEncoding.EncodeToString(asJSON) + customValue.MetadataStructValue = &openapi.MetadataStructValue{ + StructValue: &b64, + } default: - log.Printf("Type mapping not found for %s:%v", key, v) - continue + return nil, fmt.Errorf("type mapping not found for %s:%v", key, v) } data[key] = customValue diff --git a/internal/core/mapper/mlmd_mapper_test.go b/internal/core/mapper/mlmd_mapper_test.go new file mode 100644 index 00000000..d9203475 --- /dev/null +++ b/internal/core/mapper/mlmd_mapper_test.go @@ -0,0 +1,125 @@ +package mapper_test + +import ( + "encoding/base64" + "encoding/json" + "testing" + + "github.com/opendatahub-io/model-registry/internal/core/mapper" + "github.com/opendatahub-io/model-registry/internal/model/openapi" + "github.com/stretchr/testify/assert" + "golang.org/x/exp/maps" +) + +func TestMetadataValueBool(t *testing.T) { + data := make(map[string]openapi.MetadataValue) + key := "my bool" + mdValue := true + data[key] = openapi.MetadataBoolValueAsMetadataValue(&openapi.MetadataBoolValue{BoolValue: &mdValue}) + + roundTripAndAssert(t, data, key) +} + +func TestMetadataValueInt(t *testing.T) { + data := make(map[string]openapi.MetadataValue) + key := "my int" + mdValue := "987" + data[key] = openapi.MetadataIntValueAsMetadataValue(&openapi.MetadataIntValue{IntValue: &mdValue}) + + roundTripAndAssert(t, data, key) +} + +func TestMetadataValueIntFailure(t *testing.T) { + data := make(map[string]openapi.MetadataValue) + key := "my int" + mdValue := "not a number" + data[key] = openapi.MetadataIntValueAsMetadataValue(&openapi.MetadataIntValue{IntValue: &mdValue}) + + mapper, assert := setup(t) + asGRPC, err := mapper.MapToProperties(data) + if err == nil { + assert.Fail("Did not expected a converted value but an error: %v", asGRPC) + } +} + +func TestMetadataValueDouble(t *testing.T) { + data := make(map[string]openapi.MetadataValue) + key := "my double" + mdValue := 3.1415 + data[key] = openapi.MetadataDoubleValueAsMetadataValue(&openapi.MetadataDoubleValue{DoubleValue: &mdValue}) + + roundTripAndAssert(t, data, key) +} + +func TestMetadataValueString(t *testing.T) { + data := make(map[string]openapi.MetadataValue) + key := "my string" + mdValue := "Hello, World!" + data[key] = openapi.MetadataStringValueAsMetadataValue(&openapi.MetadataStringValue{StringValue: &mdValue}) + + roundTripAndAssert(t, data, key) +} + +func TestMetadataValueStruct(t *testing.T) { + data := make(map[string]openapi.MetadataValue) + key := "my struct" + + myMap := make(map[string]interface{}) + myMap["name"] = "John Doe" + myMap["age"] = 47 + asJSON, err := json.Marshal(myMap) + if err != nil { + t.Error(err) + } + b64 := base64.StdEncoding.EncodeToString(asJSON) + data[key] = openapi.MetadataStructValueAsMetadataValue(&openapi.MetadataStructValue{StructValue: &b64}) + + roundTripAndAssert(t, data, key) +} + +func TestMetadataValueProtoUnsupported(t *testing.T) { + data := make(map[string]openapi.MetadataValue) + key := "my proto" + + myMap := make(map[string]interface{}) + myMap["name"] = "John Doe" + myMap["age"] = 47 + asJSON, err := json.Marshal(myMap) + if err != nil { + t.Error(err) + } + b64 := base64.StdEncoding.EncodeToString(asJSON) + typeDef := "map[string]openapi.MetadataValue" + data[key] = openapi.MetadataProtoValueAsMetadataValue(&openapi.MetadataProtoValue{ + Type: &typeDef, + ProtoValue: &b64, + }) + + mapper, assert := setup(t) + asGRPC, err := mapper.MapToProperties(data) + if err == nil { + assert.Fail("Did not expected a converted value but an error: %v", asGRPC) + } +} + +func roundTripAndAssert(t *testing.T, data map[string]openapi.MetadataValue, key string) { + mapper, assert := setup(t) + + // first half + asGRPC, err := mapper.MapToProperties(data) + if err != nil { + t.Error(err) + } + assert.Contains(maps.Keys(asGRPC), key) + + // second half + unmarshall, err := mapper.MapFromProperties(asGRPC) + if err != nil { + t.Error(err) + } + assert.Equal(data, unmarshall, "result of round-trip shall be equal to original data") +} + +func setup(t *testing.T) (*mapper.Mapper, *assert.Assertions) { + return mapper.NewMapper(1, 2, 3), assert.New(t) +}