diff --git a/dispatchproto/any.go b/dispatchproto/any.go index 6da878c..9e5f7ba 100644 --- a/dispatchproto/any.go +++ b/dispatchproto/any.go @@ -3,6 +3,8 @@ package dispatchproto import ( + "encoding" + "encoding/json" "fmt" "reflect" "time" @@ -10,6 +12,8 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -17,6 +21,11 @@ import ( // Any represents any value. type Any struct{ proto *anypb.Any } +// Nil creates an Any that contains nil/null. +func Nil() Any { + return knownAny(&emptypb.Empty{}) +} + // Bool creates an Any that contains a boolean value. func Bool(v bool) Any { return knownAny(wrapperspb.Bool(v)) @@ -70,16 +79,58 @@ func Duration(v time.Duration) Any { return knownAny(durationpb.New(v)) } -// NewAny creates an Any from a proto.Message. -func NewAny(v any) (Any, error) { +// Marshal packages a Go value into an Any, for use as input +// to or output from a Dispatch function. +// +// Primitive values (booleans, integers, floats, strings, bytes, timestamps, +// durations) are supported, along with values that implement either +// proto.Message, json.Marshaler, encoding.TextMarshaler or +// encoding.BinaryMarshaler. +func Marshal(v any) (Any, error) { + if rv := reflect.ValueOf(v); rv.Kind() == reflect.Pointer && rv.IsNil() { + return Nil(), nil + } var m proto.Message switch vv := v.(type) { + case nil: + m = &emptypb.Empty{} case proto.Message: m = vv + case time.Time: + m = timestamppb.New(vv) + case time.Duration: + m = durationpb.New(vv) + case json.Marshaler: + // Obviously not ideal going to bytes, then to any, then + // to structpb.Value! It would be more efficient to use + // a json.Decoder, and/or to use a third-party JSON library. + b, err := vv.MarshalJSON() + if err != nil { + return Any{}, err + } + var v any + if err := json.Unmarshal(b, &v); err != nil { + return Any{}, err + } + m, err = structpb.NewValue(v) + if err != nil { + return Any{}, err + } + case encoding.TextMarshaler: + b, err := vv.MarshalText() + if err != nil { + return Any{}, err + } + m = wrapperspb.String(string(b)) + case encoding.BinaryMarshaler: + b, err := vv.MarshalBinary() + if err != nil { + return Any{}, err + } + m = wrapperspb.Bytes(b) case bool: m = wrapperspb.Bool(vv) - case int: m = wrapperspb.Int64(int64(vv)) case int8: @@ -90,7 +141,6 @@ func NewAny(v any) (Any, error) { m = wrapperspb.Int64(int64(vv)) case int64: m = wrapperspb.Int64(vv) - case uint: m = wrapperspb.UInt64(uint64(vv)) case uint8: @@ -101,26 +151,16 @@ func NewAny(v any) (Any, error) { m = wrapperspb.UInt64(uint64(vv)) case uint64: m = wrapperspb.UInt64(uint64(vv)) - case float32: m = wrapperspb.Double(float64(vv)) case float64: m = wrapperspb.Double(vv) - case string: m = wrapperspb.String(vv) - case []byte: m = wrapperspb.Bytes(vv) - - case time.Time: - m = timestamppb.New(vv) - case time.Duration: - m = durationpb.New(vv) - default: - // TODO: support more types - return Any{}, fmt.Errorf("unsupported type: %T", v) + return Any{}, fmt.Errorf("cannot serialize %v (%T)", v, v) } proto, err := anypb.New(m) @@ -131,7 +171,7 @@ func NewAny(v any) (Any, error) { } func knownAny(v any) Any { - any, err := NewAny(v) + any, err := Marshal(v) if err != nil { panic(err) } @@ -141,6 +181,10 @@ func knownAny(v any) Any { var ( timeType = reflect.TypeFor[time.Time]() durationType = reflect.TypeFor[time.Duration]() + + jsonUnmarshalerType = reflect.TypeFor[json.Unmarshaler]() + textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() + binaryUnmarshalerType = reflect.TypeFor[encoding.BinaryUnmarshaler]() ) // Unmarshal unmarshals the value. @@ -151,7 +195,7 @@ func (a Any) Unmarshal(v any) error { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Pointer || rv.IsNil() { - panic("Any.Unmarshal expects a pointer") + panic("Any.Unmarshal expects a pointer to a non-nil object") } elem := rv.Elem() @@ -159,13 +203,83 @@ func (a Any) Unmarshal(v any) error { if err != nil { return err } - rm := reflect.ValueOf(m) - switch elem.Type() { - case rm.Type(): // e.g. a proto.Message impl + // Check for an exact match on type (v is a proto.Message). + rm := reflect.ValueOf(m) + if elem.Type() == rm.Type() { elem.Set(rm) return nil + } + // Check for: + // - structpb.Value => json.Unmarshaler + // - wrapperspb.StringValue => encoding.TextUnmarshaler + // - wrapperspb.BytesValue => encoding.BinaryUnmarshaler + switch mm := m.(type) { + case *structpb.Value: + var target reflect.Value + if elem.Type().Implements(jsonUnmarshalerType) { + if elem.Kind() == reflect.Pointer && elem.IsNil() { + elem.Set(reflect.New(elem.Type().Elem())) + } + target = elem + } else if rv.Type().Implements(jsonUnmarshalerType) { + target = rv + } + if target != (reflect.Value{}) { + unmarshalJSON := target.MethodByName("UnmarshalJSON") + b, err := mm.MarshalJSON() + if err != nil { + return err + } + res := unmarshalJSON.Call([]reflect.Value{reflect.ValueOf(b)}) + if err := res[0].Interface(); err != nil { + return err.(error) + } + return nil + } + + case *wrapperspb.StringValue: + var target reflect.Value + if elem.Type().Implements(textUnmarshalerType) { + if elem.Kind() == reflect.Pointer && elem.IsNil() { + elem.Set(reflect.New(elem.Type().Elem())) + } + target = elem + } else if rv.Type().Implements(textUnmarshalerType) { + target = rv + } + if target != (reflect.Value{}) { + unmarshalText := target.MethodByName("UnmarshalText") + b := []byte(mm.Value) + res := unmarshalText.Call([]reflect.Value{reflect.ValueOf(b)}) + if err := res[0].Interface(); err != nil { + return err.(error) + } + return nil + } + + case *wrapperspb.BytesValue: + var target reflect.Value + if elem.Type().Implements(binaryUnmarshalerType) { + if elem.Kind() == reflect.Pointer && elem.IsNil() { + elem.Set(reflect.New(elem.Type().Elem())) + } + target = elem + } else if rv.Type().Implements(binaryUnmarshalerType) { + target = rv + } + if target != (reflect.Value{}) { + unmarshalBinary := target.MethodByName("UnmarshalBinary") + res := unmarshalBinary.Call([]reflect.Value{reflect.ValueOf(mm.Value)}) + if err := res[0].Interface(); err != nil { + return err.(error) + } + return nil + } + } + + switch elem.Type() { case timeType: v, ok := m.(*timestamppb.Timestamp) if !ok { @@ -249,20 +363,30 @@ func (a Any) Unmarshal(v any) error { elem.SetString(v.Value) return nil - default: - // Special case for []byte. Other reflect.Slice values aren't supported at this time. - if elem.Kind() == reflect.Slice && elem.Type().Elem().Kind() == reflect.Uint8 { - v, ok := m.(*wrapperspb.BytesValue) - if !ok { - return fmt.Errorf("cannot unmarshal %T into []byte", m) + case reflect.Interface: + if elem.NumMethod() == 0 { + if _, ok := m.(*emptypb.Empty); ok { + elem.SetZero() + return nil } - elem.SetBytes(v.Value) + } + + case reflect.Pointer: + if _, ok := m.(*emptypb.Empty); ok { + elem.Set(reflect.New(elem.Type()).Elem()) return nil } - // TODO: support more types - return fmt.Errorf("unsupported type: %v (%v kind)", elem.Type(), elem.Kind()) + case reflect.Slice: + if elem.Type().Elem().Kind() == reflect.Uint8 { + if v, ok := m.(*wrapperspb.BytesValue); ok { + elem.SetBytes(v.Value) + return nil + } + } } + + return fmt.Errorf("cannot deserialize %T into %v (%v kind)", m, elem.Type(), elem.Kind()) } // TypeURL is a URL that uniquely identifies the type of the @@ -271,10 +395,9 @@ func (a Any) TypeURL() string { return a.proto.GetTypeUrl() } -func (a Any) Format(f fmt.State, verb rune) { - // Implement fmt.Formatter rather than fmt.Stringer - // so that we can use String() to extract the string value. - _, _ = f.Write([]byte(fmt.Sprintf("Any(%s)", a.proto))) +// String is the string representation of the any value. +func (a Any) String() string { + return fmt.Sprintf("Any(%s)", a.proto) } // Equal is true if this Any is equal to another. diff --git a/dispatchproto/any_test.go b/dispatchproto/any_test.go index 6158a8e..ed780c4 100644 --- a/dispatchproto/any_test.go +++ b/dispatchproto/any_test.go @@ -2,6 +2,8 @@ package dispatchproto_test import ( "bytes" + "encoding" + "encoding/json" "fmt" "math" "reflect" @@ -9,17 +11,40 @@ import ( "testing" "time" - dispatch "github.com/dispatchrun/dispatch-go/dispatchproto" + "github.com/dispatchrun/dispatch-go/dispatchproto" + "github.com/google/go-cmp/cmp" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) +func TestAnyNil(t *testing.T) { + boxed := dispatchproto.Nil() + + // Check nil any can be deserialized. + var got any + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if got != nil { + t.Errorf("unexpected nil: got %v, want %v", got, nil) + } + + // Check null pointers can be deserialized. + now := time.Now() + tp := &now // set to something, then check it gets cleared + if err := boxed.Unmarshal(&tp); err != nil { + t.Fatal(err) + } else if tp != nil { + t.Errorf("unexpected nil: got %v, want %v", tp, nil) + } +} + func TestAnyBool(t *testing.T) { for _, v := range []bool{true, false} { - boxed := dispatch.Bool(v) + boxed := dispatchproto.Bool(v) var got bool if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -31,7 +56,7 @@ func TestAnyBool(t *testing.T) { func TestAnyInt(t *testing.T) { for _, v := range []int64{0, 11, -1, 2, math.MinInt, math.MaxInt} { - boxed := dispatch.Int(v) + boxed := dispatchproto.Int(v) var got int64 if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -43,7 +68,7 @@ func TestAnyInt(t *testing.T) { func TestAnyUint(t *testing.T) { for _, v := range []uint64{0, 11, 2, math.MaxUint} { - boxed := dispatch.Uint(v) + boxed := dispatchproto.Uint(v) var got uint64 if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -55,7 +80,7 @@ func TestAnyUint(t *testing.T) { func TestAnyFloat(t *testing.T) { for _, v := range []float64{0, 3.14, 11.11, math.MaxFloat64} { - boxed := dispatch.Float(v) + boxed := dispatchproto.Float(v) var got float64 if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -67,7 +92,7 @@ func TestAnyFloat(t *testing.T) { func TestAnyString(t *testing.T) { for _, v := range []string{"", "x", "foobar", strings.Repeat("abc", 100)} { - boxed := dispatch.String(v) + boxed := dispatchproto.String(v) var got string if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -79,7 +104,7 @@ func TestAnyString(t *testing.T) { func TestAnyBytes(t *testing.T) { for _, v := range [][]byte{nil, []byte("foobar"), bytes.Repeat([]byte("abc"), 100)} { - boxed := dispatch.Bytes(v) + boxed := dispatchproto.Bytes(v) var got []byte if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -91,7 +116,7 @@ func TestAnyBytes(t *testing.T) { func TestAnyTime(t *testing.T) { for _, v := range []time.Time{time.Now(), { /*zero*/ }, time.Date(2024, time.June, 10, 11, 30, 1, 2, time.UTC)} { - boxed := dispatch.Time(v) + boxed := dispatchproto.Time(v) var got time.Time if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -103,7 +128,7 @@ func TestAnyTime(t *testing.T) { func TestAnyDuration(t *testing.T) { for _, v := range []time.Duration{0, time.Second, 10 * time.Hour} { - boxed := dispatch.Duration(v) + boxed := dispatchproto.Duration(v) var got time.Duration if err := boxed.Unmarshal(&got); err != nil { t.Fatal(err) @@ -113,50 +138,156 @@ func TestAnyDuration(t *testing.T) { } } +func TestAnyTextMarshaler(t *testing.T) { + v := &textMarshaler{Value: "foobar"} + boxed, err := dispatchproto.Marshal(v) + if err != nil { + t.Fatal(err) + } + + var v2 *textMarshaler // (pointer) + if err := boxed.Unmarshal(&v2); err != nil { + t.Fatal(err) + } else if v2.Value != v.Value { + t.Errorf("unexpected serialized value: %v", v2.Value) + } + + var v3 textMarshaler // (not a pointer) + if err := boxed.Unmarshal(&v3); err != nil { + t.Fatal(err) + } else if v3.Value != v.Value { + t.Errorf("unexpected serialized value: %v", v3.Value) + } + + // Check a string is sent on the wire. + var v4 string + if err := boxed.Unmarshal(&v4); err != nil { + t.Fatal(err) + } else if v4 != v.Value { + t.Errorf("unexpected serialized value: %v", v4) + } +} + +func TestAnyBinaryMarshaler(t *testing.T) { + v := &binaryMarshaler{Value: []byte("foobar")} + boxed, err := dispatchproto.Marshal(v) + if err != nil { + t.Fatal(err) + } + + var v2 *binaryMarshaler // (pointer) + if err := boxed.Unmarshal(&v2); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v2.Value, v.Value) { + t.Errorf("unexpected serialized value: %v", v2.Value) + } + + var v3 binaryMarshaler // (not a pointer) + if err := boxed.Unmarshal(&v3); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v3.Value, v.Value) { + t.Errorf("unexpected serialized value: %v", v3.Value) + } + + // Check bytes are sent on the wire. + var v4 []byte + if err := boxed.Unmarshal(&v4); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v4, v.Value) { + t.Errorf("unexpected serialized value: %v", v4) + } +} + +func TestAnyJsonMarshaler(t *testing.T) { + v := &jsonMarshaler{Value: jsonValue{ + Bool: true, + Int: 11, + Float: 3.14, + String: "foo", + List: []any{nil, false, []any{"foo", "bar"}, map[string]any{"abc": "xyz"}}, + Object: map[string]any{"n": 3.14, "flag": true, "tags": []any{"x", "y", "z"}}, + }} + boxed, err := dispatchproto.Marshal(v) + if err != nil { + t.Fatal(err) + } + + var v2 *jsonMarshaler // (pointer) + if err := boxed.Unmarshal(&v2); err != nil { + t.Fatal(err) + } else if diff := cmp.Diff(v2.Value, v.Value); diff != "" { + t.Errorf("unexpected serialized value: %v", diff) + } + + var v3 *jsonMarshaler // (not a pointer) + if err := boxed.Unmarshal(&v3); err != nil { + t.Fatal(err) + } else if diff := cmp.Diff(v3.Value, v.Value); diff != "" { + t.Errorf("unexpected serialized value: %v", diff) + } + + // Check a structpb.Value is sent on the wire. + var v4 *structpb.Value + want := map[string]any{ + "null": nil, + "bool": true, + "int": float64(11), // (there's only one "number" type) + "float": 3.14, + "string": "foo", + "list": []any{nil, false, []any{"foo", "bar"}, map[string]any{"abc": "xyz"}}, + "object": map[string]any{"n": 3.14, "flag": true, "tags": []any{"x", "y", "z"}}, + } + if err := boxed.Unmarshal(&v4); err != nil { + t.Fatal(err) + } else if diff := cmp.Diff(v4.AsInterface(), want); diff != "" { + t.Errorf("unexpected serialized value: %v", diff) + } +} + func TestOverflow(t *testing.T) { var i8 int8 - if err := dispatch.Int(math.MinInt8 - 1).Unmarshal(&i8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -129 into int8" { + if err := dispatchproto.Int(math.MinInt8 - 1).Unmarshal(&i8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -129 into int8" { t.Errorf("unexpected error: %v", err) } - if err := dispatch.Int(math.MaxInt8 + 1).Unmarshal(&i8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 128 into int8" { + if err := dispatchproto.Int(math.MaxInt8 + 1).Unmarshal(&i8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 128 into int8" { t.Errorf("unexpected error: %v", err) } var i16 int16 - if err := dispatch.Int(math.MinInt16 - 1).Unmarshal(&i16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -32769 into int16" { + if err := dispatchproto.Int(math.MinInt16 - 1).Unmarshal(&i16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -32769 into int16" { t.Errorf("unexpected error: %v", err) } - if err := dispatch.Int(math.MaxInt16 + 1).Unmarshal(&i16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 32768 into int16" { + if err := dispatchproto.Int(math.MaxInt16 + 1).Unmarshal(&i16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 32768 into int16" { t.Errorf("unexpected error: %v", err) } var i32 int32 - if err := dispatch.Int(math.MinInt32 - 1).Unmarshal(&i32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -2147483649 into int32" { + if err := dispatchproto.Int(math.MinInt32 - 1).Unmarshal(&i32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -2147483649 into int32" { t.Errorf("unexpected error: %v", err) } - if err := dispatch.Int(math.MaxInt32 + 1).Unmarshal(&i32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 2147483648 into int32" { + if err := dispatchproto.Int(math.MaxInt32 + 1).Unmarshal(&i32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 2147483648 into int32" { t.Errorf("unexpected error: %v", err) } var u8 uint8 - if err := dispatch.Uint(math.MaxUint8 + 1).Unmarshal(&u8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 256 into uint8" { + if err := dispatchproto.Uint(math.MaxUint8 + 1).Unmarshal(&u8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 256 into uint8" { t.Errorf("unexpected error: %v", err) } var u16 uint16 - if err := dispatch.Uint(math.MaxUint16 + 1).Unmarshal(&u16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 65536 into uint16" { + if err := dispatchproto.Uint(math.MaxUint16 + 1).Unmarshal(&u16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 65536 into uint16" { t.Errorf("unexpected error: %v", err) } var u32 uint32 - if err := dispatch.Uint(math.MaxUint32 + 1).Unmarshal(&u32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 4294967296 into uint32" { + if err := dispatchproto.Uint(math.MaxUint32 + 1).Unmarshal(&u32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 4294967296 into uint32" { t.Errorf("unexpected error: %v", err) } var f32 float32 - if err := dispatch.Float(math.MaxFloat32 + math.MaxFloat32).Unmarshal(&f32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.DoubleValue of 6.805646932770577e+38 into float32" { + if err := dispatchproto.Float(math.MaxFloat32 + math.MaxFloat32).Unmarshal(&f32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.DoubleValue of 6.805646932770577e+38 into float32" { t.Errorf("unexpected error: %v", err) } - badTime, err := dispatch.NewAny(×tamppb.Timestamp{Seconds: math.MinInt64}) + badTime, err := dispatchproto.Marshal(×tamppb.Timestamp{Seconds: math.MinInt64}) if err != nil { t.Fatal(err) } @@ -165,7 +296,7 @@ func TestOverflow(t *testing.T) { t.Error("expected an error") } - badDuration, err := dispatch.NewAny(&durationpb.Duration{Seconds: math.MaxInt64}) + badDuration, err := dispatchproto.Marshal(&durationpb.Duration{Seconds: math.MaxInt64}) if err != nil { t.Fatal(err) } @@ -177,6 +308,9 @@ func TestOverflow(t *testing.T) { func TestAny(t *testing.T) { for _, v := range []any{ + nil, + (*time.Time)(nil), + true, false, @@ -207,30 +341,103 @@ func TestAny(t *testing.T) { // Raw proto.Message &emptypb.Empty{}, &wrapperspb.Int32Value{Value: 11}, + + // encoding.{Text,Binary}Marshaler + &textMarshaler{Value: "foobar"}, + &binaryMarshaler{Value: []byte("foobar")}, + + // json.Marshaler + &jsonMarshaler{Value: jsonValue{ + Bool: true, + Int: 11, + Float: 3.14, + String: "foo", + List: []any{nil, false, []any{"foo", "bar"}, map[string]any{"abc": "xyz"}}, + Object: map[string]any{"n": 3.14, "flag": true, "tags": []any{"x", "y", "z"}}, + }}, } { t.Run(fmt.Sprintf("%v", v), func(t *testing.T) { - boxed, err := dispatch.NewAny(v) + boxed, err := dispatchproto.Marshal(v) if err != nil { - t.Fatalf("NewAny(%v): %v", v, err) + t.Fatalf("Marshal(%v): %v", v, err) } - rv := reflect.New(reflect.TypeOf(v)) + var rt reflect.Type + if v == nil { + rt = reflect.ValueOf(&v).Elem().Type() + } else { + rt = reflect.ValueOf(v).Type() + } + rv := reflect.New(rt) if err := boxed.Unmarshal(rv.Interface()); err != nil { t.Fatal(err) } got := rv.Elem().Interface() - want := reflect.ValueOf(v).Interface() - var equal bool - if wantProto, ok := want.(proto.Message); ok { - equal = proto.Equal(got.(proto.Message), wantProto) - } else { - equal = reflect.DeepEqual(got, want) + var want any + if v != nil { + want = reflect.ValueOf(v).Interface() } - if !equal { - t.Errorf("unexpected NewAny(%v).Unmarshal result: %#v", v, got) + + if wantProto, ok := want.(proto.Message); ok { + if equal := proto.Equal(got.(proto.Message), wantProto); !equal { + t.Errorf("unexpected Marshal(%v).Unmarshal result: %#v", v, got) + } + } else if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("unexpected Marshal(%v).Unmarshal result: %v", v, diff) } }) } } + +type textMarshaler struct{ Value string } + +func (t *textMarshaler) MarshalText() ([]byte, error) { + return []byte(t.Value), nil +} + +func (t *textMarshaler) UnmarshalText(b []byte) error { + t.Value = string(b) + return nil +} + +var _ encoding.TextMarshaler = (*textMarshaler)(nil) +var _ encoding.TextUnmarshaler = (*textMarshaler)(nil) + +type binaryMarshaler struct{ Value []byte } + +func (t *binaryMarshaler) MarshalBinary() ([]byte, error) { + return t.Value, nil +} + +func (t *binaryMarshaler) UnmarshalBinary(b []byte) error { + t.Value = b + return nil +} + +var _ encoding.BinaryMarshaler = (*binaryMarshaler)(nil) +var _ encoding.BinaryUnmarshaler = (*binaryMarshaler)(nil) + +type jsonMarshaler struct{ Value jsonValue } + +type jsonValue struct { + Null *any `json:"null"` + Bool bool `json:"bool"` + String string `json:"string"` + Int int64 `json:"int"` + Float float64 `json:"float"` + List []any `json:"list"` + Object map[string]any `json:"object"` +} + +func (j *jsonMarshaler) MarshalJSON() ([]byte, error) { + return json.Marshal(j.Value) +} + +func (j *jsonMarshaler) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, &j.Value) +} + +var _ json.Marshaler = (*jsonMarshaler)(nil) +var _ json.Unmarshaler = (*jsonMarshaler)(nil) diff --git a/function.go b/function.go index 347580e..7e04ee8 100644 --- a/function.go +++ b/function.go @@ -35,7 +35,7 @@ func (f *Function[I, O]) Name() string { // BuildCall creates (but does not dispatch) a Call for the function. func (f *Function[I, O]) BuildCall(input I, opts ...dispatchproto.CallOption) (dispatchproto.Call, error) { - boxedInput, err := dispatchproto.NewAny(input) + boxedInput, err := dispatchproto.Marshal(input) if err != nil { return dispatchproto.Call{}, fmt.Errorf("cannot serialize input: %v", err) } @@ -145,7 +145,7 @@ func (f *Function[I, O]) tearDown(id dispatchcoro.InstanceID, coro dispatchcoro. func (f *Function[I, O]) serialize(id dispatchcoro.InstanceID, coro dispatchcoro.Coroutine) (dispatchproto.Any, error) { // In volatile mode, serialize a reference to the coroutine instance. if !coroutine.Durable { - return dispatchproto.NewAny(id) + return dispatchproto.Marshal(id) } // In durable mode, serialize the state of the coroutine. @@ -198,7 +198,7 @@ func (c *Function[I, O]) entrypoint(input I) func() dispatchproto.Response { // TODO: include output if not nil return dispatchproto.NewResponseError(err) } - boxedOutput, err := dispatchproto.NewAny(output) + boxedOutput, err := dispatchproto.Marshal(output) if err != nil { return dispatchproto.NewResponseErrorf("%w: invalid output %v: %v", ErrInvalidResponse, output, err) } diff --git a/go.mod b/go.mod index 0b52a2c..90c1c13 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/bufbuild/protovalidate-go v0.6.2 // indirect github.com/dunglas/httpsfv v1.0.2 // indirect github.com/google/cel-go v0.20.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/text v0.16.0 // indirect