From f0f224c9afb082cdb82da9fb1851f732b748d6a4 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Tue, 19 Sep 2023 13:35:53 -0500 Subject: [PATCH] Add dbutil.ScanAllJSON --- dbutil/scan.go | 17 ++++++++++++++++- dbutil/scan_test.go | 25 ++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/dbutil/scan.go b/dbutil/scan.go index a66fcfb..e83b326 100644 --- a/dbutil/scan.go +++ b/dbutil/scan.go @@ -40,7 +40,22 @@ func ScanAndValidateJSON(rows *sql.Rows, destination any) error { return nil } -// ScanAllSlice scans a single value from each single column row into the given slice +// ScanAllJSON scans all rows as a single column containing JSON that be unmarshalled into instances of V. +func ScanAllJSON[V any](rows *sql.Rows, s []V) ([]V, error) { + defer rows.Close() + + var v V + + for rows.Next() { + if err := ScanJSON(rows, &v); err != nil { + return nil, err + } + s = append(s, v) + } + return s, rows.Err() +} + +// ScanAllSlice scans all rows as a single value and returns them in the given slice. func ScanAllSlice[V any](rows *sql.Rows, s []V) ([]V, error) { defer rows.Close() diff --git a/dbutil/scan_test.go b/dbutil/scan_test.go index f35c6c5..7907a31 100644 --- a/dbutil/scan_test.go +++ b/dbutil/scan_test.go @@ -32,12 +32,12 @@ func TestScanJSON(t *testing.T) { queryRows := func(sql string, args ...any) *sql.Rows { rows, err := db.QueryContext(ctx, sql, args...) require.NoError(t, err) - require.True(t, rows.Next()) return rows } // if query returns valid JSON which can be unmarshaled into our struct, all good rows := queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 1) r`) + require.True(t, rows.Next()) f := &foo{} err := dbutil.ScanAndValidateJSON(rows, f) @@ -47,6 +47,7 @@ func TestScanJSON(t *testing.T) { assert.Equal(t, 40, f.Age) rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 2) r`) + require.True(t, rows.Next()) err = dbutil.ScanAndValidateJSON(rows, f) assert.NoError(t, err) @@ -56,34 +57,52 @@ func TestScanJSON(t *testing.T) { // error if row value is not JSON rows = queryRows(`SELECT id FROM foo f WHERE id = 1`) + require.True(t, rows.Next()) + err = dbutil.ScanAndValidateJSON(rows, f) assert.EqualError(t, err, `error scanning row JSON: sql: Scan error on column index 0, name "id": unsupported Scan, storing driver.Value type int64 into type *json.RawMessage`) // error if we can't marshal into the struct rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS age FROM foo f WHERE id = 1) r`) + require.True(t, rows.Next()) + err = dbutil.ScanAndValidateJSON(rows, f) assert.EqualError(t, err, "error unmarshalling row JSON: json: cannot unmarshal string into Go struct field foo.age of type int") // error if rows aren't ready to be scanned - e.g. next hasn't been called - rows, err = db.QueryContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`) - require.NoError(t, err) + rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid as uuid, f.name AS name FROM foo f WHERE id = 1) r`) + err = dbutil.ScanAndValidateJSON(rows, f) assert.EqualError(t, err, "error scanning row JSON: sql: Scan called without calling Next") // error if we request validation and returned JSON is invalid rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`) + require.True(t, rows.Next()) err = dbutil.ScanAndValidateJSON(rows, f) assert.EqualError(t, err, "error validating unmarsalled JSON: Key: 'foo.Age' Error:Field validation for 'Age' failed on the 'min' tag") + rows.Close() + // no error if we don't do validation rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`) + require.True(t, rows.Next()) err = dbutil.ScanJSON(rows, f) assert.NoError(t, err) assert.Equal(t, "a5850c89-dd29-46f6-9de1-d068b3c2db94", f.UUID) assert.Equal(t, "George", f.Name) assert.Equal(t, -1, f.Age) + + rows.Close() + + // can all scan all rows with ScanAllJSON + rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f) r`) + + var foos []*foo + foos, err = dbutil.ScanAllJSON(rows, foos) + assert.NoError(t, err) + assert.Len(t, foos, 3) } func TestScanAllSlice(t *testing.T) {