Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yinloo-ola committed Feb 16, 2024
1 parent a8df0f2 commit 7b52a63
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 96 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module github.com/jordan-bonecutter/goption
module github.com/olachat/goption

go 1.19

Expand Down
202 changes: 109 additions & 93 deletions sql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goption

import (
"bytes"
"database/sql"
"database/sql/driver"
"errors"
Expand All @@ -13,100 +14,83 @@ import (
// Scan implements sql.Scanner for Options
func (o *Option[T]) Scan(src any) error {
if src == nil {
*o = None[T]()
return nil
}

// Try scanning
var maybeScanner any = &o.t
if scanner, isScanner := maybeScanner.(sql.Scanner); isScanner {
o.ok = true
return scanner.Scan(src)
}

// Try reflecting
srcVal := reflect.ValueOf(src)
tType := reflect.TypeOf(o.t)
if srcVal.CanConvert(tType) {
reflect.ValueOf(&o.t).Elem().Set(srcVal.Convert(tType))
o.ok = true
o.ok, o.t = false, *new(T)
return nil
}

o.ok = true
return convertAssign(&o.t, src)
}

type errNotAScanner struct{}

func (errNotAScanner) Error() string {
return "Not a scanner"
}

var ErrNotAScanner errNotAScanner

func (o Option[T]) Value() (driver.Value, error) {
if !o.ok {
return nil, nil
}

var maybeValuer any = o.t
if valuer, isValuer := maybeValuer.(driver.Valuer); isValuer {
return valuer.Value()
}

tVal := reflect.ValueOf(o.t)
switch tVal.Kind() {
func convertValue(v any) (any, error) {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Pointer:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return driver.DefaultParameterConverter.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return tVal.Int(), nil
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(tVal.Uint()), nil
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := tVal.Uint()
u64 := rv.Uint()
if u64 >= 1<<63 {
return nil, fmt.Errorf("uint64 values with high bit set are not supported")
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return tVal.Float(), nil
return rv.Float(), nil
case reflect.Bool:
return tVal.Bool(), nil
return rv.Bool(), nil
case reflect.Slice:
ek := tVal.Type().Elem().Kind()
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
return tVal.Bytes(), nil
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return tVal.String(), nil
}

int64Type := reflect.TypeOf(int64(0))
if tVal.CanConvert(int64Type) {
return tVal.Convert(int64Type).Interface(), nil
}
f64Type := reflect.TypeOf(float64(0))
if tVal.CanConvert(f64Type) {
return tVal.Convert(f64Type).Interface(), nil
}
boolType := reflect.TypeOf(false)
if tVal.CanConvert(boolType) {
return tVal.Convert(boolType).Interface(), nil
return rv.String(), nil
}
bytesType := reflect.TypeOf([]byte(nil))
if tVal.CanConvert(bytesType) {
return tVal.Convert(bytesType).Interface(), nil
if val, isTime := v.(time.Time); isTime {
return val, nil
}
stringType := reflect.TypeOf("")
if tVal.CanConvert(stringType) {
return tVal.Convert(stringType).Interface(), nil
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}

func (o Option[T]) Value() (driver.Value, error) {
if !o.ok {
return nil, nil
}
timeType := reflect.TypeOf(time.Time{})
if tVal.CanConvert(timeType) {
return tVal.Convert(timeType).Interface(), nil

var maybeValuer any = o.t
if valuer, isValuer := maybeValuer.(driver.Valuer); isValuer {
return valuer.Value()
}

return o.t, nil
return convertValue(o.t)
}

var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error

type decimalDecompose interface {
// Decompose returns the internal decimal state in parts.
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
// the value set and length set as appropriate.
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
}

type decimalCompose interface {
// Compose sets the internal decimal value from parts. If the value cannot be
// represented then an error should be returned.
Compose(form byte, negative bool, coefficient []byte, exponent int32) error
}

type RawBytes []byte

// convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type. If rows is passed in, the rows will
Expand All @@ -119,36 +103,48 @@ func convertAssign(dest, src any) error {
switch d := dest.(type) {
case *string:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = s
return nil
case *[]byte:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = []byte(s)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = append((*d)[:0], s...)
return nil
}
case []byte:
switch d := dest.(type) {
case *string:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = string(s)
return nil
case *any:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = cloneBytes(s)
*d = bytes.Clone(s)
return nil
case *[]byte:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = cloneBytes(s)
*d = bytes.Clone(s)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = s
return nil
}
case time.Time:
Expand All @@ -161,22 +157,39 @@ func convertAssign(dest, src any) error {
return nil
case *[]byte:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = []byte(s.Format(time.RFC3339Nano))
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
return nil
}
case decimalDecompose:
switch d := dest.(type) {
case decimalCompose:
return d.Compose(s.Decompose(nil))
}
case nil:
switch d := dest.(type) {
case *any:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = nil
return nil
case *[]byte:
if d == nil {
return ErrNotAScanner
return errNilPtr
}
*d = nil
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = nil
return nil
Expand All @@ -202,6 +215,12 @@ func convertAssign(dest, src any) error {
*d = b
return nil
}
case *RawBytes:
sv = reflect.ValueOf(src)
if b, ok := asBytes([]byte(*d)[:0], sv); ok {
*d = RawBytes(b)
return nil
}
case *bool:
bv, err := driver.Bool.ConvertValue(src)
if err == nil {
Expand All @@ -213,12 +232,16 @@ func convertAssign(dest, src any) error {
return nil
}

if scanner, ok := dest.(sql.Scanner); ok {
return scanner.Scan(src)
}

dpv := reflect.ValueOf(dest)
if dpv.Kind() != reflect.Pointer {
return errors.New("destination not a pointer")
}
if dpv.IsNil() {
return ErrNotAScanner
return errNilPtr
}

if !sv.IsValid() {
Expand All @@ -229,7 +252,7 @@ func convertAssign(dest, src any) error {
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
switch b := src.(type) {
case []byte:
dv.Set(reflect.ValueOf(cloneBytes(b)))
dv.Set(reflect.ValueOf(bytes.Clone(b)))
default:
dv.Set(sv)
}
Expand All @@ -249,7 +272,7 @@ func convertAssign(dest, src any) error {
switch dv.Kind() {
case reflect.Pointer:
if src == nil {
dv.Set(reflect.Zero(dv.Type()))
dv.SetZero()
return nil
}
dv.Set(reflect.New(dv.Type().Elem()))
Expand Down Expand Up @@ -307,13 +330,11 @@ func convertAssign(dest, src any) error {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
}

func cloneBytes(b []byte) []byte {
if b == nil {
return nil
func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
c := make([]byte, len(b))
copy(c, b)
return c
return err
}

func asString(src any) string {
Expand All @@ -338,6 +359,7 @@ func asString(src any) string {
}
return fmt.Sprintf("%v", src)
}

func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
Expand All @@ -356,9 +378,3 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
}
return
}
func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
return err
}
4 changes: 2 additions & 2 deletions sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ func TestSQLScanner(t *testing.T) {
}

if _, err := db.Exec(`CREATE SCHEMA test;`); err != nil {
t.Fatalf("Failed creating test schema: %s", err.Error())
t.Logf("Failed creating test schema: %s", err.Error())
}

if _, err := db.Exec(`CREATE TABLE test(
key integer not null,
maybe_empty integer,
ts timestamptz
);`); err != nil {
t.Fatalf("Failed creating test table: %s", err.Error())
t.Logf("Failed creating test table: %s", err.Error())
}

var testImplementsValuer any = Some[int](0)
Expand Down

0 comments on commit 7b52a63

Please sign in to comment.