Skip to content

Commit

Permalink
Ensure txdb.conn is not closed on tx context cancelation (#10935)
Browse files Browse the repository at this point in the history
* Ensure txdb.conn is not closed on tx context cancelation

txdb.conn implements driver.Validation & driver.SessionResetter to prevent database/sql from closing connection if context of transaction was cancelled

* Fix race of c.closed
  • Loading branch information
dhaidashenko authored Oct 13, 2023
1 parent eab0984 commit 9cc576f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
31 changes: 28 additions & 3 deletions core/internal/testutils/pgtest/txdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ func init() {

var _ driver.Conn = &conn{}

var _ driver.Validator = &conn{}
var _ driver.SessionResetter = &conn{}

// txDriver is an sql driver which runs on single transaction
// when the Close is called, transaction is rolled back
type txDriver struct {
Expand Down Expand Up @@ -98,7 +101,7 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c = &conn{tx: tx, opened: 1}
c = &conn{tx: tx, opened: 1, dsn: dsn}
c.removeSelf = func() error {
return d.deleteConn(c)
}
Expand Down Expand Up @@ -147,8 +150,9 @@ func (c *conn) Begin() (driver.Tx, error) {
}

// Implement the "ConnBeginTx" interface
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
// TODO: Fix context handling
func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
// Context is ignored, because single transaction is shared by all callers, thus caller should not be able to
// control it with local context
return c.Begin()
}

Expand Down Expand Up @@ -176,6 +180,27 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
return &stmt{st, c}, nil
}

// IsValid is called prior to placing the connection into the
// connection pool by database/sql. The connection will be discarded if false is returned.
func (c *conn) IsValid() bool {
c.Lock()
defer c.Unlock()
return !c.closed
}

func (c *conn) ResetSession(ctx context.Context) error {
// Ensure bad connections are reported: From database/sql/driver:
// If a connection is never returned to the connection pool but immediately reused, then
// ResetSession is called prior to reuse but IsValid is not called.
c.Lock()
defer c.Unlock()
if c.closed {
return driver.ErrBadConn
}

return nil
}

// pgx returns nil
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
return nil
Expand Down
48 changes: 48 additions & 0 deletions core/internal/testutils/pgtest/txdb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package pgtest

import (
"context"
"testing"
"time"

"github.com/google/uuid"
"github.com/smartcontractkit/sqlx"
"github.com/stretchr/testify/assert"
)

func TestTxDBDriver(t *testing.T) {
db := NewSqlxDB(t)
dropTable := func() error {
_, err := db.Exec(`DROP TABLE IF EXISTS txdb_test`)
return err
}
// clean up, if previous tests failed
err := dropTable()
assert.NoError(t, err)
_, err = db.Exec(`CREATE TABLE txdb_test (id TEXT NOT NULL)`)
assert.NoError(t, err)
t.Cleanup(func() {
_ = dropTable()
})
_, err = db.Exec(`INSERT INTO txdb_test VALUES ($1)`, uuid.New().String())
assert.NoError(t, err)
ensureValuesPresent := func(t *testing.T, db *sqlx.DB) {
var ids []string
err = db.Select(&ids, `SELECT id from txdb_test`)
assert.NoError(t, err)
assert.Len(t, ids, 1)
}

ensureValuesPresent(t, db)
t.Run("Cancel of tx's context does not trigger rollback of driver's tx", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
_, err := db.BeginTx(ctx, nil)
assert.NoError(t, err)
cancel()
// BeginTx spawns separate goroutine that rollbacks the tx and tries to close underlying connection, unless
// db driver says that connection is still active.
// This approach is not ideal, but there is no better way to wait for independent goroutine to complete
time.Sleep(time.Second * 10)
ensureValuesPresent(t, db)
})
}

0 comments on commit 9cc576f

Please sign in to comment.