diff --git a/dbutil/bulk.go b/dbutil/bulk.go index 1b0e448..52e0c9c 100644 --- a/dbutil/bulk.go +++ b/dbutil/bulk.go @@ -71,20 +71,23 @@ func BulkQuery[T any](ctx context.Context, db BulkQueryer, query string, structs rows, err := db.QueryxContext(ctx, bulkQuery, args...) if err != nil { - return NewQueryErrorf(err, bulkQuery, args, "error making bulk query") + return QueryErrorWrapf(err, bulkQuery, args, "error making bulk query") } defer rows.Close() // if have a returning clause, read them back and try to map them if strings.Contains(strings.ToUpper(query), "RETURNING") { - for _, s := range structs { + for i, s := range structs { if !rows.Next() { - return errors.Errorf("did not receive expected number of rows on insert") + if rows.Err() != nil { + return QueryErrorWrapf(rows.Err(), bulkQuery, args, "missing returned row for struct %d", i) + } + return QueryErrorf(bulkQuery, args, "missing returned row for struct %d", i) } err = rows.StructScan(s) if err != nil { - return errors.Wrap(err, "error scanning for returned values") + return QueryErrorWrapf(err, bulkQuery, args, "error scanning returned row %d", i) } } } @@ -94,11 +97,7 @@ func BulkQuery[T any](ctx context.Context, db BulkQueryer, query string, structs } // check for any error - if rows.Err() != nil { - return NewQueryErrorf(rows.Err(), bulkQuery, args, "error during row iteration") - } - - return nil + return QueryErrorWrapf(rows.Err(), bulkQuery, args, "error during row iteration") } // extractValues is just a simple utility method that extracts the portion between `VALUE(` diff --git a/dbutil/errors.go b/dbutil/errors.go index cb56c66..5524ce7 100644 --- a/dbutil/errors.go +++ b/dbutil/errors.go @@ -25,7 +25,10 @@ type QueryError struct { } func (e *QueryError) Error() string { - return e.message + ": " + e.cause.Error() + if e.cause != nil { + return e.message + ": " + e.cause.Error() + } + return e.message } func (e *QueryError) Unwrap() error { @@ -36,7 +39,18 @@ func (e *QueryError) Query() (string, []any) { return e.sql, e.sqlArgs } -func NewQueryErrorf(cause error, sql string, sqlArgs []any, message string, msgArgs ...any) error { +func QueryErrorWrapf(cause error, sql string, sqlArgs []any, message string, msgArgs ...any) error { + if cause == nil { + return nil + } + return newQueryErrorf(cause, sql, sqlArgs, message, msgArgs...) +} + +func QueryErrorf(sql string, sqlArgs []any, message string, msgArgs ...any) error { + return newQueryErrorf(nil, sql, sqlArgs, message, msgArgs...) +} + +func newQueryErrorf(cause error, sql string, sqlArgs []any, message string, msgArgs ...any) error { return &QueryError{ cause: cause, message: fmt.Sprintf(message, msgArgs...), diff --git a/dbutil/errors_test.go b/dbutil/errors_test.go index 505bdb0..f96d3ab 100644 --- a/dbutil/errors_test.go +++ b/dbutil/errors_test.go @@ -1,6 +1,7 @@ package dbutil_test import ( + "fmt" "testing" "github.com/lib/pq" @@ -19,11 +20,18 @@ func TestIsUniqueViolation(t *testing.T) { } func TestQueryError(t *testing.T) { + qerr := dbutil.QueryErrorf("SELECT * FROM foo WHERE id = $1", []any{234}, "error selecting foo %d", 234) + assert.Error(t, qerr) + assert.Equal(t, `error selecting foo 234`, qerr.Error()) + assert.Equal(t, `error selecting foo 234`, fmt.Sprintf("%s", qerr)) + + // can also wrap an existing error var err error = &pq.Error{Code: pq.ErrorCode("22025"), Message: "unsupported Unicode escape sequence"} - qerr := dbutil.NewQueryErrorf(err, "SELECT * FROM foo WHERE id = $1", []any{234}, "error selecting foo %d", 234) + qerr = dbutil.QueryErrorWrapf(err, "SELECT * FROM foo WHERE id = $1", []any{234}, "error selecting foo %d", 234) assert.Error(t, qerr) assert.Equal(t, `error selecting foo 234: pq: unsupported Unicode escape sequence`, qerr.Error()) + assert.Equal(t, `error selecting foo 234: pq: unsupported Unicode escape sequence`, fmt.Sprintf("%s", qerr)) // can unwrap to the original error var pqerr *pq.Error @@ -42,4 +50,7 @@ func TestQueryError(t *testing.T) { query, params := unwrapped.Query() assert.Equal(t, "SELECT * FROM foo WHERE id = $1", query) assert.Equal(t, []any{234}, params) + + // wrapping a nil error returns nil + assert.Nil(t, dbutil.QueryErrorWrapf(nil, "SELECT", nil, "ooh")) }