From b3a10debcc8b6ca5e32fc766ea8aa40525ef7416 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 19 Apr 2024 15:40:51 -0700 Subject: [PATCH] Add support for rebinding queries. This commit adds new methods for rebinding queries. This allows to support multiple databases using queries with the bind parameter `?`. --- go.mod | 13 +-- go.sum | 4 +- helpers.go | 2 +- helpers_test.go | 36 ++++++ sequel.go | 125 ++++++++++++++++++++- sequel_test.go | 288 +++++++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 453 insertions(+), 15 deletions(-) create mode 100644 helpers_test.go diff --git a/go.mod b/go.mod index 14d1cea..2d71393 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,12 @@ module go.step.sm/sequel go 1.21 require ( + github.com/go-sqlx/sqlx v1.3.8 + github.com/jackc/pgx/v5 v5.5.5 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.30.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.30.0 + go.step.sm/qb v1.4.0 ) require ( @@ -17,6 +20,7 @@ require ( github.com/containerd/containerd v1.7.12 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.5.0 // indirect github.com/docker/docker v25.0.5+incompatible // indirect github.com/docker/go-connections v0.5.0 // indirect @@ -42,6 +46,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/shirou/gopsutil/v3 v3.23.12 // indirect @@ -64,13 +69,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/grpc v1.58.3 // indirect google.golang.org/protobuf v1.33.0 // indirect -) - -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sqlx/sqlx v1.3.8 - github.com/jackc/pgx/v5 v5.5.5 - github.com/pmezard/go-difflib v1.0.0 // indirect - go.step.sm/qb v1.3.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0b4f608..72971a3 100644 --- a/go.sum +++ b/go.sum @@ -148,8 +148,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= -go.step.sm/qb v1.3.0 h1:+EM326xJQsX9vHPATMTAa2YqIeYlpvUuKPuZj8V/xM4= -go.step.sm/qb v1.3.0/go.mod h1:V+B0IstgCmVyM27km7OImv90GsvtuUhNczxRkIDW4qo= +go.step.sm/qb v1.4.0 h1:U0jCLu7UADtJeZbd19ZFonu+a7mFfoL6KPKGPWRYW+c= +go.step.sm/qb v1.4.0/go.mod h1:V+B0IstgCmVyM27km7OImv90GsvtuUhNczxRkIDW4qo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/helpers.go b/helpers.go index d057c78..c635cf7 100644 --- a/helpers.go +++ b/helpers.go @@ -65,6 +65,6 @@ func NullString(s string) sql.NullString { // if the zero value is given. func NullTime(t time.Time) sql.NullTime { return sql.NullTime{ - Time: t, Valid: t.IsZero(), + Time: t, Valid: !t.IsZero(), } } diff --git a/helpers_test.go b/helpers_test.go new file mode 100644 index 0000000..4bf2e42 --- /dev/null +++ b/helpers_test.go @@ -0,0 +1,36 @@ +package sequel + +import ( + "database/sql" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestHelpers(t *testing.T) { + assert.Equal(t, sql.NullBool{Bool: true, Valid: true}, NullBool(true)) + assert.Equal(t, sql.NullBool{Bool: false, Valid: false}, NullBool(false)) + + assert.Equal(t, sql.NullByte{Byte: 1, Valid: true}, NullByte(1)) + assert.Equal(t, sql.NullByte{Byte: 0, Valid: false}, NullByte(0)) + + assert.Equal(t, sql.NullFloat64{Float64: 1.1, Valid: true}, NullFloat64(1.1)) + assert.Equal(t, sql.NullFloat64{Float64: 0, Valid: false}, NullFloat64(0)) + + assert.Equal(t, sql.NullInt16{Int16: 1, Valid: true}, NullInt16(1)) + assert.Equal(t, sql.NullInt16{Int16: 0, Valid: false}, NullInt16(0)) + + assert.Equal(t, sql.NullInt32{Int32: 1, Valid: true}, NullInt32(1)) + assert.Equal(t, sql.NullInt32{Int32: 0, Valid: false}, NullInt32(0)) + + assert.Equal(t, sql.NullInt64{Int64: 1, Valid: true}, NullInt64(1)) + assert.Equal(t, sql.NullInt64{Int64: 0, Valid: false}, NullInt64(0)) + + assert.Equal(t, sql.NullString{String: "abc", Valid: true}, NullString("abc")) + assert.Equal(t, sql.NullString{String: "", Valid: false}, NullString("")) + + now := time.Now() + assert.Equal(t, sql.NullTime{Time: now, Valid: true}, NullTime(now)) + assert.Equal(t, sql.NullTime{Time: time.Time{}, Valid: false}, NullTime(time.Time{})) +} diff --git a/sequel.go b/sequel.go index bd91c4a..15d58ad 100644 --- a/sequel.go +++ b/sequel.go @@ -28,7 +28,8 @@ type DB struct { } type options struct { - Clock clock.Clock + Clock clock.Clock + DriverName string } // Option is the type of options that can be used to modify the database. This @@ -42,17 +43,27 @@ func WithClock(c clock.Clock) Option { } } +// WithDriver defines the driver to use, defaults to pgx/v5. This default driver +// is automatically loaded by this package, any other driver must be loaded by +// the user. +func WithDriver(driverName string) Option { + return func(o *options) { + o.DriverName = driverName + } +} + // New creates a new DB. It will fail if it cannot ping it. func New(dataSourceName string, opts ...Option) (*DB, error) { options := &options{ - Clock: clock.New(), + Clock: clock.New(), + DriverName: "pgx/v5", } for _, fn := range opts { fn(options) } // Connect opens the database and verifies with a ping - db, err := sqlx.Connect("pgx/v5", dataSourceName) + db, err := sqlx.Connect(options.DriverName, dataSourceName) if err != nil { return nil, fmt.Errorf("error connecting to the database: %w", err) } @@ -118,6 +129,11 @@ func (d *DB) Close() error { return d.db.Close() } +// Rebind transforms a query from `?` to the DB driver's bind type. +func (d *DB) Rebind(query string) string { + return d.db.Rebind(query) +} + // Query executes a query that returns rows, typically a SELECT. The args are // for any placeholder parameters in the query. func (d *DB) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) { @@ -141,6 +157,44 @@ func (d *DB) Exec(ctx context.Context, query string, args ...any) (sql.Result, e return d.db.ExecContext(ctx, query, args...) } +// Query executes a query that returns rows, typically a SELECT. The query is +// rebound from `?` to the DB driver's bind type. The args are for any +// placeholder parameters in the query. +func (d *DB) RebindQuery(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + return d.db.QueryContext(ctx, d.db.Rebind(query), args...) +} + +// QueryRow executes a query that is expected to return at most one row. The +// query is rebound from `?` to the DB driver's bind type. QueryRowContext +// always returns a non-nil value. Errors are deferred until Row's Scan method +// is called. +// +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards the +// rest. +func (d *DB) RebindQueryRow(ctx context.Context, query string, args ...any) *sql.Row { + return d.db.QueryRowContext(ctx, d.db.Rebind(query), args...) +} + +// Exec executes a query without returning any rows. The query is rebound from +// `?` to the DB driver's bind type. The args are for any placeholder parameters +// in the query. +func (d *DB) RebindExec(ctx context.Context, query string, args ...any) (sql.Result, error) { + return d.db.ExecContext(ctx, d.db.Rebind(query), args...) +} + +// NamedQuery executes a query that returns rows. Any named placeholder +// parameters are replaced with fields from arg. +func (d *DB) NamedQuery(ctx context.Context, query string, arg any) (*sqlx.Rows, error) { + return d.db.NamedQueryContext(ctx, query, arg) +} + +// NamedExec using executes a query without returning any rows. Any named +// placeholder parameters are replaced with fields from arg. +func (d *DB) NamedExec(ctx context.Context, query string, arg any) (sql.Result, error) { + return d.db.NamedExecContext(ctx, query, arg) +} + // Get populates the given model for the result of the given select query. func (d *DB) Get(ctx context.Context, dest Model, query string, args ...any) error { return d.db.GetContext(ctx, dest, query, args...) @@ -295,6 +349,11 @@ func (d *DB) Begin(ctx context.Context) (*Tx, error) { }, nil } +// Rebind transforms a query from QUESTION to the DB driver's bind type. +func (t *Tx) Rebind(query string) string { + return t.tx.Rebind(query) +} + // Commit commits the transaction. func (t *Tx) Commit() error { return t.tx.Commit() @@ -305,12 +364,72 @@ func (t *Tx) Rollback() error { return t.tx.Rollback() } +// Query executes a query that returns rows, typically a SELECT. The args are +// for any placeholder parameters in the query. +func (t *Tx) Query(query string, args ...any) (*sql.Rows, error) { + return t.tx.Query(query, args...) +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards the +// rest. +func (t *Tx) QueryRow(query string, args ...any) *sql.Row { + return t.tx.QueryRow(query, args...) +} + // Exec executes a query without returning any rows. The args are for any // placeholder parameters in the query. func (t *Tx) Exec(query string, args ...any) (sql.Result, error) { return t.tx.Exec(query, args...) } +// Query executes a query that returns rows, typically a SELECT. The query is +// rebound from `?` to the DB driver's bind type. The args are for any +// placeholder parameters in the query. +func (t *Tx) RebindQuery(query string, args ...any) (*sql.Rows, error) { + return t.tx.Query(t.tx.Rebind(query), args...) +} + +// QueryRow executes a query that is expected to return at most one row. The +// query is rebound from `?` to the DB driver's bind type. QueryRowContext +// always returns a non-nil value. Errors are deferred until Row's Scan method +// is called. +// +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards the +// rest. +func (t *Tx) RebindQueryRow(query string, args ...any) *sql.Row { + return t.tx.QueryRow(t.tx.Rebind(query), args...) +} + +// Exec executes a query without returning any rows. The query is rebound from +// `?` to the DB driver's bind type. The args are for any placeholder parameters +// in the query. +func (t *Tx) RebindExec(query string, args ...any) (sql.Result, error) { + return t.tx.Exec(t.tx.Rebind(query), args...) +} + +// NamedQuery executes a query that returns rows. Any named placeholder +// parameters are replaced with fields from arg. +func (t *Tx) NamedQuery(query string, arg any) (*sqlx.Rows, error) { + return t.tx.NamedQuery(query, arg) +} + +// NamedExec using executes a query without returning any rows. Any named +// placeholder parameters are replaced with fields from arg. +func (t *Tx) NamedExec(query string, arg any) (sql.Result, error) { + return t.tx.NamedExec(query, arg) +} + +// Get populates the given model for the result of the given select query. +func (t *Tx) Get(dest Model, query string, args ...any) error { + return t.tx.Get(dest, query, args...) +} + // Insert adds a new insert query for the given model in the transaction. func (t *Tx) Insert(arg Model) error { var id string diff --git a/sequel_test.go b/sequel_test.go index d58a959..554fc19 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -9,11 +9,12 @@ import ( "testing" "time" + "github.com/go-sqlx/sqlx" "github.com/jackc/pgx/v5/pgconn" + _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/qb" - "go.step.sm/sequel/clock" ) @@ -66,6 +67,7 @@ func TestNew(t *testing.T) { }{ {"ok", args{postgresDataSource, nil}, assert.NoError}, {"ok with clock", args{postgresDataSource, []Option{WithClock(clock.NewMock(time.Now()))}}, assert.NoError}, + {"ok with driver", args{postgresDataSource, []Option{WithDriver("pgx/v5")}}, assert.NoError}, {"fail ping", args{strings.ReplaceAll(postgresDataSource, dbUser, "foo"), nil}, assert.Error}, } for _, tt := range tests { @@ -212,6 +214,11 @@ func TestDBQueries(t *testing.T) { ctx := context.Background() + t.Run("rebind", func(t *testing.T) { + query := db.Rebind("SELECT * FROM person_test WHERE name = ? AND email = ?") + assert.Equal(t, "SELECT * FROM person_test WHERE name = $1 AND email = $2", query) + }) + t.Run("insert", func(t *testing.T) { assert.NoError(t, db.Insert(ctx, p1)) assert.NoError(t, db.InsertBatch(ctx, []Model{p2, p3, p4})) @@ -254,11 +261,57 @@ func TestDBQueries(t *testing.T) { t.Run("queryRow", func(t *testing.T) { var p personModel row := db.QueryRow(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + }) + + t.Run("rebindQuery", func(t *testing.T) { + rows, err := db.RebindQuery(ctx, "SELECT * FROM person_test WHERE id = ?", p1.GetID()) assert.NoError(t, err) + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + }) + + t.Run("rebindQueryRow", func(t *testing.T) { + var p personModel + row := db.RebindQueryRow(ctx, "SELECT * FROM person_test WHERE id = ?", p1.GetID()) + assert.NoError(t, row.Err()) assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) equalPerson(t, p1, &p) }) + t.Run("namedQuery", func(t *testing.T) { + rows, err := db.NamedQuery(ctx, "SELECT * FROM person_test WHERE id = :id", p1) + assert.NoError(t, err) + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + }) + + t.Run("namedQuery withMap", func(t *testing.T) { + rows, err := db.NamedQuery(ctx, "SELECT * FROM person_test WHERE id = :id", map[string]any{ + "id": p1.GetID(), + }) + assert.NoError(t, err) + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + }) + t.Run("get", func(t *testing.T) { var pp1, pp2 personModel assert.NoError(t, db.Get(ctx, &pp1, "SELECT * FROM person_test WHERE id = $1", p1.GetID())) @@ -306,14 +359,73 @@ func TestDBQueries(t *testing.T) { assert.Error(t, db.Select(ctx, &pp, p5.GetID())) }) + t.Run("rebindExec", func(t *testing.T) { + var pp personModel + p1.DeletedAt = sql.NullTime{ + Valid: true, + Time: time.Now().UTC().Truncate(time.Second), + } + res, err := db.RebindExec(ctx, "UPDATE person_test SET deleted_at = ? WHERE id = ?", p1.DeletedAt, p1.ID) + assert.NoError(t, err) + assert.NoError(t, RowsAffected(res, 1)) + assert.NoError(t, db.Get(ctx, &pp, "SELECT * FROM person_test WHERE id = $1", p1.GetID())) + equalPerson(t, p1, &pp) + }) + + t.Run("namedExec", func(t *testing.T) { + var pp personModel + p1.DeletedAt = sql.NullTime{ + Valid: true, + Time: time.Now().UTC().Truncate(time.Second), + } + res, err := db.NamedExec(ctx, "UPDATE person_test SET deleted_at = :deleted_at WHERE id = :id", p1) + assert.NoError(t, err) + assert.NoError(t, RowsAffected(res, 1)) + assert.NoError(t, db.Get(ctx, &pp, "SELECT * FROM person_test WHERE id = $1", p1.GetID())) + equalPerson(t, p1, &pp) + }) + + t.Run("namedExec map", func(t *testing.T) { + var pp personModel + p1.DeletedAt = sql.NullTime{ + Valid: true, + Time: time.Now().UTC().Truncate(time.Second), + } + res, err := db.NamedExec(ctx, "UPDATE person_test SET deleted_at = :deleted_at WHERE id = :id", map[string]any{ + "deleted_at": p1.DeletedAt.Time, + "id": p1.ID, + }) + assert.NoError(t, err) + assert.NoError(t, RowsAffected(res, 1)) + assert.NoError(t, db.Get(ctx, &pp, "SELECT * FROM person_test WHERE id = $1", p1.GetID())) + equalPerson(t, p1, &pp) + }) + t.Run("exec (clear table)", func(t *testing.T) { res, err := db.Exec(ctx, "DELETE FROM person_test") assert.NoError(t, err) assert.NoError(t, RowsAffected(res, 4)) // p1 to p4, p5 is hard deleted }) + } func TestTxQueries(t *testing.T) { + equalPerson := func(t *testing.T, want, got *personModel) bool { + t.Helper() + if got != nil { + got.CreatedAt = got.CreatedAt.UTC().Truncate(time.Second) + got.UpdatedAt = got.UpdatedAt.UTC().Truncate(time.Second) + if got.DeletedAt.Valid { + got.DeletedAt = NullTime(got.DeletedAt.Time.UTC().Truncate(time.Second)) + } + } + want.CreatedAt = want.CreatedAt.Truncate(time.Second) + want.UpdatedAt = want.UpdatedAt.Truncate(time.Second) + if want.DeletedAt.Valid { + want.DeletedAt = NullTime(want.DeletedAt.Time.Truncate(time.Second)) + } + return assert.Equal(t, want, got) + } db, err := New(postgresDataSource) require.NoError(t, err) t.Cleanup(func() { @@ -342,6 +454,14 @@ func TestTxQueries(t *testing.T) { }, } + t.Run("rebind", func(t *testing.T) { + tx, err := db.Begin(ctx) + require.NoError(t, err) + query := tx.Rebind("SELECT * FROM person_test WHERE name = ? AND email = ?") + assert.Equal(t, "SELECT * FROM person_test WHERE name = $1 AND email = $2", query) + assert.NoError(t, tx.Rollback()) + }) + t.Run("insert", func(t *testing.T) { tx, err := db.Begin(ctx) require.NoError(t, err) @@ -361,6 +481,83 @@ func TestTxQueries(t *testing.T) { assert.NoError(t, tx.Rollback()) }) + t.Run("query", func(t *testing.T) { + tx, err := db.Begin(ctx) + require.NoError(t, err) + rows, err := tx.Query("SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, err) + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + assert.NoError(t, tx.Commit()) + }) + + t.Run("queryRow", func(t *testing.T) { + var p personModel + tx, err := db.Begin(ctx) + require.NoError(t, err) + row := tx.QueryRow("SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + assert.NoError(t, tx.Commit()) + }) + + t.Run("rebindQuery", func(t *testing.T) { + tx, err := db.Begin(ctx) + require.NoError(t, err) + rows, err := tx.RebindQuery("SELECT * FROM person_test WHERE id = ?", p1.GetID()) + assert.NoError(t, err) + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + assert.NoError(t, tx.Commit()) + }) + + t.Run("rebindQueryRow", func(t *testing.T) { + var p personModel + tx, err := db.Begin(ctx) + require.NoError(t, err) + row := tx.RebindQueryRow("SELECT * FROM person_test WHERE id = ?", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + assert.NoError(t, tx.Commit()) + }) + + t.Run("namedQuery", func(t *testing.T) { + tx, err := db.Begin(ctx) + require.NoError(t, err) + rows, err := tx.NamedQuery("SELECT * FROM person_test WHERE id = :id", p1) + assert.NoError(t, err) + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + assert.NoError(t, tx.Commit()) + }) + + t.Run("get", func(t *testing.T) { + var p personModel + tx, err := db.Begin(ctx) + require.NoError(t, err) + err = tx.Get(&p, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, err) + equalPerson(t, p1, &p) + assert.NoError(t, tx.Commit()) + }) + t.Run("update", func(t *testing.T) { tx, err := db.Begin(ctx) require.NoError(t, err) @@ -411,6 +608,68 @@ func TestTxQueries(t *testing.T) { assert.NoError(t, tx.Rollback()) }) + t.Run("rebindExec", func(t *testing.T) { + var p personModel + tx, err := db.Begin(ctx) + require.NoError(t, err) + defer func() { + assert.Error(t, tx.Rollback()) + }() + + p1.DeletedAt = sql.NullTime{ + Time: time.Now().UTC().Truncate(time.Second), + Valid: true, + } + + res, err := tx.RebindExec("UPDATE person_test SET deleted_at = ? WHERE id = ?", p1.DeletedAt, p1.ID) + assert.NoError(t, err) + n, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + // In transaction + row := tx.RebindQueryRow("SELECT * FROM person_test WHERE id = ?", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + assert.NoError(t, tx.Commit()) + // After commit + row = db.RebindQueryRow(ctx, "SELECT * FROM person_test WHERE id = ?", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + }) + + t.Run("namedExec", func(t *testing.T) { + var p personModel + tx, err := db.Begin(ctx) + require.NoError(t, err) + defer func() { + assert.Error(t, tx.Rollback()) + }() + + p1.DeletedAt = sql.NullTime{ + Time: time.Now().UTC().Truncate(time.Second), + Valid: true, + } + + res, err := tx.NamedExec("UPDATE person_test SET deleted_at = :deleted_at WHERE id = :id", p1) + assert.NoError(t, err) + n, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + // In transaction + row := tx.QueryRow("SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + assert.NoError(t, tx.Commit()) + // After commit + row = db.QueryRow(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + equalPerson(t, p1, &p) + }) + t.Run("exec", func(t *testing.T) { tx, err := db.Begin(ctx) require.NoError(t, err) @@ -426,3 +685,30 @@ func TestTxQueries(t *testing.T) { assert.NoError(t, tx.Commit()) }) } + +func TestDB_Rebind(t *testing.T) { + type fields struct { + db *sqlx.DB + clock clock.Clock + } + type args struct { + query string + } + tests := []struct { + name string + fields fields + args args + want string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &DB{ + db: tt.fields.db, + clock: tt.fields.clock, + } + assert.Equal(t, tt.want, d.Rebind(tt.args.query)) + }) + } +}