Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for rebinding queries. #25

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ module go.step.sm/sequel
go 1.21

require (
github.com/go-sqlx/sqlx v1.3.8
github.com/jackc/pgx/v5 v5.5.5
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.30.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.30.0
go.step.sm/qb v1.4.0
)

require (
Expand All @@ -17,6 +20,7 @@ require (
github.com/containerd/containerd v1.7.12 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.5.0 // indirect
github.com/docker/docker v25.0.5+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
Expand All @@ -42,6 +46,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/shirou/gopsutil/v3 v3.23.12 // indirect
Expand All @@ -64,13 +69,5 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
google.golang.org/grpc v1.58.3 // indirect
google.golang.org/protobuf v1.33.0 // indirect
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sqlx/sqlx v1.3.8
github.com/jackc/pgx/v5 v5.5.5
github.com/pmezard/go-difflib v1.0.0 // indirect
go.step.sm/qb v1.3.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.step.sm/qb v1.3.0 h1:+EM326xJQsX9vHPATMTAa2YqIeYlpvUuKPuZj8V/xM4=
go.step.sm/qb v1.3.0/go.mod h1:V+B0IstgCmVyM27km7OImv90GsvtuUhNczxRkIDW4qo=
go.step.sm/qb v1.4.0 h1:U0jCLu7UADtJeZbd19ZFonu+a7mFfoL6KPKGPWRYW+c=
go.step.sm/qb v1.4.0/go.mod h1:V+B0IstgCmVyM27km7OImv90GsvtuUhNczxRkIDW4qo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
Expand Down
2 changes: 1 addition & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ func NullString(s string) sql.NullString {
// if the zero value is given.
func NullTime(t time.Time) sql.NullTime {
return sql.NullTime{
Time: t, Valid: t.IsZero(),
Time: t, Valid: !t.IsZero(),
}
}
36 changes: 36 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package sequel

import (
"database/sql"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestHelpers(t *testing.T) {
assert.Equal(t, sql.NullBool{Bool: true, Valid: true}, NullBool(true))
assert.Equal(t, sql.NullBool{Bool: false, Valid: false}, NullBool(false))

assert.Equal(t, sql.NullByte{Byte: 1, Valid: true}, NullByte(1))
assert.Equal(t, sql.NullByte{Byte: 0, Valid: false}, NullByte(0))

assert.Equal(t, sql.NullFloat64{Float64: 1.1, Valid: true}, NullFloat64(1.1))
assert.Equal(t, sql.NullFloat64{Float64: 0, Valid: false}, NullFloat64(0))

assert.Equal(t, sql.NullInt16{Int16: 1, Valid: true}, NullInt16(1))
assert.Equal(t, sql.NullInt16{Int16: 0, Valid: false}, NullInt16(0))

assert.Equal(t, sql.NullInt32{Int32: 1, Valid: true}, NullInt32(1))
assert.Equal(t, sql.NullInt32{Int32: 0, Valid: false}, NullInt32(0))

assert.Equal(t, sql.NullInt64{Int64: 1, Valid: true}, NullInt64(1))
assert.Equal(t, sql.NullInt64{Int64: 0, Valid: false}, NullInt64(0))

assert.Equal(t, sql.NullString{String: "abc", Valid: true}, NullString("abc"))
assert.Equal(t, sql.NullString{String: "", Valid: false}, NullString(""))

now := time.Now()
assert.Equal(t, sql.NullTime{Time: now, Valid: true}, NullTime(now))
assert.Equal(t, sql.NullTime{Time: time.Time{}, Valid: false}, NullTime(time.Time{}))
}
125 changes: 122 additions & 3 deletions sequel.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ type DB struct {
}

type options struct {
Clock clock.Clock
Clock clock.Clock
DriverName string
}

// Option is the type of options that can be used to modify the database. This
Expand All @@ -42,17 +43,27 @@ func WithClock(c clock.Clock) Option {
}
}

// WithDriver defines the driver to use, defaults to pgx/v5. This default driver
// is automatically loaded by this package, any other driver must be loaded by
// the user.
func WithDriver(driverName string) Option {
return func(o *options) {
o.DriverName = driverName
}
}

// New creates a new DB. It will fail if it cannot ping it.
func New(dataSourceName string, opts ...Option) (*DB, error) {
options := &options{
Clock: clock.New(),
Clock: clock.New(),
DriverName: "pgx/v5",
}
for _, fn := range opts {
fn(options)
}

// Connect opens the database and verifies with a ping
db, err := sqlx.Connect("pgx/v5", dataSourceName)
db, err := sqlx.Connect(options.DriverName, dataSourceName)
if err != nil {
return nil, fmt.Errorf("error connecting to the database: %w", err)
}
Expand Down Expand Up @@ -118,6 +129,11 @@ func (d *DB) Close() error {
return d.db.Close()
}

// Rebind transforms a query from `?` to the DB driver's bind type.
func (d *DB) Rebind(query string) string {
return d.db.Rebind(query)
}

// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (d *DB) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
Expand All @@ -141,6 +157,44 @@ func (d *DB) Exec(ctx context.Context, query string, args ...any) (sql.Result, e
return d.db.ExecContext(ctx, query, args...)
}

// Query executes a query that returns rows, typically a SELECT. The query is
// rebound from `?` to the DB driver's bind type. The args are for any
// placeholder parameters in the query.
func (d *DB) RebindQuery(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return d.db.QueryContext(ctx, d.db.Rebind(query), args...)
}

// QueryRow executes a query that is expected to return at most one row. The
// query is rebound from `?` to the DB driver's bind type. QueryRowContext
// always returns a non-nil value. Errors are deferred until Row's Scan method
// is called.
//
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the
// rest.
func (d *DB) RebindQueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return d.db.QueryRowContext(ctx, d.db.Rebind(query), args...)
}

// Exec executes a query without returning any rows. The query is rebound from
// `?` to the DB driver's bind type. The args are for any placeholder parameters
// in the query.
func (d *DB) RebindExec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return d.db.ExecContext(ctx, d.db.Rebind(query), args...)
}

// NamedQuery executes a query that returns rows. Any named placeholder
// parameters are replaced with fields from arg.
func (d *DB) NamedQuery(ctx context.Context, query string, arg any) (*sqlx.Rows, error) {
return d.db.NamedQueryContext(ctx, query, arg)
}

// NamedExec using executes a query without returning any rows. Any named
// placeholder parameters are replaced with fields from arg.
func (d *DB) NamedExec(ctx context.Context, query string, arg any) (sql.Result, error) {
return d.db.NamedExecContext(ctx, query, arg)
}

// Get populates the given model for the result of the given select query.
func (d *DB) Get(ctx context.Context, dest Model, query string, args ...any) error {
return d.db.GetContext(ctx, dest, query, args...)
Expand Down Expand Up @@ -295,6 +349,11 @@ func (d *DB) Begin(ctx context.Context) (*Tx, error) {
}, nil
}

// Rebind transforms a query from QUESTION to the DB driver's bind type.
func (t *Tx) Rebind(query string) string {
return t.tx.Rebind(query)
}

// Commit commits the transaction.
func (t *Tx) Commit() error {
return t.tx.Commit()
Expand All @@ -305,12 +364,72 @@ func (t *Tx) Rollback() error {
return t.tx.Rollback()
}

// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (t *Tx) Query(query string, args ...any) (*sql.Rows, error) {
return t.tx.Query(query, args...)
}

// QueryRow executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
//
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the
// rest.
func (t *Tx) QueryRow(query string, args ...any) *sql.Row {
return t.tx.QueryRow(query, args...)
}

// Exec executes a query without returning any rows. The args are for any
// placeholder parameters in the query.
func (t *Tx) Exec(query string, args ...any) (sql.Result, error) {
return t.tx.Exec(query, args...)
}

// Query executes a query that returns rows, typically a SELECT. The query is
// rebound from `?` to the DB driver's bind type. The args are for any
// placeholder parameters in the query.
func (t *Tx) RebindQuery(query string, args ...any) (*sql.Rows, error) {
return t.tx.Query(t.tx.Rebind(query), args...)
}

// QueryRow executes a query that is expected to return at most one row. The
// query is rebound from `?` to the DB driver's bind type. QueryRowContext
// always returns a non-nil value. Errors are deferred until Row's Scan method
// is called.
//
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the
// rest.
func (t *Tx) RebindQueryRow(query string, args ...any) *sql.Row {
return t.tx.QueryRow(t.tx.Rebind(query), args...)
}

// Exec executes a query without returning any rows. The query is rebound from
// `?` to the DB driver's bind type. The args are for any placeholder parameters
// in the query.
func (t *Tx) RebindExec(query string, args ...any) (sql.Result, error) {
return t.tx.Exec(t.tx.Rebind(query), args...)
}

// NamedQuery executes a query that returns rows. Any named placeholder
// parameters are replaced with fields from arg.
func (t *Tx) NamedQuery(query string, arg any) (*sqlx.Rows, error) {
return t.tx.NamedQuery(query, arg)
}

// NamedExec using executes a query without returning any rows. Any named
// placeholder parameters are replaced with fields from arg.
func (t *Tx) NamedExec(query string, arg any) (sql.Result, error) {
return t.tx.NamedExec(query, arg)
}

// Get populates the given model for the result of the given select query.
func (t *Tx) Get(dest Model, query string, args ...any) error {
return t.tx.Get(dest, query, args...)
}

// Insert adds a new insert query for the given model in the transaction.
func (t *Tx) Insert(arg Model) error {
var id string
Expand Down
Loading
Loading