From 7b52a637cf7b783b41c5806592355c3cea80b1ed Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Fri, 16 Feb 2024 12:56:18 +0800 Subject: [PATCH] refactor --- go.mod | 2 +- sql.go | 202 ++++++++++++++++++++++++++++------------------------ sql_test.go | 4 +- 3 files changed, 112 insertions(+), 96 deletions(-) diff --git a/go.mod b/go.mod index 0938562..0a03008 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/jordan-bonecutter/goption +module github.com/olachat/goption go 1.19 diff --git a/sql.go b/sql.go index f190f15..ace5aa4 100644 --- a/sql.go +++ b/sql.go @@ -1,6 +1,7 @@ package goption import ( + "bytes" "database/sql" "database/sql/driver" "errors" @@ -13,100 +14,83 @@ import ( // Scan implements sql.Scanner for Options func (o *Option[T]) Scan(src any) error { if src == nil { - *o = None[T]() - return nil - } - - // Try scanning - var maybeScanner any = &o.t - if scanner, isScanner := maybeScanner.(sql.Scanner); isScanner { - o.ok = true - return scanner.Scan(src) - } - - // Try reflecting - srcVal := reflect.ValueOf(src) - tType := reflect.TypeOf(o.t) - if srcVal.CanConvert(tType) { - reflect.ValueOf(&o.t).Elem().Set(srcVal.Convert(tType)) - o.ok = true + o.ok, o.t = false, *new(T) return nil } + o.ok = true return convertAssign(&o.t, src) } -type errNotAScanner struct{} - -func (errNotAScanner) Error() string { - return "Not a scanner" -} - -var ErrNotAScanner errNotAScanner - -func (o Option[T]) Value() (driver.Value, error) { - if !o.ok { - return nil, nil - } - - var maybeValuer any = o.t - if valuer, isValuer := maybeValuer.(driver.Valuer); isValuer { - return valuer.Value() - } - - tVal := reflect.ValueOf(o.t) - switch tVal.Kind() { +func convertValue(v any) (any, error) { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Pointer: + // indirect pointers + if rv.IsNil() { + return nil, nil + } else { + return driver.DefaultParameterConverter.ConvertValue(rv.Elem().Interface()) + } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return tVal.Int(), nil + return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return int64(tVal.Uint()), nil + return int64(rv.Uint()), nil case reflect.Uint64: - u64 := tVal.Uint() + u64 := rv.Uint() if u64 >= 1<<63 { return nil, fmt.Errorf("uint64 values with high bit set are not supported") } return int64(u64), nil case reflect.Float32, reflect.Float64: - return tVal.Float(), nil + return rv.Float(), nil case reflect.Bool: - return tVal.Bool(), nil + return rv.Bool(), nil case reflect.Slice: - ek := tVal.Type().Elem().Kind() + ek := rv.Type().Elem().Kind() if ek == reflect.Uint8 { - return tVal.Bytes(), nil + return rv.Bytes(), nil } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: - return tVal.String(), nil - } - - int64Type := reflect.TypeOf(int64(0)) - if tVal.CanConvert(int64Type) { - return tVal.Convert(int64Type).Interface(), nil - } - f64Type := reflect.TypeOf(float64(0)) - if tVal.CanConvert(f64Type) { - return tVal.Convert(f64Type).Interface(), nil - } - boolType := reflect.TypeOf(false) - if tVal.CanConvert(boolType) { - return tVal.Convert(boolType).Interface(), nil + return rv.String(), nil } - bytesType := reflect.TypeOf([]byte(nil)) - if tVal.CanConvert(bytesType) { - return tVal.Convert(bytesType).Interface(), nil + if val, isTime := v.(time.Time); isTime { + return val, nil } - stringType := reflect.TypeOf("") - if tVal.CanConvert(stringType) { - return tVal.Convert(stringType).Interface(), nil + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} + +func (o Option[T]) Value() (driver.Value, error) { + if !o.ok { + return nil, nil } - timeType := reflect.TypeOf(time.Time{}) - if tVal.CanConvert(timeType) { - return tVal.Convert(timeType).Interface(), nil + + var maybeValuer any = o.t + if valuer, isValuer := maybeValuer.(driver.Valuer); isValuer { + return valuer.Value() } - return o.t, nil + return convertValue(o.t) +} + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +type decimalDecompose interface { + // Decompose returns the internal decimal state in parts. + // If the provided buf has sufficient capacity, buf may be returned as the coefficient with + // the value set and length set as appropriate. + Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) +} + +type decimalCompose interface { + // Compose sets the internal decimal value from parts. If the value cannot be + // represented then an error should be returned. + Compose(form byte, negative bool, coefficient []byte, exponent int32) error } +type RawBytes []byte + // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. If rows is passed in, the rows will @@ -119,36 +103,48 @@ func convertAssign(dest, src any) error { switch d := dest.(type) { case *string: if d == nil { - return ErrNotAScanner + return errNilPtr } *d = s return nil case *[]byte: if d == nil { - return ErrNotAScanner + return errNilPtr } *d = []byte(s) return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil } case []byte: switch d := dest.(type) { case *string: if d == nil { - return ErrNotAScanner + return errNilPtr } *d = string(s) return nil case *any: if d == nil { - return ErrNotAScanner + return errNilPtr } - *d = cloneBytes(s) + *d = bytes.Clone(s) return nil case *[]byte: if d == nil { - return ErrNotAScanner + return errNilPtr } - *d = cloneBytes(s) + *d = bytes.Clone(s) + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = s return nil } case time.Time: @@ -161,22 +157,39 @@ func convertAssign(dest, src any) error { return nil case *[]byte: if d == nil { - return ErrNotAScanner + return errNilPtr } *d = []byte(s.Format(time.RFC3339Nano)) return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil + } + case decimalDecompose: + switch d := dest.(type) { + case decimalCompose: + return d.Compose(s.Decompose(nil)) } case nil: switch d := dest.(type) { case *any: if d == nil { - return ErrNotAScanner + return errNilPtr } *d = nil return nil case *[]byte: if d == nil { - return ErrNotAScanner + return errNilPtr + } + *d = nil + return nil + case *RawBytes: + if d == nil { + return errNilPtr } *d = nil return nil @@ -202,6 +215,12 @@ func convertAssign(dest, src any) error { *d = b return nil } + case *RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = RawBytes(b) + return nil + } case *bool: bv, err := driver.Bool.ConvertValue(src) if err == nil { @@ -213,12 +232,16 @@ func convertAssign(dest, src any) error { return nil } + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(src) + } + dpv := reflect.ValueOf(dest) if dpv.Kind() != reflect.Pointer { return errors.New("destination not a pointer") } if dpv.IsNil() { - return ErrNotAScanner + return errNilPtr } if !sv.IsValid() { @@ -229,7 +252,7 @@ func convertAssign(dest, src any) error { if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { switch b := src.(type) { case []byte: - dv.Set(reflect.ValueOf(cloneBytes(b))) + dv.Set(reflect.ValueOf(bytes.Clone(b))) default: dv.Set(sv) } @@ -249,7 +272,7 @@ func convertAssign(dest, src any) error { switch dv.Kind() { case reflect.Pointer: if src == nil { - dv.Set(reflect.Zero(dv.Type())) + dv.SetZero() return nil } dv.Set(reflect.New(dv.Type().Elem())) @@ -307,13 +330,11 @@ func convertAssign(dest, src any) error { return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) } -func cloneBytes(b []byte) []byte { - if b == nil { - return nil +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err } - c := make([]byte, len(b)) - copy(c, b) - return c + return err } func asString(src any) string { @@ -338,6 +359,7 @@ func asString(src any) string { } return fmt.Sprintf("%v", src) } + func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -356,9 +378,3 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { } return } -func strconvErr(err error) error { - if ne, ok := err.(*strconv.NumError); ok { - return ne.Err - } - return err -} diff --git a/sql_test.go b/sql_test.go index 5136e7c..8364a43 100644 --- a/sql_test.go +++ b/sql_test.go @@ -21,7 +21,7 @@ func TestSQLScanner(t *testing.T) { } if _, err := db.Exec(`CREATE SCHEMA test;`); err != nil { - t.Fatalf("Failed creating test schema: %s", err.Error()) + t.Logf("Failed creating test schema: %s", err.Error()) } if _, err := db.Exec(`CREATE TABLE test( @@ -29,7 +29,7 @@ func TestSQLScanner(t *testing.T) { maybe_empty integer, ts timestamptz );`); err != nil { - t.Fatalf("Failed creating test table: %s", err.Error()) + t.Logf("Failed creating test table: %s", err.Error()) } var testImplementsValuer any = Some[int](0)