diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go b/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go index 2a183625a5..f35b96d3f6 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go @@ -114,6 +114,11 @@ func doDoltReset(ctx *sql.Context, args []string) (int, error) { if err != nil { return 1, err } + err = dSess.ResetGlobals(ctx, dbName, roots.Working) + if err != nil { + return 1, err + } + } else if apr.Contains(cli.SoftResetParam) { arg := "" if apr.NArg() > 1 { diff --git a/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go b/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go index cdca2458e7..4dda09c08a 100644 --- a/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go +++ b/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go @@ -61,41 +61,7 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb sequences: &sync.Map{}, mm: mutexmap.NewMutexMap(), } - - for _, root := range roots { - root, err := root.ResolveRootValue(ctx) - if err != nil { - return &AutoIncrementTracker{}, err - } - - err = root.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) { - ok := schema.HasAutoIncrement(sch) - if !ok { - return false, nil - } - - tableName = tableName.ToLower() - - seq, err := table.GetAutoIncrementValue(ctx) - if err != nil { - return true, err - } - - // TODO: support schema name as part of the key - tableNameStr := tableName.Name - oldValue, loaded := ait.sequences.LoadOrStore(tableNameStr, seq) - if loaded && seq > oldValue.(uint64) { - ait.sequences.Store(tableNameStr, seq) - } - - return false, nil - }) - - if err != nil { - return &AutoIncrementTracker{}, err - } - } - + ait.InitWithRoots(ctx, roots...) return &ait, nil } @@ -109,13 +75,13 @@ func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 { } // Current returns the next value to be generated in the auto increment sequence for the table named -func (a AutoIncrementTracker) Current(tableName string) uint64 { +func (a *AutoIncrementTracker) Current(tableName string) uint64 { return loadAutoIncValue(a.sequences, tableName) } // Next returns the next auto increment value for the table named using the provided value from an insert (which may // be null or 0, in which case it will be generated from the sequence). -func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) { +func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) { tbl = strings.ToLower(tbl) given, err := CoerceAutoIncrementValue(insertVal) @@ -145,7 +111,7 @@ func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, e return given, nil } -func (a AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) { +func (a *AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) { return CoerceAutoIncrementValue(val) } @@ -172,7 +138,7 @@ func CoerceAutoIncrementValue(val interface{}) (uint64, error) { // Set sets the auto increment value for the table named, if it's greater than the one already registered for this // table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the // maximum value for this table across all branches. -func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) { +func (a *AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) { tableName = strings.ToLower(tableName) release := a.mm.Lock(tableName) @@ -190,7 +156,7 @@ func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *dol // deepSet sets the auto increment value for the table named, if it's greater than the one on any branch head for this // database, ignoring the current in-memory tracker value -func (a AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) { +func (a *AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) { sess := DSessFromSess(ctx.Session) db, ok := sess.Provider().BaseDatabase(ctx, a.dbName) @@ -371,7 +337,7 @@ func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, err } // AddNewTable initializes a new table with an auto increment column to the tracker, as necessary -func (a AutoIncrementTracker) AddNewTable(tableName string) { +func (a *AutoIncrementTracker) AddNewTable(tableName string) { tableName = strings.ToLower(tableName) // only initialize the sequence for this table if no other branch has such a table a.sequences.LoadOrStore(tableName, uint64(1)) @@ -380,7 +346,7 @@ func (a AutoIncrementTracker) AddNewTable(tableName string) { // DropTable drops the table with the name given. // To establish the new auto increment value, callers must also pass all other working sets in scope that may include // a table with the same name, omitting the working set that just deleted the table named. -func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error { +func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error { tableName = strings.ToLower(tableName) release := a.mm.Lock(tableName) @@ -430,3 +396,36 @@ func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName stri a.lockMode = lockMode return a.mm.Lock(tableName), nil } + +func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error { + for _, root := range roots { + r, err := root.ResolveRootValue(ctx) + if err != nil { + return err + } + + err = r.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) { + if !schema.HasAutoIncrement(sch) { + return false, nil + } + + seq, err := table.GetAutoIncrementValue(ctx) + if err != nil { + return true, err + } + + tableNameStr := tableName.ToLower().Name + if oldValue, loaded := a.sequences.LoadOrStore(tableNameStr, seq); loaded && seq > oldValue.(uint64) { + a.sequences.Store(tableNameStr, seq) + } + + return false, nil + }) + + if err != nil { + return err + } + } + + return nil +} diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index 1c02000155..5544f2ef8d 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -1018,7 +1018,7 @@ func (d *DoltSession) SetStagingRoot(ctx *sql.Context, dbName string, newRoot do // via setRoot. This method is for clients that need to update more of the session state, such as the dolt_ functions. // Unlike setting the working root, this method always marks the database state dirty. func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error { - sessionState, _, err := d.LookupDbState(ctx, dbName) + sessionState, _, err := d.lookupDbState(ctx, dbName) if err != nil { return err } @@ -1031,6 +1031,25 @@ func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roo return d.SetWorkingSet(ctx, dbName, workingSet) } +func (d *DoltSession) ResetGlobals(ctx *sql.Context, dbName string, root doltdb.RootValue) error { + sessionState, _, err := d.lookupDbState(ctx, dbName) + if err != nil { + return err + } + + tracker, err := sessionState.dbState.globalState.AutoIncrementTracker(ctx) + if err != nil { + return err + } + + err = tracker.InitWithRoots(ctx, root) + if err != nil { + return err + } + + return nil +} + func (d *DoltSession) SetFileSystem(fs filesys.Filesys) { d.fs = fs } @@ -1059,8 +1078,8 @@ func (d *DoltSession) SetWorkingSet(ctx *sql.Context, dbName string, ws *doltdb. return err } - if writeSess := branchState.WriteSession(); writeSess != nil { - err = writeSess.SetWorkingSet(ctx, ws) + if branchState.writeSession != nil { + err = branchState.writeSession.SetWorkingSet(ctx, ws) if err != nil { return err } @@ -1484,9 +1503,10 @@ func (d *DoltSession) dbSessionVarsStale(ctx *sql.Context, state *branchState) b return d.dbCache.CacheSessionVars(state, dtx) } -func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession { - d.globalsConf = conf - return &d +func (d *DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession { + nd := *d + nd.globalsConf = conf + return &nd } // PersistGlobal implements sql.PersistableSession diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 52c61dbbcd..816d028076 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -5637,6 +5637,89 @@ var DoltAutoIncrementTests = []queries.ScriptTest{ }, }, }, + { + Name: "hard reset dropped table restores auto increment", + SetUpScript: []string{ + "create table t (a int primary key auto_increment, b int)", + "insert into t (b) values (1), (2)", + "call dolt_commit('-Am', 'initialize table')", + "drop table t", + "call dolt_reset('--hard')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t(b) values (3)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, InsertID: 3}}, + }, + }, + { + Query: "select * from t order by a", + Expected: []sql.Row{ + {1, 1}, + {2, 2}, + {3, 3}, + }, + }, + }, + }, + { + // this behavior aligns with how we treat branches + Name: "hard reset inserted rows continues auto increment", + SetUpScript: []string{ + "create table t (a int primary key auto_increment, b int)", + "insert into t (b) values (1), (2)", + "call dolt_commit('-Am', 'initialize table')", + "insert into t (b) values (3), (4)", + "call dolt_reset('--hard')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t(b) values (5)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, InsertID: 5}}, + }, + }, + { + Query: "select * from t order by a", + Expected: []sql.Row{ + {1, 1}, + {2, 2}, + {5, 5}, + }, + }, + }, + }, + { + Name: "hard reset dropped table with branch restores auto increment", + SetUpScript: []string{ + "create table t (a int primary key auto_increment, b int)", + "insert into t (b) values (1), (2)", + "call dolt_commit('-Am', 'initialize table')", + "call dolt_checkout('-b', 'branch1')", + "insert into t values (100, 100)", + "call dolt_commit('-Am', 'other')", + "call dolt_checkout('main')", + "drop table t", + "call dolt_reset('--hard')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t(b) values (101)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, InsertID: 101}}, + }, + }, + { + Query: "select * from t order by a", + Expected: []sql.Row{ + {1, 1}, + {2, 2}, + {101, 101}, + }, + }, + }, + }, } var DoltCherryPickTests = []queries.ScriptTest{ diff --git a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go index 0bf07cd491..3c92f6e0a0 100644 --- a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go +++ b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go @@ -15,6 +15,8 @@ package globalstate import ( + "context" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" @@ -38,9 +40,10 @@ type AutoIncrementTracker interface { // below the current value for this table. The table in the provided working set is assumed to already have the value // given, so the new global maximum is computed without regard for its value in that working set. Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) - // AcquireTableLock acquires the auto increment lock on a table, and returns a callback function to release the lock. // Depending on the value of the `innodb_autoinc_lock_mode` system variable, the engine may need to acquire and hold // the lock for the duration of an insert statement. AcquireTableLock(ctx *sql.Context, tableName string) (func(), error) + // InitWithRoots fills the AutoIncrementTracker with values pulled from each root in order. + InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error }