diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..2a5734d --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,25 @@ +name: Go + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: 1.22 + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v ./... diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8c2998f..0000000 --- a/.travis.yml +++ /dev/null @@ -1,5 +0,0 @@ -language: go - -go: - - "1.10.2" - - master diff --git a/README.md b/README.md index c657eae..92bc39e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # go-env -[![Build Status](https://travis-ci.org/Netflix/go-env.svg?branch=master)](https://travis-ci.org/Netflix/go-env) -[![GoDoc](https://godoc.org/github.com/Netflix/go-env?status.svg)](https://godoc.org/github.com/Netflix/go-env) +![Build Status](https://github.com/Netflix/go-env/actions/workflows/build.yml/badge.svg) +[![Go Reference](https://pkg.go.dev/badge/github.com/Netflix/go-env.svg)](https://pkg.go.dev/github.com/Netflix/go-env) [![NetflixOSS Lifecycle](https://img.shields.io/osslifecycle/Netflix/go-expect.svg)]() @@ -14,8 +14,9 @@ package main import ( "log" + "time" - env "github.com/Netflix/go-env" + "github.com/Netflix/go-env" ) type Environment struct { @@ -32,6 +33,11 @@ type Environment struct { } Extras env.EnvSet + + Duration time.Duration `env:"TYPE_DURATION"` + DefaultValue string `env:"MISSING_VAR,default=default_value"` + RequiredValue string `env:"IM_REQUIRED,required=true"` + ArrayValue []string `env:"ARRAY_VALUE,default=value1|value2|value3"` } func main() { @@ -45,7 +51,7 @@ func main() { // ... - es, err = env.Marshal(environment) + es, err = env.Marshal(&environment) if err != nil { log.Fatal(err) } @@ -59,11 +65,79 @@ func main() { es.Apply(cs) environment = Environment{} - err = env.Unmarshal(es, &environment) - if err != nil { + if err = env.Unmarshal(es, &environment); err != nil { log.Fatal(err) } environment.Extras = es } ``` + +This will initially throw an error if `IM_REQUIRED` is not set in the environment as part of the env struct validation. + +This error can be resolved by setting the `IM_REQUIRED` environment variable manually in the environment or by setting it in the +code prior to calling `UnmarshalFromEnviron` with: +```go +os.Setenv("IM_REQUIRED", "some_value") +``` + +## Custom Marshaler/Unmarshaler + +There is limited support for dictating how a field should be marshaled or unmarshaled. The following example +shows how you could marshal/unmarshal from JSON + +```go +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/Netflix/go-env" +) + +type SomeData struct { + SomeField int `json:"someField"` +} + +func (s *SomeData) UnmarshalEnvironmentValue(data string) error { + var tmp SomeData + if err := json.Unmarshal([]byte(data), &tmp); err != nil { + return err + } + *s = tmp + return nil +} + +func (s SomeData) MarshalEnvironmentValue() (string, error) { + bytes, err := json.Marshal(s) + if err != nil { + return "", err + } + return string(bytes), nil +} + +type Config struct { + SomeData *SomeData `env:"SOME_DATA"` +} + +func main() { + var cfg Config + if _, err := env.UnmarshalFromEnviron(&cfg); err != nil { + log.Fatal(err) + } + + if cfg.SomeData != nil && cfg.SomeData.SomeField == 42 { + fmt.Println("Got 42!") + } else { + fmt.Printf("Got nil or some other value: %v\n", cfg.SomeData) + } + + es, err := env.Marshal(&cfg) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Got the following: %+v\n", es) +} +``` diff --git a/env.go b/env.go index 3ddc8c8..dc09f74 100644 --- a/env.go +++ b/env.go @@ -23,6 +23,18 @@ import ( "reflect" "strconv" "strings" + "time" +) + +const ( + // tagKeyDefault is the key used in the struct field tag to specify a default + tagKeyDefault = "default" + // tagKeyRequired is the key used in the struct field tag to specify that the + // field is required + tagKeyRequired = "required" + // tagKeySeparator is the key used in the struct field tag to specify a + // separator for slice fields + tagKeySeparator = "separator" ) var ( @@ -35,8 +47,22 @@ var ( // ErrUnexportedField returned when a field with tag "env" is not exported. ErrUnexportedField = errors.New("field must be exported") + + // unmarshalType is the reflect.Type element of the Unmarshaler interface + unmarshalType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() ) +// ErrMissingRequiredValue returned when a field with required=true contains no value or default +type ErrMissingRequiredValue struct { + // Value is the type of value that is required to provide error context to + // the caller + Value string +} + +func (e ErrMissingRequiredValue) Error() string { + return fmt.Sprintf("value for this field is required [%s]", e.Value) +} + // Unmarshal parses an EnvSet and stores the result in the value pointed to by // v. Fields that are matched in v will be deleted from EnvSet, resulting in // an EnvSet with the remaining environment variables. If v is nil or not a @@ -60,17 +86,13 @@ func Unmarshal(es EnvSet, v interface{}) error { } t := rv.Type() - for i := 0; i < rv.NumField(); i++ { + for i := range rv.NumField() { valueField := rv.Field(i) - switch valueField.Kind() { - case reflect.Struct: + if valueField.Kind() == reflect.Struct { if !valueField.Addr().CanInterface() { continue } - - iface := valueField.Addr().Interface() - err := Unmarshal(es, iface) - if err != nil { + if err := Unmarshal(es, valueField.Addr().Interface()); err != nil { return err } } @@ -85,13 +107,13 @@ func Unmarshal(es EnvSet, v interface{}) error { return ErrUnexportedField } - envKeys := strings.Split(tag, ",") + envTag := parseTag(tag) var ( envValue string - ok bool + ok bool ) - for _, envKey := range envKeys { + for _, envKey := range envTag.Keys { envValue, ok = es[envKey] if ok { break @@ -99,11 +121,16 @@ func Unmarshal(es EnvSet, v interface{}) error { } if !ok { - continue + if envTag.Default != "" { + envValue = envTag.Default + } else if envTag.Required { + return &ErrMissingRequiredValue{Value: envTag.Keys[0]} + } else { + continue + } } - err := set(typeField.Type, valueField, envValue) - if err != nil { + if err := set(typeField.Type, valueField, envValue, envTag.Separator); err != nil { return err } delete(es, tag) @@ -112,12 +139,42 @@ func Unmarshal(es EnvSet, v interface{}) error { return nil } -func set(t reflect.Type, f reflect.Value, value string) error { +func set(t reflect.Type, f reflect.Value, value, sliceSeparator string) error { + // See if the type implements Unmarshaler and use that first, + // otherwise, fallback to the previous logic + var isUnmarshaler bool + isPtr := t.Kind() == reflect.Ptr + if isPtr { + isUnmarshaler = t.Implements(unmarshalType) && f.CanInterface() + } else if f.CanAddr() { + isUnmarshaler = f.Addr().Type().Implements(unmarshalType) && f.Addr().CanInterface() + } + + if isUnmarshaler { + var ptr reflect.Value + if isPtr { + // In the pointer case, we need to create a new element to have an + // address to point to + ptr = reflect.New(t.Elem()) + } else { + // And for scalars, we need the pointer to be able to modify the value + ptr = f.Addr() + } + if u, ok := ptr.Interface().(Unmarshaler); ok { + if err := u.UnmarshalEnvironmentValue(value); err != nil { + return err + } + if isPtr { + f.Set(ptr) + } + return nil + } + } + switch t.Kind() { case reflect.Ptr: ptr := reflect.New(t.Elem()) - err := set(t.Elem(), ptr.Elem(), value) - if err != nil { + if err := set(t.Elem(), ptr.Elem(), value, sliceSeparator); err != nil { return err } f.Set(ptr) @@ -129,12 +186,58 @@ func set(t reflect.Type, f reflect.Value, value string) error { return err } f.SetBool(v) + case reflect.Float32: + v, err := strconv.ParseFloat(value, 32) + if err != nil { + return err + } + f.SetFloat(v) + case reflect.Float64: + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + f.SetFloat(v) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if t.PkgPath() == "time" && t.Name() == "Duration" { + duration, err := time.ParseDuration(value) + if err != nil { + return err + } + + f.Set(reflect.ValueOf(duration)) + break + } + v, err := strconv.Atoi(value) if err != nil { return err } f.SetInt(int64(v)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + f.SetUint(v) + case reflect.Slice: + if sliceSeparator == "" { + sliceSeparator = "|" + } + values := strings.Split(value, sliceSeparator) + switch t.Elem().Kind() { + case reflect.String: + // already []string, just set directly + f.Set(reflect.ValueOf(values)) + default: + dest := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(values), len(values)) + for i, v := range values { + if err := set(t.Elem(), dest.Index(i), v, sliceSeparator); err != nil { + return err + } + } + f.Set(dest) + } default: return ErrUnsupportedType } @@ -182,16 +285,14 @@ func Marshal(v interface{}) (EnvSet, error) { es := make(EnvSet) t := rv.Type() - for i := 0; i < rv.NumField(); i++ { + for i := range rv.NumField() { valueField := rv.Field(i) - switch valueField.Kind() { - case reflect.Struct: + if valueField.Kind() == reflect.Struct { if !valueField.Addr().CanInterface() { continue } - iface := valueField.Addr().Interface() - nes, err := Marshal(iface) + nes, err := Marshal(valueField.Addr().Interface()) if err != nil { return nil, err } @@ -209,20 +310,78 @@ func Marshal(v interface{}) (EnvSet, error) { envKeys := strings.Split(tag, ",") - var envValue string + var el interface{} if typeField.Type.Kind() == reflect.Ptr { if valueField.IsNil() { continue } - envValue = fmt.Sprintf("%v", valueField.Elem().Interface()) + el = valueField.Elem().Interface() } else { - envValue = fmt.Sprintf("%v", valueField.Interface()) + el = valueField.Interface() + } + + var ( + err error + envValue string + ) + if m, ok := el.(Marshaler); ok { + envValue, err = m.MarshalEnvironmentValue() + if err != nil { + return nil, err + } + } else { + envValue = fmt.Sprintf("%v", el) } for _, envKey := range envKeys { + // Skip keys with '=', as they represent tag options and not environment variable names. + if strings.Contains(envKey, "=") { + switch strings.ToLower(strings.SplitN(envKey, "=", 2)[0]) { + case "separator", "required", "default": + continue + } + } es[envKey] = envValue } } return es, nil } + +// tag is a struct used to store the parsed "env" field tag when unmarshalling. +type tag struct { + // Keys is used to store the keys specified in the "env" field tag + Keys []string + // Default is used to specify a default value for the field + Default string + // Required is used to specify that the field is required + Required bool + // Separator is used to split the value of a slice field + Separator string +} + +// parseTag is used in the Unmarshal function to parse the "env" field tags +// into a tag struct for use in the set function. +func parseTag(tagString string) tag { + var t tag + envKeys := strings.Split(tagString, ",") + for _, key := range envKeys { + if !strings.Contains(key, "=") { + t.Keys = append(t.Keys, key) + continue + } + keyData := strings.SplitN(key, "=", 2) + switch strings.ToLower(keyData[0]) { + case tagKeyDefault: + t.Default = keyData[1] + case tagKeyRequired: + t.Required = strings.ToLower(keyData[1]) == "true" + case tagKeySeparator: + t.Separator = keyData[1] + default: + // just ignoring unsupported keys + continue + } + } + return t +} diff --git a/env_test.go b/env_test.go index 7dfae48..a7ea948 100644 --- a/env_test.go +++ b/env_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -14,7 +14,11 @@ package env import ( + "encoding/base64" + "encoding/json" + "errors" "os" + "reflect" "testing" "time" ) @@ -37,6 +41,9 @@ type ValidStruct struct { // PointerInt should work along with other supported types. PointerInt *int `env:"POINTER_INT"` + // PointerUint should work along with other supported types. + PointerUint *uint `env:"POINTER_UINT"` + // PointerPointerString should be recursed into. PointerPointerString **string `env:"POINTER_POINTER_STRING"` @@ -47,10 +54,31 @@ type ValidStruct struct { Extra string // Additional supported types - Int int `env:"INT"` - Bool bool `env:"BOOL"` + Int int `env:"INT"` + Uint uint `env:"UINT"` + Float32 float32 `env:"FLOAT32"` + Float64 float64 `env:"FLOAT64"` + Bool bool `env:"BOOL"` MultipleTags string `env:"npm_config_cache,NPM_CONFIG_CACHE"` + + MultipleTagsWithDefault string `env:"multiple_tags_with_default,MULTIPLE_TAGS_WITH_DEFAULT,default=default_tags_value"` + + TagWithDefault string `env:"tag_with_default,default=default_tag_value"` + + TagWithRequired string `env:"tag_with_required,required=false"` + + TagWithSeparator string `env:"tag_with_separator,separator=&"` + + // time.Duration is supported + Duration time.Duration `env:"TYPE_DURATION"` + + // Custom unmarshaler should support scalar types + Base64EncodedString Base64EncodedString `env:"BASE64_ENCODED_STRING"` + // Custom unmarshaler should support struct types + JSONData JSONData `env:"JSON_DATA"` + // Custom unmarshaler should support pointer types as well + PointerJSONData *JSONData `env:"POINTER_JSON_DATA"` } type UnsupportedStruct struct { @@ -61,20 +89,96 @@ type UnexportedStruct struct { home string `env:"HOME"` } -func TestUnmarshal(t *testing.T) { - environ := map[string]string{ - "HOME": "/home/test", - "WORKSPACE": "/mnt/builds/slave/workspace/test", - "EXTRA": "extra", - "INT": "1", - "BOOL": "true", - "npm_config_cache": "first", - "NPM_CONFIG_CACHE": "second", +type DefaultValueStruct struct { + DefaultString string `env:"MISSING_STRING,default=found"` + DefaultKeyValueString string `env:"MISSING_KVSTRING,default=key=value"` + DefaultBool bool `env:"MISSING_BOOL,default=true"` + DefaultInt int `env:"MISSING_INT,default=7"` + DefaultUint uint `env:"MISSING_UINT,default=4294967295"` + DefaultFloat32 float32 `env:"MISSING_FLOAT32,default=8.9"` + DefaultFloat64 float64 `env:"MISSING_FLOAT64,default=10.11"` + DefaultDuration time.Duration `env:"MISSING_DURATION,default=5s"` + DefaultStringSlice []string `env:"MISSING_STRING_SLICE,default=separate|values"` + DefaultSliceWithSeparator []string `env:"ANOTHER_MISSING_STRING_SLICE,default=separate&values,separator=&"` + DefaultRequiredSlice []string `env:"MISSING_REQUIRED_DEFAULT,default=other|things,required=true"` + DefaultWithOptionsMissing string `env:"MISSING_1,MISSING_2,default=present"` + DefaultWithOptionsPresent string `env:"MISSING_1,PRESENT,default=present"` +} + +type RequiredValueStruct struct { + Required string `env:"REQUIRED_VAL,required=true"` + RequiredMore string `env:"REQUIRED_VAL_MORE,required=true"` + RequiredWithDefault string `env:"REQUIRED_MISSING,default=myValue,required=true"` + NotRequired string `env:"NOT_REQUIRED,required=false"` + InvalidExtra string `env:"INVALID,invalid=invalid"` +} + +type Base64EncodedString string + +func (b *Base64EncodedString) UnmarshalEnvironmentValue(data string) error { + value, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return err } + *b = Base64EncodedString(value) + return nil +} - var validStruct ValidStruct - err := Unmarshal(environ, &validStruct) +func (b Base64EncodedString) MarshalEnvironmentValue() (string, error) { + return base64.StdEncoding.EncodeToString([]byte(b)), nil +} + +type JSONData struct { + SomeField int `json:"someField"` +} + +func (j *JSONData) UnmarshalEnvironmentValue(data string) error { + var tmp JSONData + if err := json.Unmarshal([]byte(data), &tmp); err != nil { + return err + } + *j = tmp + return nil +} + +func (j JSONData) MarshalEnvironmentValue() (string, error) { + bytes, err := json.Marshal(j) if err != nil { + return "", err + } + return string(bytes), nil +} + +type IterValuesStruct struct { + StringSlice []string `env:"STRING"` + IntSlice []int `env:"INT"` + Int64Slice []int64 `env:"INT64"` + DurationSlice []time.Duration `env:"DURATION"` + BoolSlice []bool `env:"BOOL"` + KVStringSlice []string `env:"KV"` + WithSeparator []int `env:"SEPARATOR,separator=&"` +} + +func TestUnmarshal(t *testing.T) { + t.Parallel() + var ( + environ = map[string]string{ + "HOME": "/home/test", + "WORKSPACE": "/mnt/builds/slave/workspace/test", + "EXTRA": "extra", + "INT": "1", + "UINT": "4294967295", + "FLOAT32": "2.3", + "FLOAT64": "4.5", + "BOOL": "true", + "npm_config_cache": "first", + "NPM_CONFIG_CACHE": "second", + "TYPE_DURATION": "5s", + } + validStruct ValidStruct + ) + + if err := Unmarshal(environ, &validStruct); err != nil { t.Errorf("Expected no error but got '%s'", err) } @@ -98,6 +202,18 @@ func TestUnmarshal(t *testing.T) { t.Errorf("Expected field value to be '%d' but got '%d'", 1, validStruct.Int) } + if validStruct.Uint != 4294967295 { + t.Errorf("Expected field value to be '%d' but got '%d'", 4294967295, validStruct.Uint) + } + + if validStruct.Float32 != 2.3 { + t.Errorf("Expected field value to be '%f' but got '%f'", 2.3, validStruct.Float32) + } + + if validStruct.Float64 != 4.5 { + t.Errorf("Expected field value to be '%f' but got '%f'", 4.5, validStruct.Float64) + } + if validStruct.Bool != true { t.Errorf("Expected field value to be '%t' but got '%t'", true, validStruct.Bool) } @@ -106,13 +222,15 @@ func TestUnmarshal(t *testing.T) { t.Errorf("Expected field value to be '%s' but got '%s'", "first", validStruct.MultipleTags) } - v, ok := environ["HOME"] - if ok { + if validStruct.Duration != 5*time.Second { + t.Errorf("Expected field value to be '%s' but got '%s'", "5s", validStruct.Duration) + } + + if v, ok := environ["HOME"]; ok { t.Errorf("Expected field '%s' to not exist but got '%s'", "HOME", v) } - v, ok = environ["EXTRA"] - if !ok { + if v, ok := environ["EXTRA"]; !ok { t.Errorf("Expected field '%s' to exist but missing", "EXTRA") } else if v != "extra" { t.Errorf("Expected field value to be '%s' but got '%s'", "extra", v) @@ -120,15 +238,18 @@ func TestUnmarshal(t *testing.T) { } func TestUnmarshalPointer(t *testing.T) { - environ := map[string]string{ - "POINTER_STRING": "", - "POINTER_INT": "1", - "POINTER_POINTER_STRING": "", - } + t.Parallel() + var ( + environ = map[string]string{ + "POINTER_STRING": "", + "POINTER_INT": "1", + "POINTER_UINT": "4294967295", + "POINTER_POINTER_STRING": "", + } + validStruct ValidStruct + ) - var validStruct ValidStruct - err := Unmarshal(environ, &validStruct) - if err != nil { + if err := Unmarshal(environ, &validStruct); err != nil { t.Errorf("Expected no error but got '%s'", err) } @@ -144,6 +265,12 @@ func TestUnmarshalPointer(t *testing.T) { t.Errorf("Expected field value to be '%d' but got '%d'", 1, *validStruct.PointerInt) } + if validStruct.PointerUint == nil { + t.Errorf("Expected field value to be '%d' but got '%v'", 4294967295, nil) + } else if *validStruct.PointerUint != 4294967295 { + t.Errorf("Expected field value to be '%d' but got '%d'", 4294967295, *validStruct.PointerUint) + } + if validStruct.PointerPointerString == nil { t.Errorf("Expected field value to be '%s' but got '%v'", "", nil) } else { @@ -159,35 +286,68 @@ func TestUnmarshalPointer(t *testing.T) { } } +func TestCustomUnmarshal(t *testing.T) { + t.Parallel() + var ( + someValue = "some value" + environ = map[string]string{ + "BASE64_ENCODED_STRING": base64.StdEncoding.EncodeToString([]byte(someValue)), + "JSON_DATA": `{ "someField": 42 }`, + "POINTER_JSON_DATA": `{ "someField": 43 }`, + } + validStruct ValidStruct + ) + + if err := Unmarshal(environ, &validStruct); err != nil { + t.Errorf("Expected no error but got '%s'", err) + } + + if validStruct.Base64EncodedString != Base64EncodedString(someValue) { + t.Errorf("Expected field value to be '%s' but got '%s'", someValue, string(validStruct.Base64EncodedString)) + } + + if validStruct.PointerJSONData == nil { + t.Errorf("Expected field value to be '%s' but got '%v'", someValue, nil) + } else if validStruct.PointerJSONData.SomeField != 43 { + t.Errorf("Expected field value to be '%d' but got '%d'", 43, validStruct.PointerJSONData.SomeField) + } + + if validStruct.JSONData.SomeField != 42 { + t.Errorf("Expected field value to be '%d' but got '%d'", 42, validStruct.JSONData.SomeField) + } +} + func TestUnmarshalInvalid(t *testing.T) { - environ := make(map[string]string) + t.Parallel() + var ( + environ = make(map[string]string) + validStruct ValidStruct + ) - var validStruct ValidStruct - err := Unmarshal(environ, validStruct) - if err != ErrInvalidValue { + if err := Unmarshal(environ, validStruct); !errors.Is(err, ErrInvalidValue) { t.Errorf("Expected error 'ErrInvalidValue' but got '%s'", err) } ptr := &validStruct - err = Unmarshal(environ, &ptr) - if err != ErrInvalidValue { + if err := Unmarshal(environ, &ptr); !errors.Is(err, ErrInvalidValue) { t.Errorf("Expected error 'ErrInvalidValue' but got '%s'", err) } } func TestUnmarshalUnsupported(t *testing.T) { - environ := map[string]string{ - "TIMESTAMP": "2016-07-15T12:00:00.000Z", - } + t.Parallel() + var ( + environ = map[string]string{"TIMESTAMP": "2016-07-15T12:00:00.000Z"} + unsupportedStruct UnsupportedStruct + ) - var unsupportedStruct UnsupportedStruct - err := Unmarshal(environ, &unsupportedStruct) - if err != ErrUnsupportedType { + if err := Unmarshal(environ, &unsupportedStruct); !errors.Is(err, ErrUnsupportedType) { t.Errorf("Expected error 'ErrUnsupportedType' but got '%s'", err) } } func TestUnmarshalFromEnviron(t *testing.T) { + t.Parallel() environ := os.Environ() es, err := EnvironToEnvSet(environ) @@ -207,25 +367,134 @@ func TestUnmarshalFromEnviron(t *testing.T) { t.Errorf("Expected environment variable to be '%s' but got '%s'", home, validStruct.Home) } - v, ok := es["HOME"] - if ok { + if v, ok := es["HOME"]; ok { t.Errorf("Expected field '%s' to not exist but got '%s'", "HOME", v) } } func TestUnmarshalUnexported(t *testing.T) { - environ := map[string]string{ - "HOME": "/home/edgarl", - } + t.Parallel() + var ( + environ = map[string]string{"HOME": "/home/edgarl"} + unexportedStruct UnexportedStruct + ) - var unexportedStruct UnexportedStruct - err := Unmarshal(environ, &unexportedStruct) - if err != ErrUnexportedField { + if err := Unmarshal(environ, &unexportedStruct); !errors.Is(err, ErrUnexportedField) { t.Errorf("Expected error 'ErrUnexportedField' but got '%s'", err) } } +func TestUnmarshalSlice(t *testing.T) { + t.Parallel() + var ( + environ = map[string]string{ + "STRING": "separate|values", + "INT": "1|2", + "INT64": "3|4", + "DURATION": "60s|70h", + "BOOL": "true|false", + "KV": "k1=v1|k2=v2", + "SEPARATOR": "1&2", // struct has `separator=&` + } + iterValStruct IterValuesStruct + ) + + if err := Unmarshal(environ, &iterValStruct); err != nil { + t.Errorf("Expected no error but got %v", err) + return + } + + testCases := [][]interface{}{ + {iterValStruct.StringSlice, []string{"separate", "values"}}, + {iterValStruct.IntSlice, []int{1, 2}}, + {iterValStruct.Int64Slice, []int64{3, 4}}, + {iterValStruct.DurationSlice, []time.Duration{time.Second * 60, time.Hour * 70}}, + {iterValStruct.BoolSlice, []bool{true, false}}, + {iterValStruct.KVStringSlice, []string{"k1=v1", "k2=v2"}}, + {iterValStruct.WithSeparator, []int{1, 2}}, + } + for _, testCase := range testCases { + if !reflect.DeepEqual(testCase[0], testCase[1]) { + t.Errorf("Expected field value to be '%v' but got '%v'", testCase[1], testCase[0]) + } + } +} + +func TestUnmarshalDefaultValues(t *testing.T) { + t.Parallel() + var ( + environ = map[string]string{"PRESENT": "youFoundMe"} + defaultValueStruct DefaultValueStruct + ) + + if err := Unmarshal(environ, &defaultValueStruct); err != nil { + t.Errorf("Expected no error but got %s", err) + } + + testCases := [][]interface{}{ + {defaultValueStruct.DefaultInt, 7}, + {defaultValueStruct.DefaultUint, uint(4294967295)}, + {defaultValueStruct.DefaultFloat32, float32(8.9)}, + {defaultValueStruct.DefaultFloat64, 10.11}, + {defaultValueStruct.DefaultBool, true}, + {defaultValueStruct.DefaultString, "found"}, + {defaultValueStruct.DefaultKeyValueString, "key=value"}, + {defaultValueStruct.DefaultDuration, 5 * time.Second}, + {defaultValueStruct.DefaultStringSlice, []string{"separate", "values"}}, + {defaultValueStruct.DefaultSliceWithSeparator, []string{"separate", "values"}}, + {defaultValueStruct.DefaultRequiredSlice, []string{"other", "things"}}, + {defaultValueStruct.DefaultWithOptionsMissing, "present"}, + {defaultValueStruct.DefaultWithOptionsPresent, "youFoundMe"}, + } + for _, testCase := range testCases { + if !reflect.DeepEqual(testCase[0], testCase[1]) { + t.Errorf("Expected field value to be '%v' but got '%v'", testCase[1], testCase[0]) + } + } +} + +func TestUnmarshalRequiredValues(t *testing.T) { + t.Parallel() + var ( + environ = make(map[string]string) + requiredValuesStruct RequiredValueStruct + ) + + // Try missing REQUIRED_VAL and REQUIRED_VAL_MORE + err := Unmarshal(environ, &requiredValuesStruct) + if err == nil { + t.Errorf("Expected error 'ErrMissingRequiredValue' but got '%s'", err) + } + errMissing := ErrMissingRequiredValue{Value: "REQUIRED_VAL"} + if err.Error() != errMissing.Error() { + t.Errorf("Expected error 'ErrMissingRequiredValue' but got '%s'", err) + } + + // Fill REQUIRED_VAL and retry REQUIRED_VAL_MORE + environ["REQUIRED_VAL"] = "required" + err = Unmarshal(environ, &requiredValuesStruct) + if err == nil { + t.Errorf("Expected error 'ErrMissingRequiredValue' but got '%s'", err) + } + errMissing = ErrMissingRequiredValue{Value: "REQUIRED_VAL_MORE"} + if err.Error() != errMissing.Error() { + t.Errorf("Expected error 'ErrMissingRequiredValue' but got '%s'", err) + } + + environ["REQUIRED_VAL_MORE"] = "required" + if err = Unmarshal(environ, &requiredValuesStruct); err != nil { + t.Errorf("Expected no error but got '%s'", err) + } + if requiredValuesStruct.Required != "required" { + t.Errorf("Expected field value to be '%s' but got '%s'", "required", requiredValuesStruct.Required) + } + if requiredValuesStruct.RequiredWithDefault != "myValue" { + t.Errorf("Expected field value to be '%s' but got '%s'", "myValue", requiredValuesStruct.RequiredWithDefault) + } +} + func TestMarshal(t *testing.T) { + t.Parallel() validStruct := ValidStruct{ Home: "/home/test", Jenkins: struct { @@ -234,10 +503,18 @@ func TestMarshal(t *testing.T) { }{ Workspace: "/mnt/builds/slave/workspace/test", }, - Extra: "extra", - Int: 1, - Bool: true, - MultipleTags: "foobar", + Extra: "extra", + Int: 1, + Uint: 4294967295, + Float32: float32(2.3), + Float64: 4.5, + Bool: true, + MultipleTags: "foobar", + MultipleTagsWithDefault: "baz", + TagWithDefault: "bar", + TagWithRequired: "foo", + TagWithSeparator: "val1&val2", + Duration: 3 * time.Minute, } environ, err := Marshal(&validStruct) @@ -261,6 +538,18 @@ func TestMarshal(t *testing.T) { t.Errorf("Expected field value to be '%s' but got '%s'", "1", environ["INT"]) } + if environ["UINT"] != "4294967295" { + t.Errorf("Expected field value to be '%s' but got '%s'", "2", environ["UINT"]) + } + + if environ["FLOAT32"] != "2.3" { + t.Errorf("Expected field value to be '%s' but got '%s'", "2.3", environ["FLOAT32"]) + } + + if environ["FLOAT64"] != "4.5" { + t.Errorf("Expected field value to be '%s' but got '%s'", "4.5", environ["FLOAT64"]) + } + if environ["BOOL"] != "true" { t.Errorf("Expected field value to be '%s' but got '%s'", "true", environ["BOOL"]) } @@ -272,46 +561,115 @@ func TestMarshal(t *testing.T) { if environ["NPM_CONFIG_CACHE"] != "foobar" { t.Errorf("Expected field value to be '%s' but got '%s'", "foobar", environ["NPM_CONFIG_CACHE"]) } + + if environ["multiple_tags_with_default"] != "baz" { + t.Errorf("Expected field value to be '%s' but got '%s'", "baz", environ["multiple_tags_with_default"]) + } + + if environ["default=default_tags_value"] != "" { + t.Errorf("'default=default_tags_value' not expected to be a valid field value.") + } + + if environ["tag_with_default"] != "bar" { + t.Errorf("Expected field value to be '%s' but got '%s'", "bar", environ["tag_with_default"]) + } + + if environ["tag_with_required"] != "foo" { + t.Errorf("Expected field value to be '%s' but got '%s'", "foo", environ["tag_with_required"]) + } + + if environ["tag_with_separator"] != "val1&val2" { + t.Errorf("Expected field value to be '%s' but got '%s'", "val1&val2", environ["tag_with_separator"]) + } + + if environ["required=true"] != "" { + t.Errorf("'required=true' not expected to be a valid field value.") + } + + if environ["separator=&"] != "" { + t.Errorf("'separator=&' not expected to be a valid field value.") + } + + if environ["default=default_tag_value"] != "" { + t.Errorf("'default=default_tag_value' not expected to be a valid field value.") + } + + if environ["TYPE_DURATION"] != "3m0s" { + t.Errorf("Expected field value to be '%s' but got '%s'", "3m0s", environ["TYPE_DURATION"]) + } } func TestMarshalInvalid(t *testing.T) { + t.Parallel() var validStruct ValidStruct - _, err := Marshal(validStruct) - if err != ErrInvalidValue { + if _, err := Marshal(validStruct); !errors.Is(err, ErrInvalidValue) { t.Errorf("Expected error 'ErrInvalidValue' but got '%s'", err) } ptr := &validStruct - _, err = Marshal(&ptr) - if err != ErrInvalidValue { + if _, err := Marshal(&ptr); !errors.Is(err, ErrInvalidValue) { t.Errorf("Expected error 'ErrInvalidValue' but got '%s'", err) } } func TestMarshalPointer(t *testing.T) { - empty := "" - validStruct := ValidStruct{ - PointerString: &empty, - } + t.Parallel() + var ( + empty = "" + validStruct = ValidStruct{PointerString: &empty} + ) + es, err := Marshal(&validStruct) if err != nil { t.Errorf("Expected no error but got '%s'", err) } - v, ok := es["POINTER_STRING"] - if !ok { + if v, ok := es["POINTER_STRING"]; !ok { t.Errorf("Expected field '%s' to exist but missing", "POINTER_STRING") } else if v != "" { t.Errorf("Expected field value to be '%s' but got '%s'", "", v) } - v, ok = es["POINTER_MISSING"] - if ok { + if v, ok := es["POINTER_MISSING"]; ok { t.Errorf("Expected field '%s' to not exist but got '%s'", "POINTER_MISSING", v) } - v, ok = es["JENKINS_POINTER_MISSING"] - if ok { + if v, ok := es["JENKINS_POINTER_MISSING"]; ok { t.Errorf("Expected field '%s' to not exist but got '%s'", "JENKINS_POINTER_MISSING", v) } } + +func TestMarshalCustom(t *testing.T) { + t.Parallel() + var ( + someValue = Base64EncodedString("someValue") + validStruct = ValidStruct{ + Base64EncodedString: someValue, + JSONData: JSONData{SomeField: 42}, + PointerJSONData: &JSONData{SomeField: 43}, + } + ) + + es, err := Marshal(&validStruct) + if err != nil { + t.Errorf("Expected no error but got '%s'", err) + } + + if v, ok := es["BASE64_ENCODED_STRING"]; !ok { + t.Errorf("Expected field '%s' to exist but missing", "BASE64_ENCODED_STRING") + } else if v != base64.StdEncoding.EncodeToString([]byte(someValue)) { + t.Errorf("Expected field value to be '%s' but got '%s'", base64.StdEncoding.EncodeToString([]byte(someValue)), v) + } + + if v, ok := es["JSON_DATA"]; !ok { + t.Errorf("Expected field '%s' to exist but got '%s'", "JSON_DATA", v) + } else if v != `{"someField":42}` { + t.Errorf("Expected field value to be '%s' but got '%s'", `{"someField":42}`, v) + } + + if v, ok := es["POINTER_JSON_DATA"]; !ok { + t.Errorf("Expected field '%s' to exist but got '%s'", "POINTER_JSON_DATA", v) + } else if v != `{"someField":43}` { + t.Errorf("Expected field value to be '%s' but got '%s'", `{"someField":43}`, v) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..49cef6d --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/Netflix/go-env + +go 1.23 diff --git a/marshal.go b/marshal.go new file mode 100644 index 0000000..0e0aed5 --- /dev/null +++ b/marshal.go @@ -0,0 +1,6 @@ +package env + +// Marshaler is the interface implemented by types that can marshal themselves into valid environment variable values. +type Marshaler interface { + MarshalEnvironmentValue() (string, error) +} diff --git a/transform.go b/transform.go index 944e71d..63351f2 100644 --- a/transform.go +++ b/transform.go @@ -19,10 +19,8 @@ import ( "strings" ) -var ( - // ErrInvalidEnviron returned when environ has an incorrect format. - ErrInvalidEnviron = errors.New("items in environ must have format key=value") -) +// ErrInvalidEnviron returned when environ has an incorrect format. +var ErrInvalidEnviron = errors.New("items in environ must have format key=value") // EnvSet represents a set of environment variables. type EnvSet map[string]string @@ -37,10 +35,10 @@ func (e EnvSet) Apply(cs ChangeSet) { if v == nil { // Equivalent to os.Unsetenv delete(e, k) - } else { - // Equivalent to os.Setenv - e[k] = *v + continue } + // Equivalent to os.Setenv + e[k] = *v } } @@ -48,7 +46,10 @@ func (e EnvSet) Apply(cs ChangeSet) { // the corresponding EnvSet. If any item in environ does follow the format, // EnvironToEnvSet returns ErrInvalidEnviron. func EnvironToEnvSet(environ []string) (EnvSet, error) { - m := make(EnvSet) + // We error out the function on the first invalid item found, so we can + // optimistically pre-allocate the EnvSet map with the correct size and + // let the GC clean up in the invalid/exit case alongside the function call. + m := make(EnvSet, len(environ)) for _, v := range environ { parts := strings.SplitN(v, "=", 2) if len(parts) != 2 { @@ -62,7 +63,7 @@ func EnvironToEnvSet(environ []string) (EnvSet, error) { // EnvSetToEnviron transforms a EnvSet into a slice of strings with the format // "key=value". func EnvSetToEnviron(m EnvSet) []string { - var environ []string + environ := make([]string, 0, len(m)) for k, v := range m { environ = append(environ, fmt.Sprintf("%s=%s", k, v)) } diff --git a/transform_test.go b/transform_test.go index 77c0b8b..d523e55 100644 --- a/transform_test.go +++ b/transform_test.go @@ -14,11 +14,13 @@ package env import ( + "errors" "fmt" "testing" ) func TestEnvSetApply(t *testing.T) { + t.Parallel() es := EnvSet{ "HOME": "/home/edgarl", "WORKSPACE": "/mnt/builds/slave/workspace/test", @@ -46,6 +48,7 @@ func TestEnvSetApply(t *testing.T) { } func TestEnvironToEnvSet(t *testing.T) { + t.Parallel() environ := []string{"HOME=/home/edgarl", "WORKSPACE=/mnt/builds/slave/workspace/test"} m, err := EnvironToEnvSet(environ) @@ -63,15 +66,16 @@ func TestEnvironToEnvSet(t *testing.T) { } func TestEnvironToEnvSetInvalid(t *testing.T) { + t.Parallel() environ := []string{"INVALID"} - _, err := EnvironToEnvSet(environ) - if err != ErrInvalidEnviron { + if _, err := EnvironToEnvSet(environ); !errors.Is(err, ErrInvalidEnviron) { t.Errorf("Expected 'ErrInvalidEnviron' but got '%s'", err) } } func TestEnvironToEnvSetSplitN(t *testing.T) { + t.Parallel() environ := []string{"SPLIT=one=two"} m, err := EnvironToEnvSet(environ) @@ -85,6 +89,7 @@ func TestEnvironToEnvSetSplitN(t *testing.T) { } func TestEnvSetToEnviron(t *testing.T) { + t.Parallel() m := EnvSet{ "HOME": "/home/test", "WORKSPACE": "/mnt/builds/slave/workspace/test", diff --git a/unmarshal.go b/unmarshal.go new file mode 100644 index 0000000..be00dcb --- /dev/null +++ b/unmarshal.go @@ -0,0 +1,8 @@ +package env + +// Unmarshaler is the interface implemented by types that can unmarshal an +// environment variable value representation of themselves. The input can be +// assumed to be the raw string value stored in the environment. +type Unmarshaler interface { + UnmarshalEnvironmentValue(data string) error +}