diff --git a/go-runtime/encoding/encoding.go b/go-runtime/encoding/encoding.go index 819a68c21..6b7a07559 100644 --- a/go-runtime/encoding/encoding.go +++ b/go-runtime/encoding/encoding.go @@ -4,22 +4,18 @@ package encoding import ( "bytes" - "encoding" "encoding/base64" "encoding/json" "fmt" + "io" "reflect" + "strings" + "time" + "unicode" "github.com/TBD54566975/ftl/backend/schema/strcase" ) -var ( - textMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - textUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() - jsonMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem() - jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() -) - func Marshal(v any) ([]byte, error) { w := &bytes.Buffer{} err := encodeValue(reflect.ValueOf(v), w) @@ -31,37 +27,29 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { w.WriteString("null") return nil } + t := v.Type() - switch { - case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonMarshaler): - v = v.Elem() - fallthrough - case t.Implements(jsonMarshaler): - enc := v.Interface().(json.Marshaler) //nolint:forcetypeassert - data, err := enc.MarshalJSON() + // Special-cased types + switch { + case t == reflect.TypeFor[time.Time](): + data, err := json.Marshal(v.Interface().(time.Time)) if err != nil { return err } w.Write(data) return nil - case t.Kind() == reflect.Ptr && t.Elem().Implements(textMarshaler): - v = v.Elem() - fallthrough - - case t.Implements(textMarshaler): - enc := v.Interface().(encoding.TextMarshaler) //nolint:forcetypeassert - data, err := enc.MarshalText() - if err != nil { - return err - } - data, err = json.Marshal(string(data)) + case t == reflect.TypeFor[json.RawMessage](): + data, err := json.Marshal(v.Interface().(json.RawMessage)) if err != nil { return err } w.Write(data) return nil + + case isOption(v.Type()): + return encodeOption(v, w) } switch v.Kind() { @@ -107,6 +95,24 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { } } +var ftlOptionTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Option" + +func isOption(t reflect.Type) bool { + return strings.HasPrefix(t.PkgPath()+"."+t.Name(), ftlOptionTypePath) +} + +func encodeOption(v reflect.Value, w *bytes.Buffer) error { + if v.NumField() != 2 { + return fmt.Errorf("value cannot have type ftl.Option since it has %d fields rather than 2: %v", v.NumField(), v) + } + optionOk := v.Field(1).Bool() + if !optionOk { + w.WriteString("null") + return nil + } + return encodeValue(v.Field(0), w) +} + func encodeStruct(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('{') afterFirst := false @@ -213,36 +219,18 @@ func Unmarshal(data []byte, v any) error { func decodeValue(d *json.Decoder, v reflect.Value) error { if !v.CanSet() { - return fmt.Errorf("cannot set value") + allBytes, _ := io.ReadAll(d.Buffered()) + return fmt.Errorf("cannot set value: %v", string(allBytes)) } t := v.Type() - switch { - case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(jsonUnmarshaler): - v = v.Addr() - fallthrough - - case t.Implements(jsonUnmarshaler): - if v.IsNil() { - v.Set(reflect.New(t.Elem())) - } - o := v.Interface() - return d.Decode(&o) - - case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(textUnmarshaler): - v = v.Addr() - fallthrough - case t.Implements(textUnmarshaler): - if v.IsNil() { - v.Set(reflect.New(t.Elem())) - } - dec := v.Interface().(encoding.TextUnmarshaler) //nolint:forcetypeassert - var s string - if err := d.Decode(&s); err != nil { - return err - } - return dec.UnmarshalText([]byte(s)) + // Special-case types + switch { + case t == reflect.TypeFor[time.Time](): + return d.Decode(v.Addr().Interface()) + case isOption(v.Type()): + return decodeOption(d, v) } switch v.Kind() { @@ -250,13 +238,15 @@ func decodeValue(d *json.Decoder, v reflect.Value) error { return decodeStruct(d, v) case reflect.Ptr: - if token, err := d.Token(); err != nil { - return err - } else if token == nil { + return handleIfNextTokenIsNull(d, func(d *json.Decoder) error { v.Set(reflect.Zero(v.Type())) return nil - } - return decodeValue(d, v.Elem()) + }, func(d *json.Decoder) error { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + return decodeValue(d, v.Elem()) + }) case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { @@ -278,6 +268,63 @@ func decodeValue(d *json.Decoder, v reflect.Value) error { } } +func handleIfNextTokenIsNull(d *json.Decoder, ifNullFn func(*json.Decoder) error, elseFn func(*json.Decoder) error) error { + isNull, err := isNextTokenNull(d) + if err != nil { + return err + } + if isNull { + err = ifNullFn(d) + if err != nil { + return err + } + // Consume the null token + _, err := d.Token() + if err != nil { + return err + } + return nil + } + return elseFn(d) +} + +// isNextTokenNull implements a cheap/dirty version of `Peek()`, which json.Decoder does +// not support. +func isNextTokenNull(d *json.Decoder) (bool, error) { + s, err := io.ReadAll(d.Buffered()) + if err != nil { + return false, err + } + if len(s) == 0 { + return false, fmt.Errorf("cannot check emptystring for token \"null\"") + } + if s[0] != ':' { + return false, fmt.Errorf("cannot check emptystring for token \"null\"") + } + i := 1 + for len(s) > i && unicode.IsSpace(rune(s[i])) { + i++ + } + if len(s) < i+4 { + return false, nil + } + return string(s[i:i+4]) == "null", nil +} + +func decodeOption(d *json.Decoder, v reflect.Value) error { + return handleIfNextTokenIsNull(d, func(d *json.Decoder) error { + v.FieldByName("Okay").SetBool(false) + return nil + }, func(d *json.Decoder) error { + err := decodeValue(d, v.FieldByName("Val")) + if err != nil { + return err + } + v.FieldByName("Okay").SetBool(true) + return nil + }) +} + func decodeStruct(d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '{'); err != nil { return err diff --git a/go-runtime/encoding/encoding_test.go b/go-runtime/encoding/encoding_test.go index 4fc6105fe..346ea921b 100644 --- a/go-runtime/encoding/encoding_test.go +++ b/go-runtime/encoding/encoding_test.go @@ -3,6 +3,7 @@ package encoding_test import ( "reflect" "testing" + "time" "github.com/alecthomas/assert/v2" @@ -31,6 +32,8 @@ func TestMarshal(t *testing.T) { {name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}, expected: `{"slice":["hello","world"]}`}, {name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}, expected: `{"map":{"foo":42}}`}, {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}, expected: `{"option":42}`}, + {name: "OptionNull", input: struct{ Option ftl.Option[int] }{ftl.None[int]()}, expected: `{"option":null}`}, + {name: "OptionZero", input: struct{ Option ftl.Option[int] }{ftl.Some(0)}, expected: `{"option":0}`}, {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}, expected: `{"option":42}`}, {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}, expected: `{"option":{"fooBar":"foo"}}`}, {name: "Unit", input: ftl.Unit{}, expected: `{}`}, @@ -69,6 +72,9 @@ func TestUnmarshal(t *testing.T) { {name: "Slice", input: `{"slice":[1,2,3]}`, expected: struct{ Slice []int }{[]int{1, 2, 3}}}, {name: "SliceOfStrings", input: `{"slice":["hello","world"]}`, expected: struct{ Slice []string }{[]string{"hello", "world"}}}, {name: "Map", input: `{"map":{"foo":42}}`, expected: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, + {name: "OptionNull", input: `{"option":null}`, expected: struct{ Option ftl.Option[int] }{ftl.None[int]()}}, + {name: "OptionNullWhitespace", input: `{"option": null}`, expected: struct{ Option ftl.Option[int] }{ftl.None[int]()}}, + {name: "OptionZero", input: `{"option":0}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(0)}}, {name: "Option", input: `{"option":42}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, {name: "OptionPtr", input: `{"option":42}`, expected: struct{ Option *ftl.Option[int] }{&somePtr}}, {name: "OptionStruct", input: `{"option":{"fooBar":"foo"}}`, expected: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, @@ -77,6 +83,12 @@ func TestUnmarshal(t *testing.T) { String string Unit ftl.Unit }{String: "something", Unit: ftl.Unit{}}}, + // Whitespaces after each `:` and multiple fields to test handling of the + // two potential terminal delimiters: `}` and `,` + {name: "ComplexFormatting", input: `{"option": null, "bool": true}`, expected: struct { + Option ftl.Option[int] + Bool bool + }{ftl.None[int](), true}}, } for _, tt := range tests { @@ -111,7 +123,9 @@ func TestRoundTrip(t *testing.T) { {name: "Slice", input: struct{ Slice []int }{[]int{1, 2, 3}}}, {name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}}, {name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, + {name: "Time", input: struct{ Time time.Time }{time.Date(2009, time.November, 29, 21, 33, 0, 0, time.UTC)}}, {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, + {name: "OptionNull", input: struct{ Option ftl.Option[int] }{ftl.None[int]()}}, {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}}, {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, {name: "Unit", input: ftl.Unit{}}, diff --git a/go-runtime/ftl/option.go b/go-runtime/ftl/option.go index 73e4b5bee..dfba72592 100644 --- a/go-runtime/ftl/option.go +++ b/go-runtime/ftl/option.go @@ -5,25 +5,20 @@ import ( "database/sql" "database/sql/driver" "encoding" - "encoding/json" "fmt" "reflect" - - ftlencoding "github.com/TBD54566975/ftl/go-runtime/encoding" ) // Stdlib interfaces types implement. type stdlib interface { fmt.Stringer fmt.GoStringer - json.Marshaler - json.Unmarshaler } // An Option type is a type that can contain a value or nothing. type Option[T any] struct { - value T - ok bool + Val T + Okay bool } var _ driver.Valuer = (*Option[int])(nil) @@ -31,69 +26,69 @@ var _ sql.Scanner = (*Option[int])(nil) func (o *Option[T]) Scan(src any) error { if src == nil { - o.ok = false + o.Okay = false var zero T - o.value = zero + o.Val = zero return nil } if value, ok := src.(T); ok { - o.value = value - o.ok = true + o.Val = value + o.Okay = true return nil } var value T switch scan := any(&value).(type) { case sql.Scanner: if err := scan.Scan(src); err != nil { - return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.value, err) + return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.Val, err) } - o.value = value - o.ok = true + o.Val = value + o.Okay = true case encoding.TextUnmarshaler: switch src := src.(type) { case string: if err := scan.UnmarshalText([]byte(src)); err != nil { - return fmt.Errorf("unmarshal from %T into Option[%T] failed: %w", src, o.value, err) + return fmt.Errorf("unmarshal from %T into Option[%T] failed: %w", src, o.Val, err) } - o.value = value - o.ok = true + o.Val = value + o.Okay = true case []byte: if err := scan.UnmarshalText(src); err != nil { - return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.value, err) + return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.Val, err) } - o.value = value - o.ok = true + o.Val = value + o.Okay = true default: - return fmt.Errorf("cannot unmarshal %T into Option[%T]", src, o.value) + return fmt.Errorf("cannot unmarshal %T into Option[%T]", src, o.Val) } default: - return fmt.Errorf("no decoding mechanism found for %T into Option[%T]", src, o.value) + return fmt.Errorf("no decoding mechanism found for %T into Option[%T]", src, o.Val) } return nil } func (o Option[T]) Value() (driver.Value, error) { - if !o.ok { + if !o.Okay { return nil, nil } - switch value := any(o.value).(type) { + switch value := any(o.Val).(type) { case driver.Valuer: return value.Value() case encoding.TextMarshaler: return value.MarshalText() } - return o.value, nil + return o.Val, nil } var _ stdlib = (*Option[int])(nil) // Some returns an Option that contains a value. -func Some[T any](value T) Option[T] { return Option[T]{value: value, ok: true} } +func Some[T any](value T) Option[T] { return Option[T]{Val: value, Okay: true} } // None returns an Option that contains nothing. func None[T any]() Option[T] { return Option[T]{} } @@ -137,65 +132,46 @@ func Zero[T any](value T) Option[T] { // Ptr returns a pointer to the value if the Option contains a value, otherwise nil. func (o Option[T]) Ptr() *T { - if o.ok { - return &o.value + if o.Okay { + return &o.Val } return nil } // Ok returns true if the Option contains a value. -func (o Option[T]) Ok() bool { return o.ok } +func (o Option[T]) Ok() bool { return o.Okay } // MustGet returns the value. It panics if the Option contains nothing. func (o Option[T]) MustGet() T { - if !o.ok { + if !o.Okay { var t T panic(fmt.Sprintf("Option[%T] contains nothing", t)) } - return o.value + return o.Val } // Get returns the value and a boolean indicating if the Option contains a value. -func (o Option[T]) Get() (T, bool) { return o.value, o.ok } +func (o Option[T]) Get() (T, bool) { return o.Val, o.Okay } // Default returns the Option value if it is present, otherwise it returns the // value passed. func (o Option[T]) Default(value T) T { - if o.ok { - return o.value + if o.Okay { + return o.Val } return value } -func (o Option[T]) MarshalJSON() ([]byte, error) { - if o.ok { - return ftlencoding.Marshal(o.value) - } - return []byte("null"), nil -} - -func (o *Option[T]) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - o.ok = false - return nil - } - if err := ftlencoding.Unmarshal(data, &o.value); err != nil { - return err - } - o.ok = true - return nil -} - func (o Option[T]) String() string { - if o.ok { - return fmt.Sprintf("%v", o.value) + if o.Okay { + return fmt.Sprintf("%v", o.Val) } return "None" } func (o Option[T]) GoString() string { - if o.ok { - return fmt.Sprintf("Some[%T](%#v)", o.value, o.value) + if o.Okay { + return fmt.Sprintf("Some[%T](%#v)", o.Val, o.Val) } - return fmt.Sprintf("None[%T]()", o.value) + return fmt.Sprintf("None[%T]()", o.Val) } diff --git a/go-runtime/ftl/option_test.go b/go-runtime/ftl/option_test.go index 13b104fc1..d07eb0427 100644 --- a/go-runtime/ftl/option_test.go +++ b/go-runtime/ftl/option_test.go @@ -2,7 +2,6 @@ package ftl import ( "database/sql" - "encoding/json" "testing" "github.com/alecthomas/assert/v2" @@ -20,27 +19,6 @@ func TestOptionGet(t *testing.T) { assert.False(t, ok) } -func TestOptionMarshalJSON(t *testing.T) { - o := Some(1) - b, err := o.MarshalJSON() - assert.NoError(t, err) - assert.Equal(t, "1", string(b)) - - o = None[int]() - b, err = o.MarshalJSON() - assert.NoError(t, err) - assert.Equal(t, "null", string(b)) -} - -func TestOptionUnmarshalJSON(t *testing.T) { - o := Option[int]{} - err := json.Unmarshal([]byte("1"), &o) - assert.NoError(t, err) - b, ok := o.Get() - assert.True(t, ok) - assert.Equal(t, 1, b) -} - func TestOptionString(t *testing.T) { o := Some(1) assert.Equal(t, "1", o.String())