Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize compound values #10

Merged
merged 7 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 156 additions & 33 deletions dispatchproto/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@
package dispatchproto

import (
"encoding"
"encoding/json"
"fmt"
"reflect"
"time"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)

// Any represents any value.
type Any struct{ proto *anypb.Any }

// Nil creates an Any that contains nil/null.
func Nil() Any {
return knownAny(&emptypb.Empty{})
}

// Bool creates an Any that contains a boolean value.
func Bool(v bool) Any {
return knownAny(wrapperspb.Bool(v))
Expand Down Expand Up @@ -70,16 +79,58 @@ func Duration(v time.Duration) Any {
return knownAny(durationpb.New(v))
}

// NewAny creates an Any from a proto.Message.
func NewAny(v any) (Any, error) {
// Marshal packages a Go value into an Any, for use as input
// to or output from a Dispatch function.
//
// Primitive values (booleans, integers, floats, strings, bytes, timestamps,
// durations) are supported, along with values that implement either
// proto.Message, json.Marshaler, encoding.TextMarshaler or
// encoding.BinaryMarshaler.
func Marshal(v any) (Any, error) {
if rv := reflect.ValueOf(v); rv.Kind() == reflect.Pointer && rv.IsNil() {
return Nil(), nil
}
var m proto.Message
switch vv := v.(type) {
case nil:
m = &emptypb.Empty{}
case proto.Message:
m = vv
case time.Time:
m = timestamppb.New(vv)
case time.Duration:
m = durationpb.New(vv)
case json.Marshaler:
// Obviously not ideal going to bytes, then to any, then
// to structpb.Value! It would be more efficient to use
// a json.Decoder, and/or to use a third-party JSON library.
b, err := vv.MarshalJSON()
if err != nil {
return Any{}, err
}
var v any
if err := json.Unmarshal(b, &v); err != nil {
return Any{}, err
}
m, err = structpb.NewValue(v)
if err != nil {
return Any{}, err
}

case encoding.TextMarshaler:
b, err := vv.MarshalText()
if err != nil {
return Any{}, err
}
m = wrapperspb.String(string(b))
case encoding.BinaryMarshaler:
b, err := vv.MarshalBinary()
if err != nil {
return Any{}, err
}
m = wrapperspb.Bytes(b)
case bool:
m = wrapperspb.Bool(vv)

case int:
m = wrapperspb.Int64(int64(vv))
case int8:
Expand All @@ -90,7 +141,6 @@ func NewAny(v any) (Any, error) {
m = wrapperspb.Int64(int64(vv))
case int64:
m = wrapperspb.Int64(vv)

case uint:
m = wrapperspb.UInt64(uint64(vv))
case uint8:
Expand All @@ -101,26 +151,16 @@ func NewAny(v any) (Any, error) {
m = wrapperspb.UInt64(uint64(vv))
case uint64:
m = wrapperspb.UInt64(uint64(vv))

case float32:
m = wrapperspb.Double(float64(vv))
case float64:
m = wrapperspb.Double(vv)

case string:
m = wrapperspb.String(vv)

case []byte:
m = wrapperspb.Bytes(vv)

case time.Time:
m = timestamppb.New(vv)
case time.Duration:
m = durationpb.New(vv)

default:
// TODO: support more types
return Any{}, fmt.Errorf("unsupported type: %T", v)
return Any{}, fmt.Errorf("cannot serialize %v (%T)", v, v)
}

proto, err := anypb.New(m)
Expand All @@ -131,7 +171,7 @@ func NewAny(v any) (Any, error) {
}

func knownAny(v any) Any {
any, err := NewAny(v)
any, err := Marshal(v)
if err != nil {
panic(err)
}
Expand All @@ -141,6 +181,10 @@ func knownAny(v any) Any {
var (
timeType = reflect.TypeFor[time.Time]()
durationType = reflect.TypeFor[time.Duration]()

jsonUnmarshalerType = reflect.TypeFor[json.Unmarshaler]()
textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
binaryUnmarshalerType = reflect.TypeFor[encoding.BinaryUnmarshaler]()
)

// Unmarshal unmarshals the value.
Expand All @@ -151,21 +195,91 @@ func (a Any) Unmarshal(v any) error {

rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
panic("Any.Unmarshal expects a pointer")
panic("Any.Unmarshal expects a pointer to a non-nil object")
}
elem := rv.Elem()

m, err := a.proto.UnmarshalNew()
if err != nil {
return err
}
rm := reflect.ValueOf(m)

switch elem.Type() {
case rm.Type(): // e.g. a proto.Message impl
// Check for an exact match on type (v is a proto.Message).
rm := reflect.ValueOf(m)
if elem.Type() == rm.Type() {
elem.Set(rm)
return nil
}

// Check for:
// - structpb.Value => json.Unmarshaler
// - wrapperspb.StringValue => encoding.TextUnmarshaler
// - wrapperspb.BytesValue => encoding.BinaryUnmarshaler
switch mm := m.(type) {
case *structpb.Value:
var target reflect.Value
if elem.Type().Implements(jsonUnmarshalerType) {
if elem.Kind() == reflect.Pointer && elem.IsNil() {
elem.Set(reflect.New(elem.Type().Elem()))
}
target = elem
} else if rv.Type().Implements(jsonUnmarshalerType) {
target = rv
}
if target != (reflect.Value{}) {
unmarshalJSON := target.MethodByName("UnmarshalJSON")
b, err := mm.MarshalJSON()
if err != nil {
return err
}
res := unmarshalJSON.Call([]reflect.Value{reflect.ValueOf(b)})
if err := res[0].Interface(); err != nil {
return err.(error)
}
return nil
}

case *wrapperspb.StringValue:
var target reflect.Value
if elem.Type().Implements(textUnmarshalerType) {
if elem.Kind() == reflect.Pointer && elem.IsNil() {
elem.Set(reflect.New(elem.Type().Elem()))
}
target = elem
} else if rv.Type().Implements(textUnmarshalerType) {
target = rv
}
if target != (reflect.Value{}) {
unmarshalText := target.MethodByName("UnmarshalText")
b := []byte(mm.Value)
res := unmarshalText.Call([]reflect.Value{reflect.ValueOf(b)})
if err := res[0].Interface(); err != nil {
return err.(error)
}
return nil
}

case *wrapperspb.BytesValue:
var target reflect.Value
if elem.Type().Implements(binaryUnmarshalerType) {
if elem.Kind() == reflect.Pointer && elem.IsNil() {
elem.Set(reflect.New(elem.Type().Elem()))
}
target = elem
} else if rv.Type().Implements(binaryUnmarshalerType) {
target = rv
}
if target != (reflect.Value{}) {
unmarshalBinary := target.MethodByName("UnmarshalBinary")
res := unmarshalBinary.Call([]reflect.Value{reflect.ValueOf(mm.Value)})
if err := res[0].Interface(); err != nil {
return err.(error)
}
return nil
}
}

switch elem.Type() {
case timeType:
v, ok := m.(*timestamppb.Timestamp)
if !ok {
Expand Down Expand Up @@ -249,20 +363,30 @@ func (a Any) Unmarshal(v any) error {
elem.SetString(v.Value)
return nil

default:
// Special case for []byte. Other reflect.Slice values aren't supported at this time.
if elem.Kind() == reflect.Slice && elem.Type().Elem().Kind() == reflect.Uint8 {
v, ok := m.(*wrapperspb.BytesValue)
if !ok {
return fmt.Errorf("cannot unmarshal %T into []byte", m)
case reflect.Interface:
if elem.NumMethod() == 0 {
if _, ok := m.(*emptypb.Empty); ok {
elem.SetZero()
return nil
}
elem.SetBytes(v.Value)
}

case reflect.Pointer:
if _, ok := m.(*emptypb.Empty); ok {
elem.Set(reflect.New(elem.Type()).Elem())
return nil
}

// TODO: support more types
return fmt.Errorf("unsupported type: %v (%v kind)", elem.Type(), elem.Kind())
case reflect.Slice:
if elem.Type().Elem().Kind() == reflect.Uint8 {
if v, ok := m.(*wrapperspb.BytesValue); ok {
elem.SetBytes(v.Value)
return nil
}
}
}

return fmt.Errorf("cannot deserialize %T into %v (%v kind)", m, elem.Type(), elem.Kind())
}

// TypeURL is a URL that uniquely identifies the type of the
Expand All @@ -271,10 +395,9 @@ func (a Any) TypeURL() string {
return a.proto.GetTypeUrl()
}

func (a Any) Format(f fmt.State, verb rune) {
// Implement fmt.Formatter rather than fmt.Stringer
// so that we can use String() to extract the string value.
_, _ = f.Write([]byte(fmt.Sprintf("Any(%s)", a.proto)))
// String is the string representation of the any value.
func (a Any) String() string {
return fmt.Sprintf("Any(%s)", a.proto)
}

// Equal is true if this Any is equal to another.
Expand Down
Loading
Loading