Skip to content

Commit

Permalink
reset auto increment counter on dolt_reset('--hard') (#8319)
Browse files Browse the repository at this point in the history
  • Loading branch information
jycor authored Sep 4, 2024
1 parent 913d6b5 commit 8dca4a5
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 49 deletions.
5 changes: 5 additions & 0 deletions go/libraries/doltcore/sqle/dprocedures/dolt_reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ func doDoltReset(ctx *sql.Context, args []string) (int, error) {
if err != nil {
return 1, err
}
err = dSess.ResetGlobals(ctx, dbName, roots.Working)
if err != nil {
return 1, err
}

} else if apr.Contains(cli.SoftResetParam) {
arg := ""
if apr.NArg() > 1 {
Expand Down
83 changes: 41 additions & 42 deletions go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,41 +61,7 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb
sequences: &sync.Map{},
mm: mutexmap.NewMutexMap(),
}

for _, root := range roots {
root, err := root.ResolveRootValue(ctx)
if err != nil {
return &AutoIncrementTracker{}, err
}

err = root.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
ok := schema.HasAutoIncrement(sch)
if !ok {
return false, nil
}

tableName = tableName.ToLower()

seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}

// TODO: support schema name as part of the key
tableNameStr := tableName.Name
oldValue, loaded := ait.sequences.LoadOrStore(tableNameStr, seq)
if loaded && seq > oldValue.(uint64) {
ait.sequences.Store(tableNameStr, seq)
}

return false, nil
})

if err != nil {
return &AutoIncrementTracker{}, err
}
}

ait.InitWithRoots(ctx, roots...)
return &ait, nil
}

Expand All @@ -109,13 +75,13 @@ func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 {
}

// Current returns the next value to be generated in the auto increment sequence for the table named
func (a AutoIncrementTracker) Current(tableName string) uint64 {
func (a *AutoIncrementTracker) Current(tableName string) uint64 {
return loadAutoIncValue(a.sequences, tableName)
}

// Next returns the next auto increment value for the table named using the provided value from an insert (which may
// be null or 0, in which case it will be generated from the sequence).
func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
tbl = strings.ToLower(tbl)

given, err := CoerceAutoIncrementValue(insertVal)
Expand Down Expand Up @@ -145,7 +111,7 @@ func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, e
return given, nil
}

func (a AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
func (a *AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
return CoerceAutoIncrementValue(val)
}

Expand All @@ -172,7 +138,7 @@ func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
// Set sets the auto increment value for the table named, if it's greater than the one already registered for this
// table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the
// maximum value for this table across all branches.
func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
func (a *AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
tableName = strings.ToLower(tableName)

release := a.mm.Lock(tableName)
Expand All @@ -190,7 +156,7 @@ func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *dol

// deepSet sets the auto increment value for the table named, if it's greater than the one on any branch head for this
// database, ignoring the current in-memory tracker value
func (a AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
func (a *AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
sess := DSessFromSess(ctx.Session)
db, ok := sess.Provider().BaseDatabase(ctx, a.dbName)

Expand Down Expand Up @@ -371,7 +337,7 @@ func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, err
}

// AddNewTable initializes a new table with an auto increment column to the tracker, as necessary
func (a AutoIncrementTracker) AddNewTable(tableName string) {
func (a *AutoIncrementTracker) AddNewTable(tableName string) {
tableName = strings.ToLower(tableName)
// only initialize the sequence for this table if no other branch has such a table
a.sequences.LoadOrStore(tableName, uint64(1))
Expand All @@ -380,7 +346,7 @@ func (a AutoIncrementTracker) AddNewTable(tableName string) {
// DropTable drops the table with the name given.
// To establish the new auto increment value, callers must also pass all other working sets in scope that may include
// a table with the same name, omitting the working set that just deleted the table named.
func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
tableName = strings.ToLower(tableName)

release := a.mm.Lock(tableName)
Expand Down Expand Up @@ -430,3 +396,36 @@ func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName stri
a.lockMode = lockMode
return a.mm.Lock(tableName), nil
}

func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error {
for _, root := range roots {
r, err := root.ResolveRootValue(ctx)
if err != nil {
return err
}

err = r.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
if !schema.HasAutoIncrement(sch) {
return false, nil
}

seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}

tableNameStr := tableName.ToLower().Name
if oldValue, loaded := a.sequences.LoadOrStore(tableNameStr, seq); loaded && seq > oldValue.(uint64) {
a.sequences.Store(tableNameStr, seq)
}

return false, nil
})

if err != nil {
return err
}
}

return nil
}
32 changes: 26 additions & 6 deletions go/libraries/doltcore/sqle/dsess/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ func (d *DoltSession) SetStagingRoot(ctx *sql.Context, dbName string, newRoot do
// via setRoot. This method is for clients that need to update more of the session state, such as the dolt_ functions.
// Unlike setting the working root, this method always marks the database state dirty.
func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error {
sessionState, _, err := d.LookupDbState(ctx, dbName)
sessionState, _, err := d.lookupDbState(ctx, dbName)
if err != nil {
return err
}
Expand All @@ -1031,6 +1031,25 @@ func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roo
return d.SetWorkingSet(ctx, dbName, workingSet)
}

func (d *DoltSession) ResetGlobals(ctx *sql.Context, dbName string, root doltdb.RootValue) error {
sessionState, _, err := d.lookupDbState(ctx, dbName)
if err != nil {
return err
}

tracker, err := sessionState.dbState.globalState.AutoIncrementTracker(ctx)
if err != nil {
return err
}

err = tracker.InitWithRoots(ctx, root)
if err != nil {
return err
}

return nil
}

func (d *DoltSession) SetFileSystem(fs filesys.Filesys) {
d.fs = fs
}
Expand Down Expand Up @@ -1059,8 +1078,8 @@ func (d *DoltSession) SetWorkingSet(ctx *sql.Context, dbName string, ws *doltdb.
return err
}

if writeSess := branchState.WriteSession(); writeSess != nil {
err = writeSess.SetWorkingSet(ctx, ws)
if branchState.writeSession != nil {
err = branchState.writeSession.SetWorkingSet(ctx, ws)
if err != nil {
return err
}
Expand Down Expand Up @@ -1484,9 +1503,10 @@ func (d *DoltSession) dbSessionVarsStale(ctx *sql.Context, state *branchState) b
return d.dbCache.CacheSessionVars(state, dtx)
}

func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
d.globalsConf = conf
return &d
func (d *DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
nd := *d
nd.globalsConf = conf
return &nd
}

// PersistGlobal implements sql.PersistableSession
Expand Down
83 changes: 83 additions & 0 deletions go/libraries/doltcore/sqle/enginetest/dolt_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -5637,6 +5637,89 @@ var DoltAutoIncrementTests = []queries.ScriptTest{
},
},
},
{
Name: "hard reset dropped table restores auto increment",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"insert into t (b) values (1), (2)",
"call dolt_commit('-Am', 'initialize table')",
"drop table t",
"call dolt_reset('--hard')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "insert into t(b) values (3)",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, InsertID: 3}},
},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{3, 3},
},
},
},
},
{
// this behavior aligns with how we treat branches
Name: "hard reset inserted rows continues auto increment",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"insert into t (b) values (1), (2)",
"call dolt_commit('-Am', 'initialize table')",
"insert into t (b) values (3), (4)",
"call dolt_reset('--hard')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "insert into t(b) values (5)",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, InsertID: 5}},
},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{5, 5},
},
},
},
},
{
Name: "hard reset dropped table with branch restores auto increment",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"insert into t (b) values (1), (2)",
"call dolt_commit('-Am', 'initialize table')",
"call dolt_checkout('-b', 'branch1')",
"insert into t values (100, 100)",
"call dolt_commit('-Am', 'other')",
"call dolt_checkout('main')",
"drop table t",
"call dolt_reset('--hard')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "insert into t(b) values (101)",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, InsertID: 101}},
},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{101, 101},
},
},
},
},
}

var DoltCherryPickTests = []queries.ScriptTest{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package globalstate

import (
"context"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
Expand All @@ -38,9 +40,10 @@ type AutoIncrementTracker interface {
// below the current value for this table. The table in the provided working set is assumed to already have the value
// given, so the new global maximum is computed without regard for its value in that working set.
Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error)

// AcquireTableLock acquires the auto increment lock on a table, and returns a callback function to release the lock.
// Depending on the value of the `innodb_autoinc_lock_mode` system variable, the engine may need to acquire and hold
// the lock for the duration of an insert statement.
AcquireTableLock(ctx *sql.Context, tableName string) (func(), error)
// InitWithRoots fills the AutoIncrementTracker with values pulled from each root in order.
InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error
}

0 comments on commit 8dca4a5

Please sign in to comment.