Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for encoding.TextUnmarshaler #6

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
32 changes: 18 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ AWS Lambda functions. It should be suitable for additional applications.

Set some parameters in [AWS Parameter Store](https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html):

| Name | Value | Type | Key ID |
| ---------------------------- | -------------------- | ------------ | ------------- |
| /exmaple_service/prod/debug | false | String | - |
| /exmaple_service/prod/port | 8080 | String | - |
| /exmaple_service/prod/user | Ian | String | - |
| /exmaple_service/prod/rate | 0.5 | String | - |
| /exmaple_service/prod/secret | zOcZkAGB6aEjN7SAoVBT | SecureString | alias/aws/ssm |
| Name | Value | Type | Key ID |
| ---------------------------- | -------------------- | ------------ | ------------- |
| /exmaple_service/prod/debug | false | String | - |
| /exmaple_service/prod/port | 8080 | String | - |
| /exmaple_service/prod/user | Ian | String | - |
| /exmaple_service/prod/rate | 0.5 | String | - |
| /exmaple_service/prod/secret | zOcZkAGB6aEjN7SAoVBT | SecureString | alias/aws/ssm |
| /exmaple_service/prod/ts | 2020-04-14T21:26:00+02:00 | String | - |

Write some code:

Expand All @@ -36,11 +37,12 @@ import (
)

type Config struct {
Debug bool `smm:"debug" default:"true"`
Port int `smm:"port"`
User string `smm:"user"`
Rate float32 `smm:"rate"`
Secret string `smm:"secret" required:"true"`
Debug bool `smm:"debug" default:"true"`
Port int `smm:"port"`
User string `smm:"user"`
Rate float32 `smm:"rate"`
Secret string `smm:"secret" required:"true"`
Timestamp time.Time `smm:"ts" required:"true"`
}

func main() {
Expand All @@ -50,8 +52,8 @@ func main() {
log.Fatal(err.Error())
}

format := "Debug: %v\nPort: %d\nUser: %s\nRate: %f\nSecret: %s\n"
_, err = fmt.Printf(format, c.Debug, c.Port, c.User, c.Rate, c.Secret)
format := "Debug: %v\nPort: %d\nUser: %s\nRate: %f\nSecret: %s\nTimestamp: %s\n"
_, err = fmt.Printf(format, c.Debug, c.Port, c.User, c.Rate, c.Secret, c.Timestamp)
if err != nil {
log.Fatal(err.Error())
}
Expand All @@ -66,6 +68,7 @@ Port: 8080
User: Ian
Rate: 0.500000
Secret: zOcZkAGB6aEjN7SAoVBT
Timestamp: 2020-04-14 21:26:00 +0200 +0200
```

[Additional examples](https://godoc.org/github.com/ianlopshire/go-ssm-config#pkg-examples) can be found in godoc.
Expand Down Expand Up @@ -101,6 +104,7 @@ ssmconfig supports these struct field types:
* int, int8, int16, int32, int64
* bool
* float32, float64
* encoding.TextUnmarshaler

More supported types may be added in the future.

Expand Down
44 changes: 41 additions & 3 deletions ssmconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package ssmconfig

import (
"encoding"
"path"
"reflect"
"strconv"
Expand Down Expand Up @@ -33,6 +34,10 @@ type Provider struct {
SSM ssmiface.SSMAPI
}

const unmarshalTextMethod = "UnmarshalText"

var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

// Process loads config values from smm (parameter store) into c. Encrypted parameters
// will automatically be decrypted. c must be a pointer to a struct.
//
Expand All @@ -59,7 +64,8 @@ func (p *Provider) Process(configPath string, c interface{}) error {
return errors.New("ssmconfig: c must be a pointer to a struct")
}

spec := buildStructSpec(configPath, v.Type())
t := v.Type()
spec := buildStructSpec(configPath, t)

params, invalidPrams, err := p.getParameters(spec)
if err != nil {
Expand All @@ -84,9 +90,16 @@ func (p *Provider) Process(configPath string, c interface{}) error {
continue
}

err = setValue(v.Field(i), value)
valueField := v.Field(i)
typeField := t.Field(i)
var err error
if isTextUnmarshaler(typeField) {
err = unmarshalText(typeField, valueField, value)
} else {
err = setValue(valueField, value)
}
Copy link
Owner

@ianlopshire ianlopshire May 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move all of the logic for TextUnmarshaler into the setValue function.

We also need to handle the case where the value is already a pointer to the implementation of TextUnmarshaled.

Something like:

func setValue(v reflect.Value, s string) error {
	t := v.Type()

	if t.Implements(textUnmarshalerType) {
		return unmarshalText(t, v, s, /*shouldAddr*/ false)
	}
	if reflect.PtrTo(t).Implements(textUnmarshalerType) {
		return unmarshalText(t, v, s, /*shouldAddr*/ true)
	}
	
	//...
}

if err != nil {
return errors.Wrapf(err, "ssmconfig: error setting field %s", v.Type().Field(i).Name)
return errors.Wrapf(err, "ssmconfig: error setting field %s", typeField.Name)
}
}

Expand Down Expand Up @@ -129,6 +142,31 @@ func (p *Provider) getParameters(spec structSpec) (params map[string]string, inv
return params, invalidParams, nil
}

// Checks whether the value implements the TextUnmarshaler interface.
func isTextUnmarshaler(f reflect.StructField) bool {
return reflect.PtrTo(f.Type).Implements(textUnmarshalerType)
}

// Create a new instance of the field's type and call its UnmarshalText([]byte) method.
// Set the value after execution and fail if the method returns an error.
func unmarshalText(f reflect.StructField, v reflect.Value, s string) error {
Copy link
Owner

@ianlopshire ianlopshire May 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified quite a bit. The Interface() method on reflect.Value allows you to access the value as a regular go interface.

We also need to handle the case where the value is already a pointer to the implementation of TextUnmarshaled. I would handle this by accepting a shouldAddr argument.

My implementation of this would probably look something like:

func unmarshalText(t reflect.Type, v reflect.Value, s string, shouldAddr bool) error {
	if shouldAddr {
		v = v.Addr()
	}
	if t.Kind() == reflect.Ptr && v.IsNil() {
		v.Set(reflect.New(t.Elem()))
	}
	return v.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(s))
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will definitely change the much simpler method invocation, thanks!

I didn't add support for pointers on purpose as all the other types don't support pointers yet. I thought this could be addressed in a new issue and possibly implemented using a more dynamic / recursive approach for all supported types.

What do you think? Should I add pointer support already or not?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmmmm, Good point.

I think I would go ahead and add pointer support for TextUnmarshaler. It starts to get confusing what is and isn't supported when you don't allow pointers for interface types.

Copy link
Author

@qexpres qexpres May 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✔️

I also added tests for empty non-required parameters and added another example to the readme.

newV := reflect.New(f.Type)
method := newV.MethodByName(unmarshalTextMethod)

args := []reflect.Value{reflect.ValueOf([]byte(s))}
values := method.Call(args)

v.Set(newV.Elem())

// implementation only returns an error value
if !values[0].IsNil() {
err := values[0].Elem().Interface().(error)
return errors.Errorf("could not decode %q into type %v: %v", s, f.Type.String(), err)
}

return nil
}

func setValue(v reflect.Value, s string) error {
switch v.Kind() {
case reflect.String:
Expand Down
70 changes: 60 additions & 10 deletions ssmconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package ssmconfig_test

import (
"errors"
"net"
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ssm"
Expand All @@ -29,16 +31,20 @@ func (c *mockSSMClient) GetParameters(input *ssm.GetParametersInput) (*ssm.GetPa
func TestProvider_Process(t *testing.T) {
t.Run("base case", func(t *testing.T) {
var s struct {
S1 string `ssm:"/strings/s1"`
S2 string `ssm:"/strings/s2" default:"string2"`
I1 int `ssm:"/int/i1"`
I2 int `ssm:"/int/i2" default:"42"`
B1 bool `ssm:"/bool/b1"`
B2 bool `ssm:"/bool/b2" default:"false"`
F321 float32 `ssm:"/float32/f321"`
F322 float32 `ssm:"/float32/f322" default:"42.42"`
F641 float64 `ssm:"/float64/f641"`
F642 float64 `ssm:"/float64/f642" default:"42.42"`
S1 string `ssm:"/strings/s1"`
S2 string `ssm:"/strings/s2" default:"string2"`
I1 int `ssm:"/int/i1"`
I2 int `ssm:"/int/i2" default:"42"`
B1 bool `ssm:"/bool/b1"`
B2 bool `ssm:"/bool/b2" default:"false"`
F321 float32 `ssm:"/float32/f321"`
F322 float32 `ssm:"/float32/f322" default:"42.42"`
F641 float64 `ssm:"/float64/f641"`
F642 float64 `ssm:"/float64/f642" default:"42.42"`
TU1 time.Time `ssm:"/text_unmarshaler/time1"`
TU2 time.Time `ssm:"/text_unmarshaler/time2" default:"2020-04-14T21:26:00+02:00"`
TU3 net.IP `ssm:"/text_unmarshaler/ipv41"`
TU4 net.IP `ssm:"/text_unmarshaler/ipv42" default:"127.0.0.1"`
Invalid string
}

Expand All @@ -65,6 +71,14 @@ func TestProvider_Process(t *testing.T) {
Name: aws.String("/base/float64/f641"),
Value: aws.String("42.42"),
},
{
Name: aws.String("/base/text_unmarshaler/time1"),
Value: aws.String("2020-04-14T21:26:00+02:00"),
},
{
Name: aws.String("/base/text_unmarshaler/ipv41"),
Value: aws.String("127.0.0.1"),
},
},
},
}
Expand Down Expand Up @@ -94,6 +108,10 @@ func TestProvider_Process(t *testing.T) {
"/base/float32/f322",
"/base/float64/f641",
"/base/float64/f642",
"/base/text_unmarshaler/time1",
"/base/text_unmarshaler/time2",
"/base/text_unmarshaler/ipv41",
"/base/text_unmarshaler/ipv42",
}

if !reflect.DeepEqual(names, expectedNames) {
Expand Down Expand Up @@ -130,6 +148,20 @@ func TestProvider_Process(t *testing.T) {
if s.F642 != 42.42 {
t.Errorf("Process() F642 unexpected value: want %f, have %f", 42.42, s.F642)
}
tm, _ := time.Parse(time.RFC3339, "2020-04-14T21:26:00+02:00")
if !s.TU1.Equal(tm) {
t.Errorf("Process() TU1 unexpected value: want %v, have %v", tm, s.TU1)
}
if !s.TU2.Equal(tm) {
t.Errorf("Process() TU2 unexpected value: want %v, have %v", tm, s.TU2)
}
ip := net.ParseIP("127.0.0.1")
if !s.TU3.Equal(ip) {
t.Errorf("Process() TU1 unexpected value: want %v, have %v", ip, s.TU3)
}
if !s.TU4.Equal(ip) {
t.Errorf("Process() TU2 unexpected value: want %v, have %v", ip, s.TU4)
}
if s.Invalid != "" {
t.Errorf("Process() Missing unexpected value: want %q, have %q", "", s.Invalid)
}
Expand Down Expand Up @@ -179,6 +211,24 @@ func TestProvider_Process(t *testing.T) {
client: &mockSSMClient{},
shouldErr: true,
},
{
name: "invalid unmarshal text",
configPath: "/base/",
c: &struct {
TU1 bool `ssm:"/text_unmarshaler/time1" default:"notATime"`
}{},
client: &mockSSMClient{},
shouldErr: true,
},
{
name: "invalid unmarshal text",
configPath: "/base/",
c: &struct {
TU3 bool `ssm:"/text_unmarshaler/ipv41" default:"notAnIP"`
}{},
client: &mockSSMClient{},
shouldErr: true,
},
{
name: "missing required parameter",
configPath: "/base/",
Expand Down