Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native serialization of JSON-like slices & maps #13

Merged
merged 8 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dispatchhttp/client.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
2 changes: 2 additions & 0 deletions dispatchhttp/header.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
2 changes: 2 additions & 0 deletions dispatchhttp/request.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
2 changes: 2 additions & 0 deletions dispatchhttp/response.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package dispatchhttp

import (
Expand Down
156 changes: 153 additions & 3 deletions dispatchproto/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ func Duration(v time.Duration) Any {
// Primitive values (booleans, integers, floats, strings, bytes, timestamps,
// durations) are supported, along with values that implement either
// proto.Message, json.Marshaler, encoding.TextMarshaler or
// encoding.BinaryMarshaler.
// encoding.BinaryMarshaler. Slices and maps are also supported, as long
// as they are JSON-like in shape.
func Marshal(v any) (Any, error) {
if rv := reflect.ValueOf(v); rv.Kind() == reflect.Pointer && rv.IsNil() {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Pointer && rv.IsNil() {
return Nil(), nil
}
var m proto.Message
Expand Down Expand Up @@ -160,7 +162,10 @@ func Marshal(v any) (Any, error) {
case []byte:
m = wrapperspb.Bytes(vv)
default:
return Any{}, fmt.Errorf("cannot serialize %v (%T)", v, v)
var err error
if m, err = newStructpbValue(rv); err != nil {
return Any{}, fmt.Errorf("cannot serialize %v: %w", v, err)
}
}

proto, err := anypb.New(m)
Expand Down Expand Up @@ -386,6 +391,10 @@ func (a Any) Unmarshal(v any) error {
}
}

if s, ok := m.(*structpb.Value); ok {
return fromStructpbValue(elem, s)
}

return fmt.Errorf("cannot deserialize %T into %v (%v kind)", m, elem.Type(), elem.Kind())
}

Expand All @@ -404,3 +413,144 @@ func (a Any) String() string {
func (a Any) Equal(other Any) bool {
return proto.Equal(a.proto, other.proto)
}

func newStructpbValue(rv reflect.Value) (*structpb.Value, error) {
switch rv.Kind() {
case reflect.Bool:
return structpb.NewBoolValue(rv.Bool()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n := rv.Int()
f := float64(n)
if int64(f) != n {
return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f)
}
return structpb.NewNumberValue(f), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := rv.Uint()
f := float64(n)
if uint64(f) != n {
return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f)
}
return structpb.NewNumberValue(f), nil
case reflect.Float32, reflect.Float64:
return structpb.NewNumberValue(rv.Float()), nil
case reflect.String:
return structpb.NewStringValue(rv.String()), nil
case reflect.Interface:
if rv.NumMethod() == 0 { // interface{} aka. any
v := rv.Interface()
if v == nil {
return structpb.NewNullValue(), nil
}
return newStructpbValue(reflect.ValueOf(v))
}
case reflect.Slice:
list := &structpb.ListValue{Values: make([]*structpb.Value, rv.Len())}
for i := range list.Values {
elem := rv.Index(i)
var err error
list.Values[i], err = newStructpbValue(elem)
if err != nil {
return nil, err
}
}
return structpb.NewListValue(list), nil
case reflect.Map:
strct := &structpb.Struct{Fields: make(map[string]*structpb.Value, rv.Len())}
iter := rv.MapRange()
for iter.Next() {
k := iter.Key()

var strKey string
var hasStrKey bool
switch k.Kind() {
case reflect.String:
strKey = k.String()
hasStrKey = true
case reflect.Interface:
if s, ok := k.Interface().(string); ok {
strKey = s
hasStrKey = true
}
}
if !hasStrKey {
return nil, fmt.Errorf("cannot serialize map with %s (%s) key", k.Type(), k.Kind())
}

v, err := newStructpbValue(iter.Value())
if err != nil {
return nil, err
}
strct.Fields[strKey] = v
}
return structpb.NewStructValue(strct), nil
}
return nil, fmt.Errorf("not implemented: %s", rv.Type())
}

func fromStructpbValue(rv reflect.Value, s *structpb.Value) error {
switch rv.Kind() {
case reflect.Bool:
if b, ok := s.Kind.(*structpb.Value_BoolValue); ok {
rv.SetBool(b.BoolValue)
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
rv.SetInt(int64(n.NumberValue))
return nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
rv.SetUint(uint64(n.NumberValue))
return nil
}
case reflect.Float32, reflect.Float64:
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
rv.SetFloat(n.NumberValue)
return nil
}
case reflect.String:
if str, ok := s.Kind.(*structpb.Value_StringValue); ok {
rv.SetString(str.StringValue)
return nil
}
case reflect.Slice:
if l, ok := s.Kind.(*structpb.Value_ListValue); ok {
values := l.ListValue.GetValues()
rv.Grow(len(values))
rv.SetLen(len(values))
for i, value := range values {
if err := fromStructpbValue(rv.Index(i), value); err != nil {
return err
}
}
return nil
}
case reflect.Map:
if strct, ok := s.Kind.(*structpb.Value_StructValue); ok {
fields := strct.StructValue.Fields
rv.Set(reflect.MakeMapWithSize(rv.Type(), len(fields)))
valueType := rv.Type().Elem()
for key, value := range fields {
mv := reflect.New(valueType).Elem()
if err := fromStructpbValue(mv, value); err != nil {
return err
}
rv.SetMapIndex(reflect.ValueOf(key), mv)
}
return nil
}
case reflect.Interface:
if rv.NumMethod() == 0 { // interface{} aka. any
v := s.AsInterface()
if v == nil {
rv.SetZero()
} else {
rv.Set(reflect.ValueOf(s.AsInterface()))
}
return nil
}
}
return fmt.Errorf("cannot deserialize %T into %v (%v kind)", s, rv.Type(), rv.Kind())
}
15 changes: 15 additions & 0 deletions dispatchproto/any_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"math"
"net/http"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -355,6 +356,20 @@ func TestAny(t *testing.T) {
List: []any{nil, false, []any{"foo", "bar"}, map[string]any{"abc": "xyz"}},
Object: map[string]any{"n": 3.14, "flag": true, "tags": []any{"x", "y", "z"}},
}},

// slices
[]string{"foo", "bar"},
[]int{-1, 1, 111},
[]bool{true, false, true},
[]float64{3.14, 1.25},
[][]string{{"foo", "bar"}, {"abc", "xyz"}},
[]any{3.14, true, "x", nil},

// maps
map[string]string{"abc": "xyz", "foo": "bar"},
map[string]int{"n": 3},
map[string]http.Header{"original": {"X-Foo": []string{"bar"}}},
map[any]any{"foo": "bar", "pi": 3.14},
} {
t.Run(fmt.Sprintf("%v", v), func(t *testing.T) {
boxed, err := dispatchproto.Marshal(v)
Expand Down
29 changes: 7 additions & 22 deletions examples/fanout/main.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !durable

package main

import (
Expand All @@ -12,14 +14,14 @@ import (

func main() {
getRepo := dispatch.Func("getRepo", func(ctx context.Context, name string) (*dispatchhttp.Response, error) {
return dispatchhttp.Get(context.Background(), "https://api.github.com/repos/dispatchrun/"+name)
return dispatchhttp.Get(ctx, "https://api.github.com/repos/dispatchrun/"+name)
})

getStargazers := dispatch.Func("getStargazers", func(ctx context.Context, url string) (*dispatchhttp.Response, error) {
return dispatchhttp.Get(context.Background(), url)
return dispatchhttp.Get(ctx, url)
})

reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs strings) (strings, error) {
reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs []string) ([]string, error) {
responses, err := getStargazers.Gather(stargazerURLs)
if err != nil {
return nil, err
Expand All @@ -39,7 +41,7 @@ func main() {
return maps.Keys(stargazers), nil
})

fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames strings) (strings, error) {
fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames []string) ([]string, error) {
responses, err := getRepo.Gather(repoNames)
if err != nil {
return nil, err
Expand All @@ -65,7 +67,7 @@ func main() {
}

go func() {
if _, err := fanout.Dispatch(context.Background(), strings{"coroutine", "dispatch-py"}); err != nil {
if _, err := fanout.Dispatch(context.Background(), []string{"coroutine", "dispatch-py"}); err != nil {
log.Fatalf("failed to dispatch call: %v", err)
}
}()
Expand All @@ -74,20 +76,3 @@ func main() {
log.Fatalf("failed to serve endpoint: %v", err)
}
}

// TODO: update dispatchproto.Marshal to support serializing slices/maps
// natively (if they can be sent on the wire as structpb.Value)
type strings []string

func (s strings) MarshalJSON() ([]byte, error) {
return json.Marshal([]string(s))
}

func (s *strings) UnmarshalJSON(b []byte) error {
var c []string
if err := json.Unmarshal(b, &c); err != nil {
return err
}
*s = c
return nil
}
Loading