diff --git a/dispatchhttp/client.go b/dispatchhttp/client.go index 55de191..7254466 100644 --- a/dispatchhttp/client.go +++ b/dispatchhttp/client.go @@ -1,3 +1,5 @@ +//go:build !durable + package dispatchhttp import ( diff --git a/dispatchhttp/header.go b/dispatchhttp/header.go index 1aa4069..9a0a258 100644 --- a/dispatchhttp/header.go +++ b/dispatchhttp/header.go @@ -1,3 +1,5 @@ +//go:build !durable + package dispatchhttp import ( diff --git a/dispatchhttp/request.go b/dispatchhttp/request.go index 247871c..702a149 100644 --- a/dispatchhttp/request.go +++ b/dispatchhttp/request.go @@ -1,3 +1,5 @@ +//go:build !durable + package dispatchhttp import ( diff --git a/dispatchhttp/response.go b/dispatchhttp/response.go index 4f5fe7b..9b4e827 100644 --- a/dispatchhttp/response.go +++ b/dispatchhttp/response.go @@ -1,3 +1,5 @@ +//go:build !durable + package dispatchhttp import ( diff --git a/dispatchproto/any.go b/dispatchproto/any.go index 9e5f7ba..1b6b7a8 100644 --- a/dispatchproto/any.go +++ b/dispatchproto/any.go @@ -85,9 +85,11 @@ func Duration(v time.Duration) Any { // Primitive values (booleans, integers, floats, strings, bytes, timestamps, // durations) are supported, along with values that implement either // proto.Message, json.Marshaler, encoding.TextMarshaler or -// encoding.BinaryMarshaler. +// encoding.BinaryMarshaler. Slices and maps are also supported, as long +// as they are JSON-like in shape. func Marshal(v any) (Any, error) { - if rv := reflect.ValueOf(v); rv.Kind() == reflect.Pointer && rv.IsNil() { + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Pointer && rv.IsNil() { return Nil(), nil } var m proto.Message @@ -160,7 +162,10 @@ func Marshal(v any) (Any, error) { case []byte: m = wrapperspb.Bytes(vv) default: - return Any{}, fmt.Errorf("cannot serialize %v (%T)", v, v) + var err error + if m, err = newStructpbValue(rv); err != nil { + return Any{}, fmt.Errorf("cannot serialize %v: %w", v, err) + } } proto, err := anypb.New(m) @@ -386,6 +391,10 @@ func (a Any) Unmarshal(v any) error { } } + if s, ok := m.(*structpb.Value); ok { + return fromStructpbValue(elem, s) + } + return fmt.Errorf("cannot deserialize %T into %v (%v kind)", m, elem.Type(), elem.Kind()) } @@ -404,3 +413,144 @@ func (a Any) String() string { func (a Any) Equal(other Any) bool { return proto.Equal(a.proto, other.proto) } + +func newStructpbValue(rv reflect.Value) (*structpb.Value, error) { + switch rv.Kind() { + case reflect.Bool: + return structpb.NewBoolValue(rv.Bool()), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n := rv.Int() + f := float64(n) + if int64(f) != n { + return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f) + } + return structpb.NewNumberValue(f), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n := rv.Uint() + f := float64(n) + if uint64(f) != n { + return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f) + } + return structpb.NewNumberValue(f), nil + case reflect.Float32, reflect.Float64: + return structpb.NewNumberValue(rv.Float()), nil + case reflect.String: + return structpb.NewStringValue(rv.String()), nil + case reflect.Interface: + if rv.NumMethod() == 0 { // interface{} aka. any + v := rv.Interface() + if v == nil { + return structpb.NewNullValue(), nil + } + return newStructpbValue(reflect.ValueOf(v)) + } + case reflect.Slice: + list := &structpb.ListValue{Values: make([]*structpb.Value, rv.Len())} + for i := range list.Values { + elem := rv.Index(i) + var err error + list.Values[i], err = newStructpbValue(elem) + if err != nil { + return nil, err + } + } + return structpb.NewListValue(list), nil + case reflect.Map: + strct := &structpb.Struct{Fields: make(map[string]*structpb.Value, rv.Len())} + iter := rv.MapRange() + for iter.Next() { + k := iter.Key() + + var strKey string + var hasStrKey bool + switch k.Kind() { + case reflect.String: + strKey = k.String() + hasStrKey = true + case reflect.Interface: + if s, ok := k.Interface().(string); ok { + strKey = s + hasStrKey = true + } + } + if !hasStrKey { + return nil, fmt.Errorf("cannot serialize map with %s (%s) key", k.Type(), k.Kind()) + } + + v, err := newStructpbValue(iter.Value()) + if err != nil { + return nil, err + } + strct.Fields[strKey] = v + } + return structpb.NewStructValue(strct), nil + } + return nil, fmt.Errorf("not implemented: %s", rv.Type()) +} + +func fromStructpbValue(rv reflect.Value, s *structpb.Value) error { + switch rv.Kind() { + case reflect.Bool: + if b, ok := s.Kind.(*structpb.Value_BoolValue); ok { + rv.SetBool(b.BoolValue) + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if n, ok := s.Kind.(*structpb.Value_NumberValue); ok { + rv.SetInt(int64(n.NumberValue)) + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if n, ok := s.Kind.(*structpb.Value_NumberValue); ok { + rv.SetUint(uint64(n.NumberValue)) + return nil + } + case reflect.Float32, reflect.Float64: + if n, ok := s.Kind.(*structpb.Value_NumberValue); ok { + rv.SetFloat(n.NumberValue) + return nil + } + case reflect.String: + if str, ok := s.Kind.(*structpb.Value_StringValue); ok { + rv.SetString(str.StringValue) + return nil + } + case reflect.Slice: + if l, ok := s.Kind.(*structpb.Value_ListValue); ok { + values := l.ListValue.GetValues() + rv.Grow(len(values)) + rv.SetLen(len(values)) + for i, value := range values { + if err := fromStructpbValue(rv.Index(i), value); err != nil { + return err + } + } + return nil + } + case reflect.Map: + if strct, ok := s.Kind.(*structpb.Value_StructValue); ok { + fields := strct.StructValue.Fields + rv.Set(reflect.MakeMapWithSize(rv.Type(), len(fields))) + valueType := rv.Type().Elem() + for key, value := range fields { + mv := reflect.New(valueType).Elem() + if err := fromStructpbValue(mv, value); err != nil { + return err + } + rv.SetMapIndex(reflect.ValueOf(key), mv) + } + return nil + } + case reflect.Interface: + if rv.NumMethod() == 0 { // interface{} aka. any + v := s.AsInterface() + if v == nil { + rv.SetZero() + } else { + rv.Set(reflect.ValueOf(s.AsInterface())) + } + return nil + } + } + return fmt.Errorf("cannot deserialize %T into %v (%v kind)", s, rv.Type(), rv.Kind()) +} diff --git a/dispatchproto/any_test.go b/dispatchproto/any_test.go index ed780c4..482cffe 100644 --- a/dispatchproto/any_test.go +++ b/dispatchproto/any_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "net/http" "reflect" "strings" "testing" @@ -355,6 +356,20 @@ func TestAny(t *testing.T) { List: []any{nil, false, []any{"foo", "bar"}, map[string]any{"abc": "xyz"}}, Object: map[string]any{"n": 3.14, "flag": true, "tags": []any{"x", "y", "z"}}, }}, + + // slices + []string{"foo", "bar"}, + []int{-1, 1, 111}, + []bool{true, false, true}, + []float64{3.14, 1.25}, + [][]string{{"foo", "bar"}, {"abc", "xyz"}}, + []any{3.14, true, "x", nil}, + + // maps + map[string]string{"abc": "xyz", "foo": "bar"}, + map[string]int{"n": 3}, + map[string]http.Header{"original": {"X-Foo": []string{"bar"}}}, + map[any]any{"foo": "bar", "pi": 3.14}, } { t.Run(fmt.Sprintf("%v", v), func(t *testing.T) { boxed, err := dispatchproto.Marshal(v) diff --git a/examples/fanout/main.go b/examples/fanout/main.go index f222ee3..99291b2 100644 --- a/examples/fanout/main.go +++ b/examples/fanout/main.go @@ -1,3 +1,5 @@ +//go:build !durable + package main import ( @@ -12,14 +14,14 @@ import ( func main() { getRepo := dispatch.Func("getRepo", func(ctx context.Context, name string) (*dispatchhttp.Response, error) { - return dispatchhttp.Get(context.Background(), "https://api.github.com/repos/dispatchrun/"+name) + return dispatchhttp.Get(ctx, "https://api.github.com/repos/dispatchrun/"+name) }) getStargazers := dispatch.Func("getStargazers", func(ctx context.Context, url string) (*dispatchhttp.Response, error) { - return dispatchhttp.Get(context.Background(), url) + return dispatchhttp.Get(ctx, url) }) - reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs strings) (strings, error) { + reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs []string) ([]string, error) { responses, err := getStargazers.Gather(stargazerURLs) if err != nil { return nil, err @@ -39,7 +41,7 @@ func main() { return maps.Keys(stargazers), nil }) - fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames strings) (strings, error) { + fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames []string) ([]string, error) { responses, err := getRepo.Gather(repoNames) if err != nil { return nil, err @@ -65,7 +67,7 @@ func main() { } go func() { - if _, err := fanout.Dispatch(context.Background(), strings{"coroutine", "dispatch-py"}); err != nil { + if _, err := fanout.Dispatch(context.Background(), []string{"coroutine", "dispatch-py"}); err != nil { log.Fatalf("failed to dispatch call: %v", err) } }() @@ -74,20 +76,3 @@ func main() { log.Fatalf("failed to serve endpoint: %v", err) } } - -// TODO: update dispatchproto.Marshal to support serializing slices/maps -// natively (if they can be sent on the wire as structpb.Value) -type strings []string - -func (s strings) MarshalJSON() ([]byte, error) { - return json.Marshal([]string(s)) -} - -func (s *strings) UnmarshalJSON(b []byte) error { - var c []string - if err := json.Unmarshal(b, &c); err != nil { - return err - } - *s = c - return nil -}