Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dbutil.ScanAllSlice and ScanAllMap #94

Merged
merged 3 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions dbutil/json.go → dbutil/scan.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package dbutil

import (
"database/sql"
"encoding/json"

"github.com/go-playground/validator/v10"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

var validate = validator.New()

// ScanJSON scans a row which is JSON into a destination struct
func ScanJSON(rows *sqlx.Rows, destination any) error {
func ScanJSON(rows *sql.Rows, destination any) error {
var raw json.RawMessage
err := rows.Scan(&raw)
if err != nil {
Expand All @@ -27,7 +27,7 @@
}

// ScanAndValidateJSON scans a row which is JSON into a destination struct and validates it
func ScanAndValidateJSON(rows *sqlx.Rows, destination any) error {
func ScanAndValidateJSON(rows *sql.Rows, destination any) error {
if err := ScanJSON(rows, destination); err != nil {
return err
}
Expand All @@ -39,3 +39,34 @@

return nil
}

// ScanAllSlice scans a single value from each single column row into the given slice
func ScanAllSlice[V any](rows *sql.Rows, s []V) ([]V, error) {
defer rows.Close()

var v V

for rows.Next() {
if err := rows.Scan(&v); err != nil {
return nil, err
}

Check warning on line 52 in dbutil/scan.go

View check run for this annotation

Codecov / codecov/patch

dbutil/scan.go#L51-L52

Added lines #L51 - L52 were not covered by tests
s = append(s, v)
}
return s, rows.Err()
}

// ScanAllMap scans a key and value from each two column row into the given map
func ScanAllMap[K comparable, V any](rows *sql.Rows, m map[K]V) error {
defer rows.Close()

var k K
var v V

for rows.Next() {
if err := rows.Scan(&k, &v); err != nil {
return err
}

Check warning on line 68 in dbutil/scan.go

View check run for this annotation

Codecov / codecov/patch

dbutil/scan.go#L67-L68

Added lines #L67 - L68 were not covered by tests
m[k] = v
}
return rows.Err()
}
62 changes: 58 additions & 4 deletions dbutil/json_test.go → dbutil/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package dbutil_test

import (
"context"
"database/sql"
"testing"

"github.com/jmoiron/sqlx"
"github.com/nyaruka/gocommon/dbutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -29,8 +29,8 @@ func TestScanJSON(t *testing.T) {
Age int `json:"age" validate:"min=0"`
}

queryRows := func(sql string, args ...any) *sqlx.Rows {
rows, err := db.QueryxContext(ctx, sql, args...)
queryRows := func(sql string, args ...any) *sql.Rows {
rows, err := db.QueryContext(ctx, sql, args...)
require.NoError(t, err)
require.True(t, rows.Next())
return rows
Expand Down Expand Up @@ -65,7 +65,7 @@ func TestScanJSON(t *testing.T) {
assert.EqualError(t, err, "error unmarshalling row JSON: json: cannot unmarshal string into Go struct field foo.age of type int")

// error if rows aren't ready to be scanned - e.g. next hasn't been called
rows, err = db.QueryxContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`)
rows, err = db.QueryContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`)
require.NoError(t, err)
err = dbutil.ScanAndValidateJSON(rows, f)
assert.EqualError(t, err, "error scanning row JSON: sql: Scan called without calling Next")
Expand All @@ -85,3 +85,57 @@ func TestScanJSON(t *testing.T) {
assert.Equal(t, "George", f.Name)
assert.Equal(t, -1, f.Age)
}

func TestScanAllSlice(t *testing.T) {
db := getTestDB()

defer func() { db.MustExec(`DROP TABLE foo`) }()

db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, name VARCHAR(10))`)
db.MustExec(`INSERT INTO foo (name) VALUES('Ann')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Bob')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Cat')`)

rows, err := db.Query(`SELECT id FROM foo ORDER BY id`)
require.NoError(t, err)

ids := make([]int, 0, 2)
ids, err = dbutil.ScanAllSlice(rows, ids)
require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3}, ids)

rows, err = db.Query(`SELECT name FROM foo ORDER BY id DESC`)
require.NoError(t, err)

names := make([]string, 0, 2)
names, err = dbutil.ScanAllSlice(rows, names)
require.NoError(t, err)
assert.Equal(t, []string{"Cat", "Bob", "Ann"}, names)
}

func TestScanAllMap(t *testing.T) {
db := getTestDB()

defer func() { db.MustExec(`DROP TABLE foo`) }()

db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, name VARCHAR(10))`)
db.MustExec(`INSERT INTO foo (name) VALUES('Ann')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Bob')`)
db.MustExec(`INSERT INTO foo (name) VALUES('Cat')`)

rows, err := db.Query(`SELECT id, name FROM foo`)
require.NoError(t, err)

nameByID := make(map[int]string, 2)
err = dbutil.ScanAllMap(rows, nameByID)
require.NoError(t, err)
assert.Equal(t, map[int]string{1: "Ann", 2: "Bob", 3: "Cat"}, nameByID)

rows, err = db.Query(`SELECT name, id FROM foo`)
require.NoError(t, err)

idByName := make(map[string]int, 2)
err = dbutil.ScanAllMap(rows, idByName)
require.NoError(t, err)
assert.Equal(t, map[string]int{"Ann": 1, "Bob": 2, "Cat": 3}, idByName)
}
Loading