From f332d228e7b4d8ef90a188f43acc93d93a0042e0 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Mon, 28 Aug 2023 13:28:47 -0500 Subject: [PATCH 1/2] Add dbutil.SliceScan and MapScan --- dbutil/scan.go | 34 ++++++++++++++++++++++++ dbutil/scan_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 dbutil/scan.go create mode 100644 dbutil/scan_test.go diff --git a/dbutil/scan.go b/dbutil/scan.go new file mode 100644 index 0000000..f33884a --- /dev/null +++ b/dbutil/scan.go @@ -0,0 +1,34 @@ +package dbutil + +import "database/sql" + +// SliceScan scans a single value from each row into the given slice +func SliceScan[V any](rows *sql.Rows, s []V) ([]V, error) { + defer rows.Close() + + var v V + + for rows.Next() { + if err := rows.Scan(&v); err != nil { + return nil, err + } + s = append(s, v) + } + return s, rows.Err() +} + +// MapScan scans a key and value from each row into the given map +func MapScan[K comparable, V any](rows *sql.Rows, m map[K]V) error { + defer rows.Close() + + var k K + var v V + + for rows.Next() { + if err := rows.Scan(&k, &v); err != nil { + return err + } + m[k] = v + } + return rows.Err() +} diff --git a/dbutil/scan_test.go b/dbutil/scan_test.go new file mode 100644 index 0000000..c778767 --- /dev/null +++ b/dbutil/scan_test.go @@ -0,0 +1,63 @@ +package dbutil_test + +import ( + "testing" + + "github.com/nyaruka/gocommon/dbutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSliceScan(t *testing.T) { + db := getTestDB() + + defer func() { db.MustExec(`DROP TABLE foo`) }() + + db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, name VARCHAR(10))`) + db.MustExec(`INSERT INTO foo (name) VALUES('Ann')`) + db.MustExec(`INSERT INTO foo (name) VALUES('Bob')`) + db.MustExec(`INSERT INTO foo (name) VALUES('Cat')`) + + rows, err := db.Query(`SELECT id FROM foo ORDER BY id`) + require.NoError(t, err) + + ids := make([]int, 0, 2) + ids, err = dbutil.SliceScan(rows, ids) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, ids) + + rows, err = db.Query(`SELECT name FROM foo ORDER BY id DESC`) + require.NoError(t, err) + + names := make([]string, 0, 2) + names, err = dbutil.SliceScan(rows, names) + require.NoError(t, err) + assert.Equal(t, []string{"Cat", "Bob", "Ann"}, names) +} + +func TestMapScan(t *testing.T) { + db := getTestDB() + + defer func() { db.MustExec(`DROP TABLE foo`) }() + + db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, name VARCHAR(10))`) + db.MustExec(`INSERT INTO foo (name) VALUES('Ann')`) + db.MustExec(`INSERT INTO foo (name) VALUES('Bob')`) + db.MustExec(`INSERT INTO foo (name) VALUES('Cat')`) + + rows, err := db.Query(`SELECT id, name FROM foo`) + require.NoError(t, err) + + nameByID := make(map[int]string, 2) + err = dbutil.MapScan(rows, nameByID) + require.NoError(t, err) + assert.Equal(t, map[int]string{1: "Ann", 2: "Bob", 3: "Cat"}, nameByID) + + rows, err = db.Query(`SELECT name, id FROM foo`) + require.NoError(t, err) + + idByName := make(map[string]int, 2) + err = dbutil.MapScan(rows, idByName) + require.NoError(t, err) + assert.Equal(t, map[string]int{"Ann": 1, "Bob": 2, "Cat": 3}, idByName) +} From b5adbb4721d7a5060e75aeaaa5395d89b4311804 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Mon, 28 Aug 2023 14:38:06 -0500 Subject: [PATCH 2/2] Merge dbutil/scan with dbutil/json and rework json functions to use sql instead of sqlx --- dbutil/json.go | 41 --------------------- dbutil/json_test.go | 87 ------------------------------------------- dbutil/scan.go | 48 +++++++++++++++++++++--- dbutil/scan_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 127 insertions(+), 139 deletions(-) delete mode 100644 dbutil/json.go delete mode 100644 dbutil/json_test.go diff --git a/dbutil/json.go b/dbutil/json.go deleted file mode 100644 index b657368..0000000 --- a/dbutil/json.go +++ /dev/null @@ -1,41 +0,0 @@ -package dbutil - -import ( - "encoding/json" - - "github.com/go-playground/validator/v10" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" -) - -var validate = validator.New() - -// ScanJSON scans a row which is JSON into a destination struct -func ScanJSON(rows *sqlx.Rows, destination any) error { - var raw json.RawMessage - err := rows.Scan(&raw) - if err != nil { - return errors.Wrap(err, "error scanning row JSON") - } - - err = json.Unmarshal(raw, destination) - if err != nil { - return errors.Wrap(err, "error unmarshalling row JSON") - } - - return nil -} - -// ScanAndValidateJSON scans a row which is JSON into a destination struct and validates it -func ScanAndValidateJSON(rows *sqlx.Rows, destination any) error { - if err := ScanJSON(rows, destination); err != nil { - return err - } - - err := validate.Struct(destination) - if err != nil { - return errors.Wrapf(err, "error validating unmarsalled JSON") - } - - return nil -} diff --git a/dbutil/json_test.go b/dbutil/json_test.go deleted file mode 100644 index a9af471..0000000 --- a/dbutil/json_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package dbutil_test - -import ( - "context" - "testing" - - "github.com/jmoiron/sqlx" - "github.com/nyaruka/gocommon/dbutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestScanJSON(t *testing.T) { - ctx := context.Background() - db := getTestDB() - - defer func() { - db.MustExec(`DROP TABLE foo`) - }() - - db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, uuid UUID NOT NULL, name VARCHAR(10), age INT)`) - db.MustExec(`INSERT INTO foo (uuid, name, age) VALUES('11163af6-a2ee-486d-b6dc-984174f10eec', 'Bob', 40)`) - db.MustExec(`INSERT INTO foo (uuid, name, age) VALUES('57d3f887-9ae1-4292-8fa4-ffc11e31e2f7', 'Cathy', 30)`) - db.MustExec(`INSERT INTO foo (uuid, name, age) VALUES('a5850c89-dd29-46f6-9de1-d068b3c2db94', 'George', -1)`) - - type foo struct { - UUID string `json:"uuid" validate:"required"` - Name string `json:"name"` - Age int `json:"age" validate:"min=0"` - } - - queryRows := func(sql string, args ...any) *sqlx.Rows { - rows, err := db.QueryxContext(ctx, sql, args...) - require.NoError(t, err) - require.True(t, rows.Next()) - return rows - } - - // if query returns valid JSON which can be unmarshaled into our struct, all good - rows := queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 1) r`) - - f := &foo{} - err := dbutil.ScanAndValidateJSON(rows, f) - assert.NoError(t, err) - assert.Equal(t, "11163af6-a2ee-486d-b6dc-984174f10eec", f.UUID) - assert.Equal(t, "Bob", f.Name) - assert.Equal(t, 40, f.Age) - - rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 2) r`) - - err = dbutil.ScanAndValidateJSON(rows, f) - assert.NoError(t, err) - assert.Equal(t, "57d3f887-9ae1-4292-8fa4-ffc11e31e2f7", f.UUID) - assert.Equal(t, "Cathy", f.Name) - assert.Equal(t, 30, f.Age) - - // error if row value is not JSON - rows = queryRows(`SELECT id FROM foo f WHERE id = 1`) - err = dbutil.ScanAndValidateJSON(rows, f) - assert.EqualError(t, err, `error scanning row JSON: sql: Scan error on column index 0, name "id": unsupported Scan, storing driver.Value type int64 into type *json.RawMessage`) - - // error if we can't marshal into the struct - rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS age FROM foo f WHERE id = 1) r`) - err = dbutil.ScanAndValidateJSON(rows, f) - assert.EqualError(t, err, "error unmarshalling row JSON: json: cannot unmarshal string into Go struct field foo.age of type int") - - // error if rows aren't ready to be scanned - e.g. next hasn't been called - rows, err = db.QueryxContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`) - require.NoError(t, err) - err = dbutil.ScanAndValidateJSON(rows, f) - assert.EqualError(t, err, "error scanning row JSON: sql: Scan called without calling Next") - - // error if we request validation and returned JSON is invalid - rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`) - - err = dbutil.ScanAndValidateJSON(rows, f) - assert.EqualError(t, err, "error validating unmarsalled JSON: Key: 'foo.Age' Error:Field validation for 'Age' failed on the 'min' tag") - - // no error if we don't do validation - rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`) - - err = dbutil.ScanJSON(rows, f) - assert.NoError(t, err) - assert.Equal(t, "a5850c89-dd29-46f6-9de1-d068b3c2db94", f.UUID) - assert.Equal(t, "George", f.Name) - assert.Equal(t, -1, f.Age) -} diff --git a/dbutil/scan.go b/dbutil/scan.go index f33884a..a66fcfb 100644 --- a/dbutil/scan.go +++ b/dbutil/scan.go @@ -1,9 +1,47 @@ package dbutil -import "database/sql" +import ( + "database/sql" + "encoding/json" -// SliceScan scans a single value from each row into the given slice -func SliceScan[V any](rows *sql.Rows, s []V) ([]V, error) { + "github.com/go-playground/validator/v10" + "github.com/pkg/errors" +) + +var validate = validator.New() + +// ScanJSON scans a row which is JSON into a destination struct +func ScanJSON(rows *sql.Rows, destination any) error { + var raw json.RawMessage + err := rows.Scan(&raw) + if err != nil { + return errors.Wrap(err, "error scanning row JSON") + } + + err = json.Unmarshal(raw, destination) + if err != nil { + return errors.Wrap(err, "error unmarshalling row JSON") + } + + return nil +} + +// ScanAndValidateJSON scans a row which is JSON into a destination struct and validates it +func ScanAndValidateJSON(rows *sql.Rows, destination any) error { + if err := ScanJSON(rows, destination); err != nil { + return err + } + + err := validate.Struct(destination) + if err != nil { + return errors.Wrapf(err, "error validating unmarsalled JSON") + } + + return nil +} + +// ScanAllSlice scans a single value from each single column row into the given slice +func ScanAllSlice[V any](rows *sql.Rows, s []V) ([]V, error) { defer rows.Close() var v V @@ -17,8 +55,8 @@ func SliceScan[V any](rows *sql.Rows, s []V) ([]V, error) { return s, rows.Err() } -// MapScan scans a key and value from each row into the given map -func MapScan[K comparable, V any](rows *sql.Rows, m map[K]V) error { +// ScanAllMap scans a key and value from each two column row into the given map +func ScanAllMap[K comparable, V any](rows *sql.Rows, m map[K]V) error { defer rows.Close() var k K diff --git a/dbutil/scan_test.go b/dbutil/scan_test.go index c778767..f35c6c5 100644 --- a/dbutil/scan_test.go +++ b/dbutil/scan_test.go @@ -1,6 +1,8 @@ package dbutil_test import ( + "context" + "database/sql" "testing" "github.com/nyaruka/gocommon/dbutil" @@ -8,7 +10,83 @@ import ( "github.com/stretchr/testify/require" ) -func TestSliceScan(t *testing.T) { +func TestScanJSON(t *testing.T) { + ctx := context.Background() + db := getTestDB() + + defer func() { + db.MustExec(`DROP TABLE foo`) + }() + + db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, uuid UUID NOT NULL, name VARCHAR(10), age INT)`) + db.MustExec(`INSERT INTO foo (uuid, name, age) VALUES('11163af6-a2ee-486d-b6dc-984174f10eec', 'Bob', 40)`) + db.MustExec(`INSERT INTO foo (uuid, name, age) VALUES('57d3f887-9ae1-4292-8fa4-ffc11e31e2f7', 'Cathy', 30)`) + db.MustExec(`INSERT INTO foo (uuid, name, age) VALUES('a5850c89-dd29-46f6-9de1-d068b3c2db94', 'George', -1)`) + + type foo struct { + UUID string `json:"uuid" validate:"required"` + Name string `json:"name"` + Age int `json:"age" validate:"min=0"` + } + + queryRows := func(sql string, args ...any) *sql.Rows { + rows, err := db.QueryContext(ctx, sql, args...) + require.NoError(t, err) + require.True(t, rows.Next()) + return rows + } + + // if query returns valid JSON which can be unmarshaled into our struct, all good + rows := queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 1) r`) + + f := &foo{} + err := dbutil.ScanAndValidateJSON(rows, f) + assert.NoError(t, err) + assert.Equal(t, "11163af6-a2ee-486d-b6dc-984174f10eec", f.UUID) + assert.Equal(t, "Bob", f.Name) + assert.Equal(t, 40, f.Age) + + rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 2) r`) + + err = dbutil.ScanAndValidateJSON(rows, f) + assert.NoError(t, err) + assert.Equal(t, "57d3f887-9ae1-4292-8fa4-ffc11e31e2f7", f.UUID) + assert.Equal(t, "Cathy", f.Name) + assert.Equal(t, 30, f.Age) + + // error if row value is not JSON + rows = queryRows(`SELECT id FROM foo f WHERE id = 1`) + err = dbutil.ScanAndValidateJSON(rows, f) + assert.EqualError(t, err, `error scanning row JSON: sql: Scan error on column index 0, name "id": unsupported Scan, storing driver.Value type int64 into type *json.RawMessage`) + + // error if we can't marshal into the struct + rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS age FROM foo f WHERE id = 1) r`) + err = dbutil.ScanAndValidateJSON(rows, f) + assert.EqualError(t, err, "error unmarshalling row JSON: json: cannot unmarshal string into Go struct field foo.age of type int") + + // error if rows aren't ready to be scanned - e.g. next hasn't been called + rows, err = db.QueryContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`) + require.NoError(t, err) + err = dbutil.ScanAndValidateJSON(rows, f) + assert.EqualError(t, err, "error scanning row JSON: sql: Scan called without calling Next") + + // error if we request validation and returned JSON is invalid + rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`) + + err = dbutil.ScanAndValidateJSON(rows, f) + assert.EqualError(t, err, "error validating unmarsalled JSON: Key: 'foo.Age' Error:Field validation for 'Age' failed on the 'min' tag") + + // no error if we don't do validation + rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`) + + err = dbutil.ScanJSON(rows, f) + assert.NoError(t, err) + assert.Equal(t, "a5850c89-dd29-46f6-9de1-d068b3c2db94", f.UUID) + assert.Equal(t, "George", f.Name) + assert.Equal(t, -1, f.Age) +} + +func TestScanAllSlice(t *testing.T) { db := getTestDB() defer func() { db.MustExec(`DROP TABLE foo`) }() @@ -22,7 +100,7 @@ func TestSliceScan(t *testing.T) { require.NoError(t, err) ids := make([]int, 0, 2) - ids, err = dbutil.SliceScan(rows, ids) + ids, err = dbutil.ScanAllSlice(rows, ids) require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, ids) @@ -30,12 +108,12 @@ func TestSliceScan(t *testing.T) { require.NoError(t, err) names := make([]string, 0, 2) - names, err = dbutil.SliceScan(rows, names) + names, err = dbutil.ScanAllSlice(rows, names) require.NoError(t, err) assert.Equal(t, []string{"Cat", "Bob", "Ann"}, names) } -func TestMapScan(t *testing.T) { +func TestScanAllMap(t *testing.T) { db := getTestDB() defer func() { db.MustExec(`DROP TABLE foo`) }() @@ -49,7 +127,7 @@ func TestMapScan(t *testing.T) { require.NoError(t, err) nameByID := make(map[int]string, 2) - err = dbutil.MapScan(rows, nameByID) + err = dbutil.ScanAllMap(rows, nameByID) require.NoError(t, err) assert.Equal(t, map[int]string{1: "Ann", 2: "Bob", 3: "Cat"}, nameByID) @@ -57,7 +135,7 @@ func TestMapScan(t *testing.T) { require.NoError(t, err) idByName := make(map[string]int, 2) - err = dbutil.MapScan(rows, idByName) + err = dbutil.ScanAllMap(rows, idByName) require.NoError(t, err) assert.Equal(t, map[string]int{"Ann": 1, "Bob": 2, "Cat": 3}, idByName) }