From 8fd8394c4d4695cc723a7a3d30f5c466cde30000 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 23 Dec 2023 22:55:30 +0200 Subject: [PATCH] Add query helper to dbutil --- dbutil/queryhelper.go | 113 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + 3 files changed, 116 insertions(+) create mode 100644 dbutil/queryhelper.go diff --git a/dbutil/queryhelper.go b/dbutil/queryhelper.go new file mode 100644 index 0000000..10de3cd --- /dev/null +++ b/dbutil/queryhelper.go @@ -0,0 +1,113 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package dbutil + +import ( + "context" + "database/sql" + "errors" + + "golang.org/x/exp/constraints" +) + +// DataStruct is an interface for structs that represent a single database row. +type DataStruct[T any] interface { + Scan(row Scannable) (T, error) +} + +// QueryHelper is a generic helper struct for SQL query execution boilerplate. +// +// After implementing the Scan and Init methods in a data struct, the query +// helper allows writing query functions in a single line. +type QueryHelper[T DataStruct[T]] struct { + db *Database + newFunc func(qh *QueryHelper[T]) T +} + +func MakeQueryHelper[T DataStruct[T]](db *Database, new func(qh *QueryHelper[T]) T) *QueryHelper[T] { + return &QueryHelper[T]{db: db, newFunc: new} +} + +// ValueOrErr is a helper function that returns the value if err is nil, or +// returns nil and the error if err is not nil. It can be used to avoid +// `if err != nil { return nil, err }` boilerplate in certain cases like +// DataStruct.Scan implementations. +func ValueOrErr[T any](val *T, err error) (*T, error) { + if err != nil { + return nil, err + } + return val, nil +} + +// StrPtr returns a pointer to the given string, or nil if the string is empty. +func StrPtr[T ~string, T2 *T](val T) *string { + if val == "" { + return nil + } + strVal := string(val) + return &strVal +} + +// NumPtr returns a pointer to the given number, or nil if the number is zero. +func NumPtr[T constraints.Integer | constraints.Float](val T) *T { + if val == 0 { + return nil + } + return &val +} + +func (qh *QueryHelper[T]) GetDB() *Database { + return qh.db +} + +func (qh *QueryHelper[T]) New() T { + return qh.newFunc(qh) +} + +// Exec executes a query with ExecContext and returns the error. +// +// It omits the sql.Result return value, as it is rarely used. When the result +// is wanted, use `qh.GetDB().Conn(ctx).ExecContext(...)` instead, which is +// otherwise equivalent. +func (qh *QueryHelper[T]) Exec(ctx context.Context, query string, args ...any) error { + _, err := qh.db.Conn(ctx).ExecContext(ctx, query, args...) + return err +} + +// QueryOne executes a query with QueryRowContext, uses the associated DataStruct +// to scan it, and returns the value. If the query returns no rows, it returns nil +// and no error. +func (qh *QueryHelper[T]) QueryOne(ctx context.Context, query string, args ...any) (val T, err error) { + val, err = qh.New().Scan(qh.db.Conn(ctx).QueryRowContext(ctx, query, args...)) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return val, err +} + +// QueryMany executes a query with QueryContext, uses the associated DataStruct +// to scan each row, and returns the values. If the query returns no rows, it +// returns a non-nil zero-length slice and no error. +func (qh *QueryHelper[T]) QueryMany(ctx context.Context, query string, args ...any) ([]T, error) { + rows, err := qh.db.Conn(ctx).QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + _ = rows.Close() + }() + items := make([]T, 0) + var item T + for rows.Next() { + item, err = qh.New().Scan(rows) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, rows.Err() +} diff --git a/go.mod b/go.mod index 1715991..141a40b 100644 --- a/go.mod +++ b/go.mod @@ -15,5 +15,6 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect golang.org/x/sys v0.12.0 // indirect ) diff --git a/go.sum b/go.sum index 0163640..e852add 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=