diff --git a/db.go b/db.go index 5ca6d12..546395c 100644 --- a/db.go +++ b/db.go @@ -139,18 +139,33 @@ func (db *DB) Wrap(sqlTx *sql.Tx) *Tx { // Transactional starts a transaction and executes the given function. // If the function returns an error, the transaction will be rolled back. // Otherwise, the transaction will be committed. -func (db *DB) Transactional(f func(*Tx) error) error { +func (db *DB) Transactional(f func(*Tx) error) (err error) { tx, err := db.Begin() if err != nil { return err } - if err := f(tx); err != nil { - if e := tx.Rollback(); e != nil { - return Errors{err, e} + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } else if err != nil { + if err2 := tx.Rollback(); err2 != nil { + if err2 == sql.ErrTxDone { + return + } + err = Errors{err, err2} + } + } else { + if err = tx.Commit(); err == sql.ErrTxDone { + err = nil + } } - return err - } - return tx.Commit() + }() + + err = f(tx) + + return err } // DriverName returns the name of the DB driver. diff --git a/db_test.go b/db_test.go index 81463a0..59436e5 100644 --- a/db_test.go +++ b/db_test.go @@ -273,6 +273,42 @@ func TestDB_Transactional(t *testing.T) { db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) assert.Equal(t, "Go in Action", name) } + + // Rollback called within Transactional and return error + err = db.Transactional(func(tx *Tx) error { + _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() + if err != nil { + return err + } + _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() + if err != nil { + tx.Rollback() + return err + } + return nil + }) + if assert.NotNil(t, err) { + db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) + assert.Equal(t, "Go in Action", name) + } + + // Rollback called within Transactional without returning error + err = db.Transactional(func(tx *Tx) error { + _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() + if err != nil { + return err + } + _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() + if err != nil { + tx.Rollback() + return nil + } + return nil + }) + if assert.Nil(t, err) { + db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) + assert.Equal(t, "Go in Action", name) + } } func TestErrors_Error(t *testing.T) {