Skip to content

Commit

Permalink
support types that implement encoding.TextUnmarshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
profclems committed Oct 12, 2024
1 parent edb8dfb commit 159a6c9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 20 deletions.
4 changes: 2 additions & 2 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/profclems/go-dotenv"
)

func BenchmarkDotenv_LoadConfig(b *testing.B) {
func BenchmarkDotenv_Load(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
config := dotenv.New()
Expand Down Expand Up @@ -41,7 +41,7 @@ func BenchmarkDotenv_Init_GetSet(b *testing.B) {
})
}

func BenchmarkDotenv_LoadConfig_GetSet(b *testing.B) {
func BenchmarkDotenv_Load_GetSet(b *testing.B) {
dotenv.SetConfigFile("fixtures/large.env")
err := dotenv.Load()
if err != nil {
Expand Down
57 changes: 39 additions & 18 deletions dotenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dotenv

import (
"bytes"
"encoding"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -175,12 +176,22 @@ func (e *DotEnv) SetConfigFile(configFile string) {
}

// UnMarshal unmarshals the config file into a struct.
// Recognizes the following struct tags:
// - env:"KEY" to specify the key name to look up in the config file
// - default:"value" to specify a default value if the key is not found
func UnMarshal(v any) error {
return d.Unmarshal(v)
}

func (e *DotEnv) Unmarshal(v any) error {
val := reflect.ValueOf(v).Elem()
func (e *DotEnv) Unmarshal(v any) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
}
}()

vPtr := reflect.ValueOf(v)
val := vPtr.Elem()

if vk := val.Kind(); vk != reflect.Struct {
return fmt.Errorf("expected a struct, got %T", vk)
Expand All @@ -191,28 +202,38 @@ func (e *DotEnv) Unmarshal(v any) error {
field := typ.Field(i)
fieldVal := val.Field(i)

if field.Type.Kind() == reflect.Struct {
if err := e.Unmarshal(fieldVal.Addr().Interface()); err != nil {
return err
getConfigVal := func() string {
tag := field.Tag.Get("env")
if tag != "" {
if envVal := e.GetString(tag); envVal != "" {
return envVal
}
}
continue
// set default value
if def := field.Tag.Get("default"); def != "" {
return def
}
return ""
}

tag := field.Tag.Get("env")
var configVal string
if tag != "" {
if envVal := e.GetString(tag); envVal != "" {
configVal = envVal
if fieldVal.CanAddr() {
if m, ok := fieldVal.Addr().Interface().(encoding.TextUnmarshaler); ok {
if err := m.UnmarshalText([]byte(getConfigVal())); err != nil {
return err
}
continue
}
}
if configVal == "" {
// set default value
if def := field.Tag.Get("default"); def != "" {
configVal = def
} else {
continue

if field.Type.Kind() == reflect.Struct {
if err := e.Unmarshal(fieldVal.Addr().Interface()); err != nil {
return err
}
continue
}

configVal := getConfigVal()

// set the value based on the field type
switch field.Type {
case reflect.TypeOf(time.Time{}):
Expand Down Expand Up @@ -241,7 +262,7 @@ func (e *DotEnv) Unmarshal(v any) error {
}
}

return nil
return err
}

// Get can retrieve any value given the key to use.
Expand Down
35 changes: 35 additions & 0 deletions dotenv_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dotenv_test

import (
"encoding"
"log"
"testing"
"time"

Expand Down Expand Up @@ -169,3 +171,36 @@ func TestUnMarshal(t *testing.T) {

require.Equal(t, expectedConfig, config)
}

type customDuration struct {
value time.Duration
}

// check that it implements encoding.TextUnmarshaler
var _ encoding.TextUnmarshaler = (*customDuration)(nil)

func (c *customDuration) UnmarshalText(text []byte) error {
d, err := time.ParseDuration(string(text))
if err != nil {
return err
}
log.Println(d)
c.value = d
return nil
}

func TestUnMarshal_fieldWithTextUnmarshaler(t *testing.T) {
type config struct {
Interval customDuration `env:"INTERVAL" default:"15m"`
}

expectedConfig := config{
Interval: customDuration{value: 15 * time.Minute},
}
cfg := config{}

dotenv := dotenv.New()
err := dotenv.Unmarshal(&cfg)
require.NoError(t, err)
require.Equal(t, expectedConfig, cfg)
}

0 comments on commit 159a6c9

Please sign in to comment.