Skip to content

Commit

Permalink
Use any instead of interface{}
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Aug 28, 2023
1 parent 1ee7448 commit fcbdb06
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 43 deletions.
2 changes: 1 addition & 1 deletion dates/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (d Date) Value() (driver.Value, error) {
}

// Scan scans from the db value
func (d *Date) Scan(value interface{}) error {
func (d *Date) Scan(value any) error {
*d = ExtractDate(value.(time.Time))
return nil
}
Expand Down
10 changes: 5 additions & 5 deletions dbutil/assertdb/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

// Query creates a new query on which one can assert things
func Query(t *testing.T, db *sqlx.DB, sql string, args ...interface{}) *TestQuery {
func Query(t *testing.T, db *sqlx.DB, sql string, args ...any) *TestQuery {
return &TestQuery{t, db, sql, args}
}

Expand All @@ -17,11 +17,11 @@ type TestQuery struct {
t *testing.T
db *sqlx.DB
sql string
args []interface{}
args []any
}

// Returns asserts that the query returns a single value
func (q *TestQuery) Returns(expected interface{}, msgAndArgs ...interface{}) {
func (q *TestQuery) Returns(expected any, msgAndArgs ...any) {
q.t.Helper()

// get a variable of same type to hold actual result
Expand All @@ -40,10 +40,10 @@ func (q *TestQuery) Returns(expected interface{}, msgAndArgs ...interface{}) {
}

// Columns asserts that the query returns the given column values
func (q *TestQuery) Columns(expected map[string]interface{}, msgAndArgs ...interface{}) {
func (q *TestQuery) Columns(expected map[string]any, msgAndArgs ...any) {
q.t.Helper()

actual := make(map[string]interface{}, len(expected))
actual := make(map[string]any, len(expected))

err := q.db.QueryRowx(q.sql, q.args...).MapScan(actual)
assert.NoError(q.t, err, msgAndArgs...)
Expand Down
6 changes: 3 additions & 3 deletions dbutil/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type QueryError struct {
cause error
message string
sql string
sqlArgs []interface{}
sqlArgs []any
}

func (e *QueryError) Error() string {
Expand All @@ -32,11 +32,11 @@ func (e *QueryError) Unwrap() error {
return e.cause
}

func (e *QueryError) Query() (string, []interface{}) {
func (e *QueryError) Query() (string, []any) {
return e.sql, e.sqlArgs
}

func NewQueryErrorf(cause error, sql string, sqlArgs []interface{}, message string, msgArgs ...interface{}) error {
func NewQueryErrorf(cause error, sql string, sqlArgs []any, message string, msgArgs ...any) error {
return &QueryError{
cause: cause,
message: fmt.Sprintf(message, msgArgs...),
Expand Down
4 changes: 2 additions & 2 deletions dbutil/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestIsUniqueViolation(t *testing.T) {
func TestQueryError(t *testing.T) {
var err error = &pq.Error{Code: pq.ErrorCode("22025"), Message: "unsupported Unicode escape sequence"}

qerr := dbutil.NewQueryErrorf(err, "SELECT * FROM foo WHERE id = $1", []interface{}{234}, "error selecting foo %d", 234)
qerr := dbutil.NewQueryErrorf(err, "SELECT * FROM foo WHERE id = $1", []any{234}, "error selecting foo %d", 234)
assert.Error(t, qerr)
assert.Equal(t, `error selecting foo 234: pq: unsupported Unicode escape sequence`, qerr.Error())

Expand All @@ -41,5 +41,5 @@ func TestQueryError(t *testing.T) {

query, params := unwrapped.Query()
assert.Equal(t, "SELECT * FROM foo WHERE id = $1", query)
assert.Equal(t, []interface{}{234}, params)
assert.Equal(t, []any{234}, params)
}
4 changes: 2 additions & 2 deletions dbutil/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
var validate = validator.New()

// ScanJSON scans a row which is JSON into a destination struct
func ScanJSON(rows *sqlx.Rows, destination interface{}) error {
func ScanJSON(rows *sqlx.Rows, destination any) error {
var raw json.RawMessage
err := rows.Scan(&raw)
if err != nil {
Expand All @@ -27,7 +27,7 @@ func ScanJSON(rows *sqlx.Rows, destination interface{}) error {
}

// ScanAndValidateJSON scans a row which is JSON into a destination struct and validates it
func ScanAndValidateJSON(rows *sqlx.Rows, destination interface{}) error {
func ScanAndValidateJSON(rows *sqlx.Rows, destination any) error {
if err := ScanJSON(rows, destination); err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion dbutil/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestScanJSON(t *testing.T) {
Age int `json:"age" validate:"min=0"`
}

queryRows := func(sql string, args ...interface{}) *sqlx.Rows {
queryRows := func(sql string, args ...any) *sqlx.Rows {
rows, err := db.QueryxContext(ctx, sql, args...)
require.NoError(t, err)
require.True(t, rows.Next())
Expand Down
6 changes: 3 additions & 3 deletions dbutil/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// Queryer is the DB/TX functionality needed for operations in this package
type Queryer interface {
Rebind(query string) string
QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error)
QueryxContext(ctx context.Context, query string, args ...any) (*sqlx.Rows, error)
}

// BulkQuery runs the query as a bulk operation with the given structs
Expand Down Expand Up @@ -61,7 +61,7 @@ func BulkQuery[T any](ctx context.Context, tx Queryer, query string, structs []T

// BulkSQL takes a query which uses VALUES with struct bindings and rewrites it as a bulk operation.
// It returns the new SQL query and the args to pass to it.
func BulkSQL[T any](tx Queryer, sql string, structs []T) (string, []interface{}, error) {
func BulkSQL[T any](tx Queryer, sql string, structs []T) (string, []any, error) {
if len(structs) == 0 {
return "", nil, errors.New("can't generate bulk sql with zero structs")
}
Expand All @@ -71,7 +71,7 @@ func BulkSQL[T any](tx Queryer, sql string, structs []T) (string, []interface{},
values.Grow(7 * len(structs))

// this will be each of the arguments to match the positional values above
args := make([]interface{}, 0, len(structs)*5)
args := make([]any, 0, len(structs)*5)

// for each value we build a bound SQL statement, then extract the values clause
for i, value := range structs {
Expand Down
26 changes: 16 additions & 10 deletions dbutil/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,36 @@ func TestBulkSQL(t *testing.T) {
}

// error if we use a query without a VALUES clause
_, _, err := dbutil.BulkSQL(db, `UPDATE foo SET name = :name WHERE id = :id`, []interface{}{contact{ID: 1, Name: "Bob"}})
_, _, err := dbutil.BulkSQL(db, `UPDATE foo SET name = :name WHERE id = :id`, []any{contact{ID: 1, Name: "Bob"}})
assert.EqualError(t, err, "error extracting VALUES from sql: UPDATE foo SET name = ? WHERE id = ?")

// try with missing parentheses
_, _, err = dbutil.BulkSQL(db, `INSERT INTO foo (id, name) VALUES(:id, :name`, []interface{}{contact{ID: 1, Name: "Bob"}})
_, _, err = dbutil.BulkSQL(db, `INSERT INTO foo (id, name) VALUES(:id, :name`, []any{contact{ID: 1, Name: "Bob"}})
assert.EqualError(t, err, "error extracting VALUES from sql: INSERT INTO foo (id, name) VALUES(?, ?")

sql := `INSERT INTO foo (id, name) VALUES(:id, :name)`

// try with zero structs
_, _, err = dbutil.BulkSQL(db, sql, []interface{}{})
_, _, err = dbutil.BulkSQL(db, sql, []any{})
assert.EqualError(t, err, "can't generate bulk sql with zero structs")

// try with one struct
query, args, err := dbutil.BulkSQL(db, sql, []interface{}{contact{ID: 1, Name: "Bob"}})
query, args, err := dbutil.BulkSQL(db, sql, []any{contact{ID: 1, Name: "Bob"}})
assert.NoError(t, err)
assert.Equal(t, `INSERT INTO foo (id, name) VALUES($1, $2)`, query)
assert.Equal(t, []interface{}{1, "Bob"}, args)
assert.Equal(t, []any{1, "Bob"}, args)

// try with multiple...
query, args, err = dbutil.BulkSQL(db, sql, []interface{}{contact{ID: 1, Name: "Bob"}, contact{ID: 2, Name: "Cathy"}, contact{ID: 3, Name: "George"}})
query, args, err = dbutil.BulkSQL(db, sql, []any{contact{ID: 1, Name: "Bob"}, contact{ID: 2, Name: "Cathy"}, contact{ID: 3, Name: "George"}})
assert.NoError(t, err)
assert.Equal(t, `INSERT INTO foo (id, name) VALUES($1, $2),($3, $4),($5, $6)`, query)
assert.Equal(t, []interface{}{1, "Bob", 2, "Cathy", 3, "George"}, args)
assert.Equal(t, []any{1, "Bob", 2, "Cathy", 3, "George"}, args)

// try with multiple...
query, args, err = dbutil.BulkSQL(db, sql, []any{contact{ID: 1, Name: "Bob"}, contact{ID: 2, Name: "Cathy"}, contact{ID: 3, Name: "George"}})
assert.NoError(t, err)
assert.Equal(t, `INSERT INTO foo (id, name) VALUES($1, $2),($3, $4),($5, $6)`, query)
assert.Equal(t, []any{1, "Bob", 2, "Cathy", 3, "George"}, args)
}

func TestBulkQuery(t *testing.T) {
Expand All @@ -73,7 +79,7 @@ func TestBulkQuery(t *testing.T) {
foo2 := &foo{Name: "Jon", Age: 34}

// error if no VALUES clause
err := dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) RETURNING id`, []interface{}{foo1, foo2})
err := dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) RETURNING id`, []any{foo1, foo2})
assert.EqualError(t, err, "error extracting VALUES from sql: INSERT INTO foo (name, age) RETURNING id")

sql := `INSERT INTO foo (name, age) VALUES(:name, :age) RETURNING id`
Expand All @@ -93,15 +99,15 @@ func TestBulkQuery(t *testing.T) {

// returning ids is optional
foo3 := &foo{Name: "Jim", Age: 54}
err = dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) VALUES(:name, :age)`, []interface{}{foo3})
err = dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) VALUES(:name, :age)`, []any{foo3})
assert.NoError(t, err)
assert.Equal(t, 0, foo3.ID)

assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'Jim' AND age = 54`).Returns(1)

// try with a struct that is invalid
foo4 := &foo{Name: "Jonny", Age: 34}
err = dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) VALUES(:name, :age)`, []interface{}{foo4})
err = dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) VALUES(:name, :age)`, []any{foo4})
assert.EqualError(t, err, "error making bulk query: pq: value too long for type character varying(3)")
assert.Equal(t, 0, foo4.ID)
}
Expand Down
20 changes: 10 additions & 10 deletions jsonx/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ import (
)

// Marshal marshals the given object to JSON
func Marshal(v interface{}) ([]byte, error) {
func Marshal(v any) ([]byte, error) {
return marshal(v, "")
}

// MarshalPretty marshals the given object to pretty JSON
func MarshalPretty(v interface{}) ([]byte, error) {
func MarshalPretty(v any) ([]byte, error) {
return marshal(v, " ")
}

// MarshalMerged marshals the properties of two objects as one object
func MarshalMerged(v1 interface{}, v2 interface{}) ([]byte, error) {
func MarshalMerged(v1 any, v2 any) ([]byte, error) {
b1, err := marshal(v1, "")
if err != nil {
return nil, err
Expand All @@ -32,15 +32,15 @@ func MarshalMerged(v1 interface{}, v2 interface{}) ([]byte, error) {
}

// MustMarshal marshals the given object to JSON, panicking on an error
func MustMarshal(v interface{}) []byte {
func MustMarshal(v any) []byte {
data, err := marshal(v, "")
if err != nil {
panic(err)
}
return data
}

func marshal(v interface{}, indent string) ([]byte, error) {
func marshal(v any, indent string) ([]byte, error) {
buffer := &bytes.Buffer{}
encoder := json.NewEncoder(buffer)
encoder.SetEscapeHTML(false) // see https://github.com/golang/go/issues/8592
Expand All @@ -57,7 +57,7 @@ func marshal(v interface{}, indent string) ([]byte, error) {
}

// Unmarshal is just a shortcut for json.Unmarshal so all calls can be made via the jsonx package
func Unmarshal(data json.RawMessage, v interface{}) error {
func Unmarshal(data json.RawMessage, v any) error {
return json.Unmarshal(data, v)
}

Expand All @@ -69,7 +69,7 @@ func UnmarshalArray(data json.RawMessage) ([]json.RawMessage, error) {
}

// UnmarshalWithLimit unmarsmals a struct with a limit on how many bytes can be read from the given reader
func UnmarshalWithLimit(reader io.ReadCloser, s interface{}, limit int64) error {
func UnmarshalWithLimit(reader io.ReadCloser, s any, limit int64) error {
body, err := io.ReadAll(io.LimitReader(reader, limit))
if err != nil {
return err
Expand All @@ -81,15 +81,15 @@ func UnmarshalWithLimit(reader io.ReadCloser, s interface{}, limit int64) error
}

// MustUnmarshal unmarshals the given JSON, panicking on an error
func MustUnmarshal(data json.RawMessage, v interface{}) {
func MustUnmarshal(data json.RawMessage, v any) {
if err := json.Unmarshal(data, v); err != nil {
panic(err)
}
}

// DecodeGeneric decodes the given JSON as a generic map or slice
func DecodeGeneric(data []byte) (interface{}, error) {
var asGeneric interface{}
func DecodeGeneric(data []byte) (any, error) {
var asGeneric any
decoder := json.NewDecoder(bytes.NewBuffer(data))
decoder.UseNumber()
return asGeneric, decoder.Decode(&asGeneric)
Expand Down
12 changes: 6 additions & 6 deletions jsonx/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ func TestDecodeGeneric(t *testing.T) {
vals, err := jsonx.DecodeGeneric(data)
assert.NoError(t, err)

asMap := vals.(map[string]interface{})
asMap := vals.(map[string]any)
assert.Equal(t, true, asMap["bool"])
assert.Equal(t, json.Number("123.34"), asMap["number"])
assert.Equal(t, "hello", asMap["text"])
assert.Equal(t, map[string]interface{}{"foo": "bar"}, asMap["object"])
assert.Equal(t, []interface{}{json.Number("1"), "x"}, asMap["array"])
assert.Equal(t, map[string]any{"foo": "bar"}, asMap["object"])
assert.Equal(t, []any{json.Number("1"), "x"}, asMap["array"])

// parse a JSON array into a slice
data = []byte(`[{"foo": 123}, {"foo": 456}]`)
vals, err = jsonx.DecodeGeneric(data)
assert.NoError(t, err)

asSlice := vals.([]interface{})
assert.Equal(t, map[string]interface{}{"foo": json.Number("123")}, asSlice[0])
assert.Equal(t, map[string]interface{}{"foo": json.Number("456")}, asSlice[1])
asSlice := vals.([]any)
assert.Equal(t, map[string]any{"foo": json.Number("123")}, asSlice[0])
assert.Equal(t, map[string]any{"foo": json.Number("456")}, asSlice[1])
}

0 comments on commit fcbdb06

Please sign in to comment.