Skip to content

Commit

Permalink
Merge pull request #94 from nyaruka/dbutil_scan
Browse files Browse the repository at this point in the history
Add `dbutil.ScanAllSlice` and `ScanAllMap`
  • Loading branch information
rowanseymour authored Aug 28, 2023
2 parents c616cdf + b5adbb4 commit f5df8d1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 7 deletions.
37 changes: 34 additions & 3 deletions dbutil/json.go → dbutil/scan.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package dbutil

import (
"database/sql"
"encoding/json"

"github.com/go-playground/validator/v10"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

var validate = validator.New()

// ScanJSON scans a row which is JSON into a destination struct
func ScanJSON(rows *sqlx.Rows, destination any) error {
func ScanJSON(rows *sql.Rows, destination any) error {
var raw json.RawMessage
err := rows.Scan(&raw)
if err != nil {
Expand All @@ -27,7 +27,7 @@ func ScanJSON(rows *sqlx.Rows, destination any) error {
}

// ScanAndValidateJSON scans a row which is JSON into a destination struct and validates it
func ScanAndValidateJSON(rows *sqlx.Rows, destination any) error {
func ScanAndValidateJSON(rows *sql.Rows, destination any) error {
if err := ScanJSON(rows, destination); err != nil {
return err
}
Expand All @@ -39,3 +39,34 @@ func ScanAndValidateJSON(rows *sqlx.Rows, destination any) error {

return nil
}

// ScanAllSlice scans a single value from each single column row into the given slice
func ScanAllSlice[V any](rows *sql.Rows, s []V) ([]V, error) {
defer rows.Close()

var v V

for rows.Next() {
if err := rows.Scan(&v); err != nil {
return nil, err
}
s = append(s, v)
}
return s, rows.Err()
}

// ScanAllMap scans a key and value from each two column row into the given map
func ScanAllMap[K comparable, V any](rows *sql.Rows, m map[K]V) error {
defer rows.Close()

var k K
var v V

for rows.Next() {
if err := rows.Scan(&k, &v); err != nil {
return err
}
m[k] = v
}
return rows.Err()
}
62 changes: 58 additions & 4 deletions dbutil/json_test.go → dbutil/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package dbutil_test

import (
"context"
"database/sql"
"testing"

"github.com/jmoiron/sqlx"
"github.com/nyaruka/gocommon/dbutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -29,8 +29,8 @@ func TestScanJSON(t *testing.T) {
Age int `json:"age" validate:"min=0"`
}

queryRows := func(sql string, args ...any) *sqlx.Rows {
rows, err := db.QueryxContext(ctx, sql, args...)
queryRows := func(sql string, args ...any) *sql.Rows {
rows, err := db.QueryContext(ctx, sql, args...)
require.NoError(t, err)
require.True(t, rows.Next())
return rows
Expand Down Expand Up @@ -65,7 +65,7 @@ func TestScanJSON(t *testing.T) {
assert.EqualError(t, err, "error unmarshalling row JSON: json: cannot unmarshal string into Go struct field foo.age of type int")

// error if rows aren't ready to be scanned - e.g. next hasn't been called
rows, err = db.QueryxContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`)
rows, err = db.QueryContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`)
require.NoError(t, err)
err = dbutil.ScanAndValidateJSON(rows, f)
assert.EqualError(t, err, "error scanning row JSON: sql: Scan called without calling Next")
Expand All @@ -85,3 +85,57 @@ func TestScanJSON(t *testing.T) {
assert.Equal(t, "George", f.Name)
assert.Equal(t, -1, f.Age)
}

func TestScanAllSlice(t *testing.T) {
db := getTestDB()

defer func() { db.MustExec(`DROP TABLE foo`) }()

db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, name VARCHAR(10))`)
db.MustExec(`INSERT INTO foo (name) VALUES('Ann')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Bob')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Cat')`)

rows, err := db.Query(`SELECT id FROM foo ORDER BY id`)
require.NoError(t, err)

ids := make([]int, 0, 2)
ids, err = dbutil.ScanAllSlice(rows, ids)
require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3}, ids)

rows, err = db.Query(`SELECT name FROM foo ORDER BY id DESC`)
require.NoError(t, err)

names := make([]string, 0, 2)
names, err = dbutil.ScanAllSlice(rows, names)
require.NoError(t, err)
assert.Equal(t, []string{"Cat", "Bob", "Ann"}, names)
}

func TestScanAllMap(t *testing.T) {
db := getTestDB()

defer func() { db.MustExec(`DROP TABLE foo`) }()

db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, name VARCHAR(10))`)
db.MustExec(`INSERT INTO foo (name) VALUES('Ann')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Bob')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Cat')`)

rows, err := db.Query(`SELECT id, name FROM foo`)
require.NoError(t, err)

nameByID := make(map[int]string, 2)
err = dbutil.ScanAllMap(rows, nameByID)
require.NoError(t, err)
assert.Equal(t, map[int]string{1: "Ann", 2: "Bob", 3: "Cat"}, nameByID)

rows, err = db.Query(`SELECT name, id FROM foo`)
require.NoError(t, err)

idByName := make(map[string]int, 2)
err = dbutil.ScanAllMap(rows, idByName)
require.NoError(t, err)
assert.Equal(t, map[string]int{"Ann": 1, "Bob": 2, "Cat": 3}, idByName)
}

0 comments on commit f5df8d1

Please sign in to comment.