Skip to content

Commit

Permalink
Prevent accidentally mixing transactions of different databases
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jun 2, 2024
1 parent a1d4796 commit 603e3d7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
6 changes: 6 additions & 0 deletions dbutil/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ type Database struct {
Dialect Dialect
UpgradeTable UpgradeTable

txnCtxKey contextKey

IgnoreForeignTables bool
IgnoreUnsupportedDatabase bool
}
Expand Down Expand Up @@ -132,6 +134,8 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
Log: log,
Dialect: db.Dialect,

txnCtxKey: db.txnCtxKey,

IgnoreForeignTables: true,
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
}
Expand All @@ -149,6 +153,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {

IgnoreForeignTables: true,
VersionTable: "version",

txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),
}
wrappedDB.LoggingDB.UnderlyingExecable = db
wrappedDB.LoggingDB.db = wrappedDB
Expand Down
18 changes: 12 additions & 6 deletions dbutil/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"errors"
"fmt"
"runtime"
"sync/atomic"
"time"

"github.com/rs/zerolog"
Expand All @@ -26,13 +27,18 @@ var (
ErrTxnCommit = fmt.Errorf("%w: commit", ErrTxn)
)

type contextKey int
type contextKey int64

const (
ContextKeyDatabaseTransaction contextKey = iota
ContextKeyDoTxnCallerSkip
ContextKeyDoTxnCallerSkip contextKey = 1
)

var nextContextKeyDatabaseTransaction atomic.Uint64

func init() {
nextContextKeyDatabaseTransaction.Store(1 << 61)
}

func (db *Database) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return db.Conn(ctx).ExecContext(ctx, query, args...)
}
Expand All @@ -56,7 +62,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
if ctx == nil {
panic("DoTxn() called with nil ctx")
}
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
if ctx.Value(db.txnCtxKey) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
return fn(ctx)
}
Expand Down Expand Up @@ -106,7 +112,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
log.Trace().Msg("Transaction started")
tx.noTotalLog = true
ctx = log.WithContext(ctx)
ctx = context.WithValue(ctx, ContextKeyDatabaseTransaction, tx)
ctx = context.WithValue(ctx, db.txnCtxKey, tx)
err = fn(ctx)
if err != nil {
log.Trace().Err(err).Msg("Database transaction failed, rolling back")
Expand All @@ -131,7 +137,7 @@ func (db *Database) Conn(ctx context.Context) Execable {
if ctx == nil {
panic("Conn() called with nil ctx")
}
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
txn, ok := ctx.Value(db.txnCtxKey).(Transaction)
if ok {
return txn
}
Expand Down
2 changes: 2 additions & 0 deletions dbutil/upgrades_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func testUpgrade(dialect Dialect) func(t *testing.T) {
VersionTable: "version",
Dialect: dialect,
UpgradeTable: makeTable(),
txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),

IgnoreForeignTables: true,
}
Expand Down Expand Up @@ -107,6 +108,7 @@ func testCompatCheck(dialect Dialect) func(t *testing.T) {
VersionTable: "version",
Dialect: dialect,
UpgradeTable: makeTable(),
txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),

IgnoreForeignTables: true,
}
Expand Down

0 comments on commit 603e3d7

Please sign in to comment.