Skip to content

Commit

Permalink
Add support for rebinding queries.
Browse files Browse the repository at this point in the history
This commit adds new methods for rebinding queries. This allows to
support multiple databases using queries with the bind parameter `?`.
  • Loading branch information
maraino committed Apr 19, 2024
1 parent ced317a commit b3a10de
Show file tree
Hide file tree
Showing 6 changed files with 453 additions and 15 deletions.
13 changes: 5 additions & 8 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ module go.step.sm/sequel
go 1.21

require (
github.com/go-sqlx/sqlx v1.3.8
github.com/jackc/pgx/v5 v5.5.5
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.30.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.30.0
go.step.sm/qb v1.4.0
)

require (
Expand All @@ -17,6 +20,7 @@ require (
github.com/containerd/containerd v1.7.12 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.5.0 // indirect
github.com/docker/docker v25.0.5+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
Expand All @@ -42,6 +46,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/shirou/gopsutil/v3 v3.23.12 // indirect
Expand All @@ -64,13 +69,5 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
google.golang.org/grpc v1.58.3 // indirect
google.golang.org/protobuf v1.33.0 // indirect
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sqlx/sqlx v1.3.8
github.com/jackc/pgx/v5 v5.5.5
github.com/pmezard/go-difflib v1.0.0 // indirect
go.step.sm/qb v1.3.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.step.sm/qb v1.3.0 h1:+EM326xJQsX9vHPATMTAa2YqIeYlpvUuKPuZj8V/xM4=
go.step.sm/qb v1.3.0/go.mod h1:V+B0IstgCmVyM27km7OImv90GsvtuUhNczxRkIDW4qo=
go.step.sm/qb v1.4.0 h1:U0jCLu7UADtJeZbd19ZFonu+a7mFfoL6KPKGPWRYW+c=
go.step.sm/qb v1.4.0/go.mod h1:V+B0IstgCmVyM27km7OImv90GsvtuUhNczxRkIDW4qo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
Expand Down
2 changes: 1 addition & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ func NullString(s string) sql.NullString {
// if the zero value is given.
func NullTime(t time.Time) sql.NullTime {
return sql.NullTime{
Time: t, Valid: t.IsZero(),
Time: t, Valid: !t.IsZero(),
}
}
36 changes: 36 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package sequel

import (
"database/sql"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestHelpers(t *testing.T) {
assert.Equal(t, sql.NullBool{Bool: true, Valid: true}, NullBool(true))
assert.Equal(t, sql.NullBool{Bool: false, Valid: false}, NullBool(false))

assert.Equal(t, sql.NullByte{Byte: 1, Valid: true}, NullByte(1))
assert.Equal(t, sql.NullByte{Byte: 0, Valid: false}, NullByte(0))

assert.Equal(t, sql.NullFloat64{Float64: 1.1, Valid: true}, NullFloat64(1.1))
assert.Equal(t, sql.NullFloat64{Float64: 0, Valid: false}, NullFloat64(0))

assert.Equal(t, sql.NullInt16{Int16: 1, Valid: true}, NullInt16(1))
assert.Equal(t, sql.NullInt16{Int16: 0, Valid: false}, NullInt16(0))

assert.Equal(t, sql.NullInt32{Int32: 1, Valid: true}, NullInt32(1))
assert.Equal(t, sql.NullInt32{Int32: 0, Valid: false}, NullInt32(0))

assert.Equal(t, sql.NullInt64{Int64: 1, Valid: true}, NullInt64(1))
assert.Equal(t, sql.NullInt64{Int64: 0, Valid: false}, NullInt64(0))

assert.Equal(t, sql.NullString{String: "abc", Valid: true}, NullString("abc"))
assert.Equal(t, sql.NullString{String: "", Valid: false}, NullString(""))

now := time.Now()
assert.Equal(t, sql.NullTime{Time: now, Valid: true}, NullTime(now))
assert.Equal(t, sql.NullTime{Time: time.Time{}, Valid: false}, NullTime(time.Time{}))
}
125 changes: 122 additions & 3 deletions sequel.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ type DB struct {
}

type options struct {
Clock clock.Clock
Clock clock.Clock
DriverName string
}

// Option is the type of options that can be used to modify the database. This
Expand All @@ -42,17 +43,27 @@ func WithClock(c clock.Clock) Option {
}
}

// WithDriver defines the driver to use, defaults to pgx/v5. This default driver
// is automatically loaded by this package, any other driver must be loaded by
// the user.
func WithDriver(driverName string) Option {
return func(o *options) {
o.DriverName = driverName
}
}

// New creates a new DB. It will fail if it cannot ping it.
func New(dataSourceName string, opts ...Option) (*DB, error) {
options := &options{
Clock: clock.New(),
Clock: clock.New(),
DriverName: "pgx/v5",
}
for _, fn := range opts {
fn(options)
}

// Connect opens the database and verifies with a ping
db, err := sqlx.Connect("pgx/v5", dataSourceName)
db, err := sqlx.Connect(options.DriverName, dataSourceName)
if err != nil {
return nil, fmt.Errorf("error connecting to the database: %w", err)
}
Expand Down Expand Up @@ -118,6 +129,11 @@ func (d *DB) Close() error {
return d.db.Close()
}

// Rebind transforms a query from `?` to the DB driver's bind type.
func (d *DB) Rebind(query string) string {
return d.db.Rebind(query)
}

// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (d *DB) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
Expand All @@ -141,6 +157,44 @@ func (d *DB) Exec(ctx context.Context, query string, args ...any) (sql.Result, e
return d.db.ExecContext(ctx, query, args...)
}

// Query executes a query that returns rows, typically a SELECT. The query is
// rebound from `?` to the DB driver's bind type. The args are for any
// placeholder parameters in the query.
func (d *DB) RebindQuery(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return d.db.QueryContext(ctx, d.db.Rebind(query), args...)
}

// QueryRow executes a query that is expected to return at most one row. The
// query is rebound from `?` to the DB driver's bind type. QueryRowContext
// always returns a non-nil value. Errors are deferred until Row's Scan method
// is called.
//
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the
// rest.
func (d *DB) RebindQueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return d.db.QueryRowContext(ctx, d.db.Rebind(query), args...)
}

// Exec executes a query without returning any rows. The query is rebound from
// `?` to the DB driver's bind type. The args are for any placeholder parameters
// in the query.
func (d *DB) RebindExec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return d.db.ExecContext(ctx, d.db.Rebind(query), args...)
}

// NamedQuery executes a query that returns rows. Any named placeholder
// parameters are replaced with fields from arg.
func (d *DB) NamedQuery(ctx context.Context, query string, arg any) (*sqlx.Rows, error) {
return d.db.NamedQueryContext(ctx, query, arg)
}

// NamedExec using executes a query without returning any rows. Any named
// placeholder parameters are replaced with fields from arg.
func (d *DB) NamedExec(ctx context.Context, query string, arg any) (sql.Result, error) {
return d.db.NamedExecContext(ctx, query, arg)
}

// Get populates the given model for the result of the given select query.
func (d *DB) Get(ctx context.Context, dest Model, query string, args ...any) error {
return d.db.GetContext(ctx, dest, query, args...)
Expand Down Expand Up @@ -295,6 +349,11 @@ func (d *DB) Begin(ctx context.Context) (*Tx, error) {
}, nil
}

// Rebind transforms a query from QUESTION to the DB driver's bind type.
func (t *Tx) Rebind(query string) string {
return t.tx.Rebind(query)
}

// Commit commits the transaction.
func (t *Tx) Commit() error {
return t.tx.Commit()
Expand All @@ -305,12 +364,72 @@ func (t *Tx) Rollback() error {
return t.tx.Rollback()
}

// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (t *Tx) Query(query string, args ...any) (*sql.Rows, error) {
return t.tx.Query(query, args...)
}

// QueryRow executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
//
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the
// rest.
func (t *Tx) QueryRow(query string, args ...any) *sql.Row {
return t.tx.QueryRow(query, args...)
}

// Exec executes a query without returning any rows. The args are for any
// placeholder parameters in the query.
func (t *Tx) Exec(query string, args ...any) (sql.Result, error) {
return t.tx.Exec(query, args...)
}

// Query executes a query that returns rows, typically a SELECT. The query is
// rebound from `?` to the DB driver's bind type. The args are for any
// placeholder parameters in the query.
func (t *Tx) RebindQuery(query string, args ...any) (*sql.Rows, error) {
return t.tx.Query(t.tx.Rebind(query), args...)
}

// QueryRow executes a query that is expected to return at most one row. The
// query is rebound from `?` to the DB driver's bind type. QueryRowContext
// always returns a non-nil value. Errors are deferred until Row's Scan method
// is called.
//
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the
// rest.
func (t *Tx) RebindQueryRow(query string, args ...any) *sql.Row {
return t.tx.QueryRow(t.tx.Rebind(query), args...)
}

// Exec executes a query without returning any rows. The query is rebound from
// `?` to the DB driver's bind type. The args are for any placeholder parameters
// in the query.
func (t *Tx) RebindExec(query string, args ...any) (sql.Result, error) {
return t.tx.Exec(t.tx.Rebind(query), args...)
}

// NamedQuery executes a query that returns rows. Any named placeholder
// parameters are replaced with fields from arg.
func (t *Tx) NamedQuery(query string, arg any) (*sqlx.Rows, error) {
return t.tx.NamedQuery(query, arg)
}

// NamedExec using executes a query without returning any rows. Any named
// placeholder parameters are replaced with fields from arg.
func (t *Tx) NamedExec(query string, arg any) (sql.Result, error) {
return t.tx.NamedExec(query, arg)
}

// Get populates the given model for the result of the given select query.
func (t *Tx) Get(dest Model, query string, args ...any) error {
return t.tx.Get(dest, query, args...)
}

// Insert adds a new insert query for the given model in the transaction.
func (t *Tx) Insert(arg Model) error {
var id string
Expand Down
Loading

0 comments on commit b3a10de

Please sign in to comment.