From 36d1630b7901b27bfe1642f873731292c385702e Mon Sep 17 00:00:00 2001 From: Cesar <137245636+nytzuga@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:44:14 -0300 Subject: [PATCH] Add nullable option to codec (#2171) Signed-off-by: Cesar <137245636+nytzuga@users.noreply.github.com> Co-authored-by: Stephen Buttolph Co-authored-by: Dan Laine --- codec/reflectcodec/struct_fielder.go | 19 ++- codec/reflectcodec/type_codec.go | 217 ++++++++++++++++++-------- codec/reflectcodec/type_codec_test.go | 30 ++++ codec/test_codec.go | 94 +++++++++-- 4 files changed, 283 insertions(+), 77 deletions(-) create mode 100644 codec/reflectcodec/type_codec_test.go diff --git a/codec/reflectcodec/struct_fielder.go b/codec/reflectcodec/struct_fielder.go index 27edb58a1793..d266b60a3ebf 100644 --- a/codec/reflectcodec/struct_fielder.go +++ b/codec/reflectcodec/struct_fielder.go @@ -18,6 +18,10 @@ const ( // TagValue is the value the tag must have to be serialized. TagValue = "true" + + // TagValue is the value the tag must have to be serialized, this variant + // includes the nullable option + TagWithNullableValue = "true,nullable" ) var _ StructFielder = (*structFielder)(nil) @@ -25,6 +29,7 @@ var _ StructFielder = (*structFielder)(nil) type FieldDesc struct { Index int MaxSliceLen uint32 + Nullable bool } // StructFielder handles discovery of serializable fields in a struct. @@ -82,10 +87,19 @@ func (s *structFielder) GetSerializedFields(t reflect.Type) ([]FieldDesc, error) // Multiple tags per fields can be specified. // Serialize/Deserialize field if it has // any tag with the right value - captureField := false + var ( + captureField bool + nullable bool + ) for _, tag := range s.tags { - if field.Tag.Get(tag) == TagValue { + switch field.Tag.Get(tag) { + case TagValue: + captureField = true + case TagWithNullableValue: captureField = true + nullable = true + } + if captureField { break } } @@ -107,6 +121,7 @@ func (s *structFielder) GetSerializedFields(t reflect.Type) ([]FieldDesc, error) serializedFields = append(serializedFields, FieldDesc{ Index: i, MaxSliceLen: maxSliceLen, + Nullable: nullable, }) } s.serializedFieldIndices[t] = serializedFields // cache result diff --git a/codec/reflectcodec/type_codec.go b/codec/reflectcodec/type_codec.go index ac9ca25c16e7..9f9037f43d4e 100644 --- a/codec/reflectcodec/type_codec.go +++ b/codec/reflectcodec/type_codec.go @@ -13,18 +13,23 @@ import ( "golang.org/x/exp/slices" "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/wrappers" ) -// DefaultTagName that enables serialization. -const DefaultTagName = "serialize" +const ( + // DefaultTagName that enables serialization. + DefaultTagName = "serialize" + initialSliceLen = 16 +) var ( _ codec.Codec = (*genericCodec)(nil) - errMarshalNil = errors.New("can't marshal nil pointer or interface") - errUnmarshalNil = errors.New("can't unmarshal nil") - errNeedPointer = errors.New("argument to unmarshal must be a pointer") + errMarshalNil = errors.New("can't marshal nil pointer or interface") + errUnmarshalNil = errors.New("can't unmarshal nil") + errNeedPointer = errors.New("argument to unmarshal must be a pointer") + errRecursiveInterfaceTypes = errors.New("recursive interface types") ) type TypeCodec interface { @@ -85,12 +90,18 @@ func (c *genericCodec) Size(value interface{}) (int, error) { return 0, errMarshalNil // can't marshal nil } - size, _, err := c.size(reflect.ValueOf(value)) + size, _, err := c.size(reflect.ValueOf(value), false /*=nullable*/, nil /*=typeStack*/) return size, err } -// size returns the size of the value along with whether the value is constant sized. -func (c *genericCodec) size(value reflect.Value) (int, bool, error) { +// size returns the size of the value along with whether the value is constant +// sized. This function takes into account a `nullable` property which allows +// pointers and interfaces to serialize nil values +func (c *genericCodec) size( + value reflect.Value, + nullable bool, + typeStack set.Set[reflect.Type], +) (int, bool, error) { switch valueKind := value.Kind(); valueKind { case reflect.Uint8: return wrappers.ByteLen, true, nil @@ -114,24 +125,41 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { return wrappers.StringLen(value.String()), false, nil case reflect.Ptr: if value.IsNil() { - // Can't marshal nil pointers (but nil slices are fine) - return 0, false, errMarshalNil + if !nullable { + return 0, false, errMarshalNil + } + return wrappers.BoolLen, false, nil + } + + size, constSize, err := c.size(value.Elem(), false /*=nullable*/, typeStack) + if nullable { + return wrappers.BoolLen + size, false, err } - return c.size(value.Elem()) + return size, constSize, err case reflect.Interface: if value.IsNil() { - // Can't marshal nil interfaces (but nil slices are fine) - return 0, false, errMarshalNil + if !nullable { + return 0, false, errMarshalNil + } + return wrappers.BoolLen, false, nil } + underlyingValue := value.Interface() underlyingType := reflect.TypeOf(underlyingValue) + if typeStack.Contains(underlyingType) { + return 0, false, fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, underlyingType) + } + typeStack.Add(underlyingType) + prefixSize := c.typer.PrefixSize(underlyingType) - valueSize, _, err := c.size(value.Elem()) - if err != nil { - return 0, false, err + valueSize, _, err := c.size(value.Elem(), false /*=nullable*/, typeStack) + + typeStack.Remove(underlyingType) + if nullable { + return wrappers.BoolLen + prefixSize + valueSize, false, err } - return prefixSize + valueSize, false, nil + return prefixSize + valueSize, false, err case reflect.Slice: numElts := value.Len() @@ -139,7 +167,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { return wrappers.IntLen, false, nil } - size, constSize, err := c.size(value.Index(0)) + size, constSize, err := c.size(value.Index(0), nullable, typeStack) if err != nil { return 0, false, err } @@ -151,7 +179,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { } for i := 1; i < numElts; i++ { - innerSize, _, err := c.size(value.Index(i)) + innerSize, _, err := c.size(value.Index(i), nullable, typeStack) if err != nil { return 0, false, err } @@ -165,7 +193,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { return 0, true, nil } - size, constSize, err := c.size(value.Index(0)) + size, constSize, err := c.size(value.Index(0), nullable, typeStack) if err != nil { return 0, false, err } @@ -177,7 +205,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { } for i := 1; i < numElts; i++ { - innerSize, _, err := c.size(value.Index(i)) + innerSize, _, err := c.size(value.Index(i), nullable, typeStack) if err != nil { return 0, false, err } @@ -196,7 +224,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { constSize = true ) for _, fieldDesc := range serializedFields { - innerSize, innerConstSize, err := c.size(value.Field(fieldDesc.Index)) + innerSize, innerConstSize, err := c.size(value.Field(fieldDesc.Index), fieldDesc.Nullable, typeStack) if err != nil { return 0, false, err } @@ -211,11 +239,11 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { return wrappers.IntLen, false, nil } - keySize, keyConstSize, err := c.size(iter.Key()) + keySize, keyConstSize, err := c.size(iter.Key(), false /*=nullable*/, typeStack) if err != nil { return 0, false, err } - valueSize, valueConstSize, err := c.size(iter.Value()) + valueSize, valueConstSize, err := c.size(iter.Value(), nullable, typeStack) if err != nil { return 0, false, err } @@ -230,7 +258,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { totalValueSize = valueSize ) for iter.Next() { - valueSize, _, err := c.size(iter.Value()) + valueSize, _, err := c.size(iter.Value(), nullable, typeStack) if err != nil { return 0, false, err } @@ -244,7 +272,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { totalKeySize = keySize ) for iter.Next() { - keySize, _, err := c.size(iter.Key()) + keySize, _, err := c.size(iter.Key(), false /*=nullable*/, typeStack) if err != nil { return 0, false, err } @@ -255,11 +283,11 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) { default: totalSize := wrappers.IntLen + keySize + valueSize for iter.Next() { - keySize, _, err := c.size(iter.Key()) + keySize, _, err := c.size(iter.Key(), false /*=nullable*/, typeStack) if err != nil { return 0, false, err } - valueSize, _, err := c.size(iter.Value()) + valueSize, _, err := c.size(iter.Value(), nullable, typeStack) if err != nil { return 0, false, err } @@ -279,13 +307,19 @@ func (c *genericCodec) MarshalInto(value interface{}, p *wrappers.Packer) error return errMarshalNil // can't marshal nil } - return c.marshal(reflect.ValueOf(value), p, c.maxSliceLen) + return c.marshal(reflect.ValueOf(value), p, c.maxSliceLen, false /*=nullable*/, nil /*=typeStack*/) } // marshal writes the byte representation of [value] to [p] -// [value]'s underlying value must not be a nil pointer or interface +// // c.lock should be held for the duration of this function -func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSliceLen uint32) error { +func (c *genericCodec) marshal( + value reflect.Value, + p *wrappers.Packer, + maxSliceLen uint32, + nullable bool, + typeStack set.Set[reflect.Type], +) error { switch valueKind := value.Kind(); valueKind { case reflect.Uint8: p.PackByte(uint8(value.Uint())) @@ -318,22 +352,41 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice p.PackBool(value.Bool()) return p.Err case reflect.Ptr: - if value.IsNil() { // Can't marshal nil (except nil slices) + isNil := value.IsNil() + if nullable { + p.PackBool(isNil) + if isNil || p.Err != nil { + return p.Err + } + } else if isNil { return errMarshalNil } - return c.marshal(value.Elem(), p, c.maxSliceLen) + + return c.marshal(value.Elem(), p, c.maxSliceLen, false /*=nullable*/, typeStack) case reflect.Interface: - if value.IsNil() { // Can't marshal nil (except nil slices) + isNil := value.IsNil() + if nullable { + p.PackBool(isNil) + if isNil || p.Err != nil { + return p.Err + } + } else if isNil { return errMarshalNil } + underlyingValue := value.Interface() underlyingType := reflect.TypeOf(underlyingValue) + if typeStack.Contains(underlyingType) { + return fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, underlyingType) + } + typeStack.Add(underlyingType) if err := c.typer.PackPrefix(p, underlyingType); err != nil { return err } - if err := c.marshal(value.Elem(), p, c.maxSliceLen); err != nil { + if err := c.marshal(value.Elem(), p, c.maxSliceLen, false /*=nullable*/, typeStack); err != nil { return err } + typeStack.Remove(underlyingType) return p.Err case reflect.Slice: numElts := value.Len() // # elements in the slice/array. 0 if this slice is nil. @@ -361,7 +414,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice return p.Err } for i := 0; i < numElts; i++ { // Process each element in the slice - if err := c.marshal(value.Index(i), p, c.maxSliceLen); err != nil { + if err := c.marshal(value.Index(i), p, c.maxSliceLen, nullable, typeStack); err != nil { return err } } @@ -381,7 +434,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice ) } for i := 0; i < numElts; i++ { // Process each element in the array - if err := c.marshal(value.Index(i), p, c.maxSliceLen); err != nil { + if err := c.marshal(value.Index(i), p, c.maxSliceLen, nullable, typeStack); err != nil { return err } } @@ -392,7 +445,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice return err } for _, fieldDesc := range serializedFields { // Go through all fields of this struct that are serialized - if err := c.marshal(value.Field(fieldDesc.Index), p, fieldDesc.MaxSliceLen); err != nil { // Serialize the field and write to byte array + if err := c.marshal(value.Field(fieldDesc.Index), p, fieldDesc.MaxSliceLen, fieldDesc.Nullable, typeStack); err != nil { // Serialize the field and write to byte array return err } } @@ -423,7 +476,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice startOffset := p.Offset endOffset := p.Offset for i, key := range keys { - if err := c.marshal(key, p, c.maxSliceLen); err != nil { + if err := c.marshal(key, p, c.maxSliceLen, false /*=nullable*/, typeStack); err != nil { return err } if p.Err != nil { @@ -456,7 +509,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice } // serialize and pack value - if err := c.marshal(value.MapIndex(key.key), p, c.maxSliceLen); err != nil { + if err := c.marshal(value.MapIndex(key.key), p, c.maxSliceLen, nullable, typeStack); err != nil { return err } } @@ -467,8 +520,8 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice } } -// Unmarshal unmarshals [bytes] into [dest], where -// [dest] must be a pointer or interface +// Unmarshal unmarshals [bytes] into [dest], where [dest] must be a pointer or +// interface func (c *genericCodec) Unmarshal(bytes []byte, dest interface{}) error { if dest == nil { return errUnmarshalNil @@ -481,7 +534,7 @@ func (c *genericCodec) Unmarshal(bytes []byte, dest interface{}) error { if destPtr.Kind() != reflect.Ptr { return errNeedPointer } - if err := c.unmarshal(&p, destPtr.Elem(), c.maxSliceLen); err != nil { + if err := c.unmarshal(&p, destPtr.Elem(), c.maxSliceLen, false /*=nullable*/, nil /*=typeStack*/); err != nil { return err } if p.Offset != len(bytes) { @@ -495,8 +548,19 @@ func (c *genericCodec) Unmarshal(bytes []byte, dest interface{}) error { } // Unmarshal from p.Bytes into [value]. [value] must be addressable. +// +// The [nullable] property affects how pointers and interfaces are unmarshalled, +// as an extra byte would be used to unmarshal nil values for pointers and +// interaces +// // c.lock should be held for the duration of this function -func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSliceLen uint32) error { +func (c *genericCodec) unmarshal( + p *wrappers.Packer, + value reflect.Value, + maxSliceLen uint32, + nullable bool, + typeStack set.Set[reflect.Type], +) error { switch value.Kind() { case reflect.Uint8: value.SetUint(uint64(p.UnpackByte())) @@ -573,18 +637,22 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli } numElts := int(numElts32) + sliceType := value.Type() + innerType := sliceType.Elem() + // If this is a slice of bytes, manually unpack the bytes rather // than calling unmarshal on each byte. This improves performance. - if elemKind := value.Type().Elem().Kind(); elemKind == reflect.Uint8 { + if elemKind := innerType.Kind(); elemKind == reflect.Uint8 { value.SetBytes(p.UnpackFixedBytes(numElts)) return p.Err } - // set [value] to be a slice of the appropriate type/capacity (right now it is nil) - value.Set(reflect.MakeSlice(value.Type(), numElts, numElts)) - // Unmarshal each element into the appropriate index of the slice + // Unmarshal each element and append it into the slice. + value.Set(reflect.MakeSlice(sliceType, 0, initialSliceLen)) + zeroValue := reflect.Zero(innerType) for i := 0; i < numElts; i++ { - if err := c.unmarshal(p, value.Index(i), c.maxSliceLen); err != nil { - return fmt.Errorf("couldn't unmarshal slice element: %w", err) + value.Set(reflect.Append(value, zeroValue)) + if err := c.unmarshal(p, value.Index(i), c.maxSliceLen, nullable, typeStack); err != nil { + return err } } return nil @@ -601,8 +669,8 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli return nil } for i := 0; i < numElts; i++ { - if err := c.unmarshal(p, value.Index(i), c.maxSliceLen); err != nil { - return fmt.Errorf("couldn't unmarshal array element: %w", err) + if err := c.unmarshal(p, value.Index(i), c.maxSliceLen, nullable, typeStack); err != nil { + return err } } return nil @@ -613,15 +681,29 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli } return nil case reflect.Interface: + if nullable { + isNil := p.UnpackBool() + if isNil || p.Err != nil { + return p.Err + } + } + intfImplementor, err := c.typer.UnpackPrefix(p, value.Type()) if err != nil { return err } + intfImplementorType := intfImplementor.Type() + if typeStack.Contains(intfImplementorType) { + return fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, intfImplementorType) + } + typeStack.Add(intfImplementorType) + // Unmarshal into the struct - if err := c.unmarshal(p, intfImplementor, c.maxSliceLen); err != nil { - return fmt.Errorf("couldn't unmarshal interface: %w", err) + if err := c.unmarshal(p, intfImplementor, c.maxSliceLen, false /*=nullable*/, typeStack); err != nil { + return err } - // And assign the filled struct to the value + + typeStack.Remove(intfImplementorType) value.Set(intfImplementor) return nil case reflect.Struct: @@ -632,19 +714,26 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli } // Go through the fields and umarshal into them for _, fieldDesc := range serializedFieldIndices { - if err := c.unmarshal(p, value.Field(fieldDesc.Index), fieldDesc.MaxSliceLen); err != nil { - return fmt.Errorf("couldn't unmarshal struct: %w", err) + if err := c.unmarshal(p, value.Field(fieldDesc.Index), fieldDesc.MaxSliceLen, fieldDesc.Nullable, typeStack); err != nil { + return err } } return nil case reflect.Ptr: + if nullable { + isNil := p.UnpackBool() + if isNil || p.Err != nil { + return p.Err + } + } + // Get the type this pointer points to t := value.Type().Elem() // Create a new pointer to a new value of the underlying type v := reflect.New(t) // Fill the value - if err := c.unmarshal(p, v.Elem(), c.maxSliceLen); err != nil { - return fmt.Errorf("couldn't unmarshal pointer: %w", err) + if err := c.unmarshal(p, v.Elem(), c.maxSliceLen, false /*=nullable*/, typeStack); err != nil { + return err } // Assign to the top-level struct's member value.Set(v) @@ -671,15 +760,15 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli ) // Set [value] to be a new map of the appropriate type. - value.Set(reflect.MakeMapWithSize(mapType, numElts)) + value.Set(reflect.MakeMap(mapType)) for i := 0; i < numElts; i++ { mapKey := reflect.New(mapKeyType).Elem() keyStartOffset := p.Offset - if err := c.unmarshal(p, mapKey, c.maxSliceLen); err != nil { - return fmt.Errorf("couldn't unmarshal map key (%s): %w", mapKeyType, err) + if err := c.unmarshal(p, mapKey, c.maxSliceLen, false /*=nullable*/, typeStack); err != nil { + return err } // Get the key's byte representation and check that the new key is @@ -696,8 +785,8 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli // Get the value mapValue := reflect.New(mapValueType).Elem() - if err := c.unmarshal(p, mapValue, c.maxSliceLen); err != nil { - return fmt.Errorf("couldn't unmarshal map value for key %s: %w", mapKey, err) + if err := c.unmarshal(p, mapValue, c.maxSliceLen, nullable, typeStack); err != nil { + return err } // Assign the key-value pair in the map diff --git a/codec/reflectcodec/type_codec_test.go b/codec/reflectcodec/type_codec_test.go new file mode 100644 index 000000000000..42b256c4a6c9 --- /dev/null +++ b/codec/reflectcodec/type_codec_test.go @@ -0,0 +1,30 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package reflectcodec + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSizeWithNil(t *testing.T) { + require := require.New(t) + var x *int32 + y := int32(1) + c := genericCodec{} + _, _, err := c.size(reflect.ValueOf(x), false /*=nullable*/, nil /*=typeStack*/) + require.ErrorIs(err, errMarshalNil) + len, _, err := c.size(reflect.ValueOf(x), true /*=nullable*/, nil /*=typeStack*/) + require.Empty(err) + require.Equal(1, len) + x = &y + len, _, err = c.size(reflect.ValueOf(y), true /*=nullable*/, nil /*=typeStack*/) + require.Empty(err) + require.Equal(4, len) + len, _, err = c.size(reflect.ValueOf(x), true /*=nullable*/, nil /*=typeStack*/) + require.Empty(err) + require.Equal(5, len) +} diff --git a/codec/test_codec.go b/codec/test_codec.go index 7177e08e81b1..341912a823af 100644 --- a/codec/test_codec.go +++ b/codec/test_codec.go @@ -23,6 +23,7 @@ var ( TestBigArray, TestPointerToStruct, TestSliceOfStruct, + TestStructWithNullable, TestInterface, TestSliceOfInterface, TestArrayOfInterface, @@ -63,7 +64,8 @@ type Foo interface { } type MyInnerStruct struct { - Str string `serialize:"true"` + Str string `serialize:"true"` + NumberNotProvided *int32 `serialize:"true,nullable"` } func (*MyInnerStruct) Foo() int { @@ -86,6 +88,15 @@ type MyInnerStruct3 struct { F Foo `serialize:"true"` } +type MyStructWithNullable struct { + Interface any `serialize:"true,nullable"` + Int32 *int32 `serialize:"true,nullable"` + Int64 *int64 `serialize:"true,nullable"` + Int32Slice []*int32 `serialize:"true,nullable"` + Int32Array [2]*int32 `serialize:"true,nullable"` + Int32Map map[int32]*int32 `serialize:"true,nullable"` +} + type myStruct struct { InnerStruct MyInnerStruct `serialize:"true"` InnerStruct2 *MyInnerStruct `serialize:"true"` @@ -145,21 +156,23 @@ func TestStruct(codec GeneralCodec, t testing.TB) { myMap7["key"] = "value" myMap7[int32(1)] = int32(2) + number := int32(8) + myStructInstance := myStruct{ - InnerStruct: MyInnerStruct{"hello"}, - InnerStruct2: &MyInnerStruct{"yello"}, + InnerStruct: MyInnerStruct{"hello", nil}, + InnerStruct2: &MyInnerStruct{"yello", nil}, Member1: 1, Member2: 2, MySlice: []byte{1, 2, 3, 4}, MySlice2: []string{"one", "two", "three"}, - MySlice3: []MyInnerStruct{{"abc"}, {"ab"}, {"c"}}, + MySlice3: []MyInnerStruct{{"abc", nil}, {"ab", &number}, {"c", nil}}, MySlice4: []*MyInnerStruct2{{true}, {}}, MySlice5: []Foo{&MyInnerStruct2{true}, &MyInnerStruct2{}}, MyArray: [4]byte{5, 6, 7, 8}, MyArray2: [5]string{"four", "five", "six", "seven"}, - MyArray3: [3]MyInnerStruct{{"d"}, {"e"}, {"f"}}, + MyArray3: [3]MyInnerStruct{{"d", nil}, {"e", nil}, {"f", nil}}, MyArray4: [2]*MyInnerStruct2{{}, {true}}, - MyInterface: &MyInnerStruct{"yeet"}, + MyInterface: &MyInnerStruct{"yeet", &number}, InnerStruct3: MyInnerStruct3{ Str: "str", M1: MyInnerStruct{ @@ -414,20 +427,79 @@ func TestPointerToStruct(codec GeneralCodec, t testing.TB) { require.Equal(myPtr, myPtrUnmarshaled) } +func TestStructWithNullable(codec GeneralCodec, t testing.TB) { + require := require.New(t) + n1 := int32(5) + n2 := int64(10) + struct1 := MyStructWithNullable{ + Interface: nil, + Int32: &n1, + Int64: &n2, + Int32Slice: []*int32{ + nil, + nil, + &n1, + }, + Int32Array: [2]*int32{ + nil, + &n1, + }, + Int32Map: map[int32]*int32{ + 1: nil, + 2: &n1, + }, + } + + require.NoError(codec.RegisterType(&MyStructWithNullable{})) + manager := NewDefaultManager() + require.NoError(manager.RegisterCodec(0, codec)) + + bytes, err := manager.Marshal(0, struct1) + require.NoError(err) + + bytesLen, err := manager.Size(0, struct1) + require.NoError(err) + require.Len(bytes, bytesLen) + + var struct1Unmarshaled MyStructWithNullable + version, err := manager.Unmarshal(bytes, &struct1Unmarshaled) + require.NoError(err) + require.Zero(version) + require.Equal(struct1, struct1Unmarshaled) + + struct1 = MyStructWithNullable{ + Int32Slice: []*int32{}, + Int32Map: map[int32]*int32{}, + } + bytes, err = manager.Marshal(0, struct1) + require.NoError(err) + + bytesLen, err = manager.Size(0, struct1) + require.NoError(err) + require.Len(bytes, bytesLen) + + var struct1Unmarshaled2 MyStructWithNullable + version, err = manager.Unmarshal(bytes, &struct1Unmarshaled2) + require.NoError(err) + require.Zero(version) + require.Equal(struct1, struct1Unmarshaled2) +} + // Test marshalling a slice of structs func TestSliceOfStruct(codec GeneralCodec, t testing.TB) { require := require.New(t) - + n1 := int32(-1) + n2 := int32(0xff) mySlice := []MyInnerStruct3{ { Str: "One", - M1: MyInnerStruct{"Two"}, - F: &MyInnerStruct{"Three"}, + M1: MyInnerStruct{"Two", &n1}, + F: &MyInnerStruct{"Three", &n2}, }, { Str: "Four", - M1: MyInnerStruct{"Five"}, - F: &MyInnerStruct{"Six"}, + M1: MyInnerStruct{"Five", nil}, + F: &MyInnerStruct{"Six", nil}, }, } require.NoError(codec.RegisterType(&MyInnerStruct{}))