From 603e3d7117c11f637dcb94b694dc34a469b61842 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 2 Jun 2024 19:32:13 +0300 Subject: [PATCH] Prevent accidentally mixing transactions of different databases --- dbutil/database.go | 6 ++++++ dbutil/transaction.go | 18 ++++++++++++------ dbutil/upgrades_test.go | 2 ++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/dbutil/database.go b/dbutil/database.go index bcaa8ea..bf57e2b 100644 --- a/dbutil/database.go +++ b/dbutil/database.go @@ -104,6 +104,8 @@ type Database struct { Dialect Dialect UpgradeTable UpgradeTable + txnCtxKey contextKey + IgnoreForeignTables bool IgnoreUnsupportedDatabase bool } @@ -132,6 +134,8 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da Log: log, Dialect: db.Dialect, + txnCtxKey: db.txnCtxKey, + IgnoreForeignTables: true, IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase, } @@ -149,6 +153,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) { IgnoreForeignTables: true, VersionTable: "version", + + txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)), } wrappedDB.LoggingDB.UnderlyingExecable = db wrappedDB.LoggingDB.db = wrappedDB diff --git a/dbutil/transaction.go b/dbutil/transaction.go index e71b423..ecdc59e 100644 --- a/dbutil/transaction.go +++ b/dbutil/transaction.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "runtime" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -26,13 +27,18 @@ var ( ErrTxnCommit = fmt.Errorf("%w: commit", ErrTxn) ) -type contextKey int +type contextKey int64 const ( - ContextKeyDatabaseTransaction contextKey = iota - ContextKeyDoTxnCallerSkip + ContextKeyDoTxnCallerSkip contextKey = 1 ) +var nextContextKeyDatabaseTransaction atomic.Uint64 + +func init() { + nextContextKeyDatabaseTransaction.Store(1 << 61) +} + func (db *Database) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) { return db.Conn(ctx).ExecContext(ctx, query, args...) } @@ -56,7 +62,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx if ctx == nil { panic("DoTxn() called with nil ctx") } - if ctx.Value(ContextKeyDatabaseTransaction) != nil { + if ctx.Value(db.txnCtxKey) != nil { zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one") return fn(ctx) } @@ -106,7 +112,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx log.Trace().Msg("Transaction started") tx.noTotalLog = true ctx = log.WithContext(ctx) - ctx = context.WithValue(ctx, ContextKeyDatabaseTransaction, tx) + ctx = context.WithValue(ctx, db.txnCtxKey, tx) err = fn(ctx) if err != nil { log.Trace().Err(err).Msg("Database transaction failed, rolling back") @@ -131,7 +137,7 @@ func (db *Database) Conn(ctx context.Context) Execable { if ctx == nil { panic("Conn() called with nil ctx") } - txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction) + txn, ok := ctx.Value(db.txnCtxKey).(Transaction) if ok { return txn } diff --git a/dbutil/upgrades_test.go b/dbutil/upgrades_test.go index 1aa715f..b837f61 100644 --- a/dbutil/upgrades_test.go +++ b/dbutil/upgrades_test.go @@ -70,6 +70,7 @@ func testUpgrade(dialect Dialect) func(t *testing.T) { VersionTable: "version", Dialect: dialect, UpgradeTable: makeTable(), + txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)), IgnoreForeignTables: true, } @@ -107,6 +108,7 @@ func testCompatCheck(dialect Dialect) func(t *testing.T) { VersionTable: "version", Dialect: dialect, UpgradeTable: makeTable(), + txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)), IgnoreForeignTables: true, }