diff --git a/interpolate.go b/interpolate.go index 372164f..00be848 100644 --- a/interpolate.go +++ b/interpolate.go @@ -94,6 +94,20 @@ var ( typeTime = reflect.TypeOf(time.Time{}) ) +func isInterfaceInvalidOrNil(value interface{}) bool { + v := reflect.ValueOf(value) + if !v.IsValid() { + return true + } + + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return v.IsNil() + default: + return false + } +} + func (i *interpolator) encodePlaceholder(value interface{}, topLevel bool) error { if builder, ok := value.(Builder); ok { pbuf := NewBuffer() @@ -120,9 +134,13 @@ func (i *interpolator) encodePlaceholder(value interface{}, topLevel bool) error } if valuer, ok := value.(driver.Valuer); ok { - // get driver.Valuer's data var err error - value, err = valuer.Value() + if isInterfaceInvalidOrNil(valuer) { + value = nil + } else { + // get driver.Valuer's data + value, err = valuer.Value() + } if err != nil { return err } diff --git a/interpolate_test.go b/interpolate_test.go index bb45aed..ca1ad0e 100644 --- a/interpolate_test.go +++ b/interpolate_test.go @@ -1,6 +1,7 @@ package dbr import ( + "database/sql/driver" "strings" "testing" "time" @@ -9,6 +10,13 @@ import ( "github.com/stretchr/testify/require" ) +type TestValuer struct { +} + +func (m TestValuer) Value() (driver.Value, error) { + return nil, nil +} + func TestInterpolateIgnoreBinary(t *testing.T) { for _, test := range []struct { query string @@ -22,6 +30,12 @@ func TestInterpolateIgnoreBinary(t *testing.T) { wantQuery: "1", wantValue: nil, }, + { + query: "?", + value: []interface{}{(*TestValuer)(nil)}, + wantQuery: "NULL", + wantValue: nil, + }, { query: "?", value: []interface{}{[]byte{1, 2, 3}},