From b9d4b7a1fdfa84e1edd1eeee1252ed68110c7972 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 20 Dec 2024 08:27:06 +0100 Subject: [PATCH] move lastInsertId setting to vcursor Signed-off-by: Andres Taylor --- go/vt/vtgate/executor.go | 11 ++++------- go/vt/vtgate/executorcontext/vcursor_impl.go | 9 +++++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 5d27f4ebdec..2cc3043dfdf 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -234,9 +234,9 @@ func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConn stmtType, result, err := e.execute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats) logStats.Error = err if result == nil { - saveSessionStats(safeSession, stmtType, false, 0, 0, 0, err) + saveSessionStats(safeSession, stmtType, 0, 0, err) } else { - saveSessionStats(safeSession, stmtType, result.InsertIDChanged, result.InsertID, result.RowsAffected, len(result.Rows), err) + saveSessionStats(safeSession, stmtType, result.RowsAffected, len(result.Rows), err) } if result != nil && len(result.Rows) > warnMemoryRows { warnings.Add("ResultsExceeded", 1) @@ -381,7 +381,7 @@ func (e *Executor) StreamExecute( err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, resultHandler, srr.storeResultStats) logStats.Error = err - saveSessionStats(safeSession, srr.stmtType, srr.insertIDChanged, srr.insertID, srr.rowsAffected, srr.rowsReturned, err) + saveSessionStats(safeSession, srr.stmtType, srr.rowsAffected, srr.rowsReturned, err) if srr.rowsReturned > warnMemoryRows { warnings.Add("ResultsExceeded", 1) piiSafeSQL, err := e.env.Parser().RedactSQLQuery(sql) @@ -414,7 +414,7 @@ func canReturnRows(stmtType sqlparser.StatementType) bool { } } -func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.StatementType, insertIDChanged bool, insertID, rowsAffected uint64, rowsReturned int, err error) { +func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.StatementType, rowsAffected uint64, rowsReturned int, err error) { safeSession.RowCount = -1 if err != nil { return @@ -422,9 +422,6 @@ func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.Stat if !safeSession.IsFoundRowsHandled() { safeSession.FoundRows = uint64(rowsReturned) } - if insertID != 0 || insertIDChanged { - safeSession.LastInsertId = insertID - } switch stmtType { case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: safeSession.RowCount = int64(rowsAffected) diff --git a/go/vt/vtgate/executorcontext/vcursor_impl.go b/go/vt/vtgate/executorcontext/vcursor_impl.go index 1896b3f267a..e51056a3367 100644 --- a/go/vt/vtgate/executorcontext/vcursor_impl.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl.go @@ -679,6 +679,9 @@ func (vc *VCursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primi } return func(result *sqltypes.Result) error { + if result.InsertIDChanged { + vc.SafeSession.LastInsertId = result.InsertID + } vc.logOpTraffic(primitive, result) return callback(result) } @@ -764,6 +767,9 @@ func (vc *VCursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.P qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.SafeSession, canAutocommit, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID) vc.setRollbackOnPartialExecIfRequired(len(errs) != len(rss), rollbackOnError) vc.logShardsQueried(primitive, len(rss)) + if qr.InsertIDChanged { + vc.SafeSession.LastInsertId = qr.InsertID + } return qr, errs } @@ -803,6 +809,9 @@ func (vc *VCursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.P // execute DMLs through ExecuteStandalone. qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.SafeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID) vc.logShardsQueried(primitive, len(rss)) + if qr.InsertIDChanged { + vc.SafeSession.LastInsertId = qr.InsertID + } return qr, vterrors.Aggregate(errs) }