diff --git a/transform/flatten_mangler.go b/transform/flatten_mangler.go index fcde82c..65a9641 100644 --- a/transform/flatten_mangler.go +++ b/transform/flatten_mangler.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/fatih/structtag" + "github.com/vimeo/dials/common" "github.com/vimeo/dials/tagformat/caseconversion" ) @@ -206,7 +207,8 @@ func (f *FlattenMangler) getTag(sf *reflect.StructField, tags, flattenedPath []s func (f *FlattenMangler) Unmangle(sf reflect.StructField, vs []FieldValueTuple) (reflect.Value, error) { val := reflect.New(sf.Type).Elem() - output, err := populateStruct(val, vs, 0) + + output, _, err := populateStruct(val, vs, 0) if err != nil { return val, err } @@ -218,18 +220,29 @@ func (f *FlattenMangler) Unmangle(sf reflect.StructField, vs []FieldValueTuple) return val, nil } +func isNil(val reflect.Value) bool { + switch val.Kind() { + case reflect.Pointer, reflect.Chan, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface: + return val.IsNil() + default: + return false + } +} + // populateStruct populates the original value with values from the flattend values -func populateStruct(originalVal reflect.Value, vs []FieldValueTuple, inputIndex int) (int, error) { +// bool return value indicates whether any inner fields were non-nil (ignore if error-set) +func populateStruct(originalVal reflect.Value, vs []FieldValueTuple, inputIndex int) (int, bool, error) { if !originalVal.CanSet() { - return inputIndex, fmt.Errorf("error unmangling %s. Need addressable type, actual %q", originalVal, originalVal.Type().Kind()) + return inputIndex, false, fmt.Errorf("error unmangling %s. Need addressable type, actual %q", originalVal, originalVal.Type().Kind()) } kind, vt := getUnderlyingKindType(originalVal.Type()) + anyChildSet := false switch kind { case reflect.Struct: // go through each field if the struct doesn't implement TextUnmarshaler - if vt.Implements(textMReflectType) || reflect.PtrTo(vt).Implements(textMReflectType) { + if vt.Implements(textMReflectType) || reflect.PointerTo(vt).Implements(textMReflectType) { break } // the originalVal is a pointer and to go through the fields, we need @@ -247,36 +260,45 @@ func populateStruct(originalVal reflect.Value, vs []FieldValueTuple, inputIndex switch kind { case reflect.Struct: // don't flatten if the struct implements TextUnmarshaler - if t.Implements(textMReflectType) || reflect.PtrTo(t).Implements(textMReflectType) { + if t.Implements(textMReflectType) || reflect.PointerTo(t).Implements(textMReflectType) { break // break out of the case, still stays within the for loop } var err error - inputIndex, err = populateStruct(nestedVal, vs, inputIndex) + var nestedAnySet bool + inputIndex, nestedAnySet, err = populateStruct(nestedVal, vs, inputIndex) if err != nil { - return inputIndex, err + return inputIndex, false, err } + anyChildSet = anyChildSet || nestedAnySet continue default: } if !nestedVal.CanSet() { - return inputIndex, fmt.Errorf("nested value %s under %s cannot be set", nestedVal, originalVal) + return inputIndex, false, fmt.Errorf("nested value %s under %s cannot be set", nestedVal, originalVal) } if !vs[inputIndex].Value.Type().AssignableTo(nestedVal.Type()) { - return inputIndex, fmt.Errorf("error unmangling. Expected type %s. Actual type %s", vs[inputIndex].Value.Type(), nestedVal.Type()) + return inputIndex, false, fmt.Errorf("error unmangling. Expected type %s. Actual type %s", vs[inputIndex].Value.Type(), nestedVal.Type()) + } + if !isNil(vs[inputIndex].Value) { + nestedVal.Set(vs[inputIndex].Value) + anyChildSet = true } - nestedVal.Set(vs[inputIndex].Value) inputIndex++ } - setVal.Elem().Set(val) - originalVal.Set(setVal) - return inputIndex, nil - default: + if anyChildSet { + setVal.Elem().Set(val) + originalVal.Set(setVal) + } + return inputIndex, anyChildSet, nil + } + val := vs[inputIndex].Value + if !isNil(val) { + originalVal.Set(val) + anyChildSet = true } - originalVal.Set(vs[inputIndex].Value) inputIndex++ - - return inputIndex, nil + return inputIndex, anyChildSet, nil } // ShouldRecurse returns false because Mangle walks through nested structs and doesn't need Transform's recursion diff --git a/transform/flatten_mangler_test.go b/transform/flatten_mangler_test.go index a2066ae..5a9d837 100644 --- a/transform/flatten_mangler_test.go +++ b/transform/flatten_mangler_test.go @@ -148,8 +148,17 @@ func TestFlattenMangler(t *testing.T) { }, modify: func(t testing.TB, val reflect.Value) {}, assertion: func(t testing.TB, i interface{}) { - // should be empty struct since none of the fields are exposed - assert.Equal(t, struct{}{}, *i.(*struct{})) + if i == nil { + t.Error("nil Unmangle output") + } + s, ok := i.(*struct{}) + if !ok { + t.Errorf("unexpected type %T; expected *struct{}", i) + return + } + if s != nil { + t.Errorf("non-nil Unmangle output for empty struct (with type %T) %+[1]v", s) + } }, }, { @@ -207,6 +216,66 @@ func TestFlattenMangler(t *testing.T) { assert.Equal(t, st, i) }, }, + { + name: "nil nested struct", + testStruct: b, + modify: func(t testing.TB, val reflect.Value) { + + expectedDialsTags := []string{ + "config_field_Name", + "config_field_Foobar_Location", + "config_field_Foobar_Coordinates", + "config_field_Foobar_some_time", + "config_field_AnotherField", + } + + expectedFieldTags := []string{ + "ConfigField,Name", + "ConfigField,Foobar,Location", + "ConfigField,Foobar,Coordinates", + "ConfigField,Foobar,SomeTime", + "ConfigField,AnotherField", + } + + for i := 0; i < val.Type().NumField(); i++ { + f := val.Type().Field(i) + assert.EqualValues(t, expectedDialsTags[i], f.Tag.Get(common.DialsTagName)) + assert.EqualValues(t, expectedFieldTags[i], f.Tag.Get(dialsFieldPathTag)) + if f.Type.Kind() != reflect.Pointer { + t.Errorf("field %d has kind %s, not %s", i, f.Type.Kind(), reflect.Pointer) + } + } + + s1 := "test" + i2 := 42 + + val.Field(0).Set(reflect.ValueOf(&s1)) + val.Field(1).Set(reflect.Zero(reflect.TypeOf((*string)(nil)))) + val.Field(2).Set(reflect.Zero(reflect.TypeOf((*int)(nil)))) + val.Field(3).Set(reflect.Zero(reflect.TypeOf((*time.Duration)(nil)))) + val.Field(4).Set(reflect.ValueOf(&i2)) + }, + assertion: func(t testing.TB, i interface{}) { + // all the fields are pointerified because of call to Pointerify + s1 := "test" + i2 := 42 + b := struct { + Name *string `dials:"Name"` + Foobar *struct { + Location *string `dials:"Location"` + Coordinates *int `dials:"Coordinates"` + SomeTime *time.Duration + } `dials:"Foobar"` + AnotherField *int `dials:"AnotherField"` + }{ + Name: &s1, + Foobar: nil, + AnotherField: &i2, + } + + assert.EqualValues(t, &b, i) + }, + }, { name: "multilevel nested struct", testStruct: b,