Skip to content

Commit

Permalink
move lastInsertId setting to vcursor
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Dec 20, 2024
1 parent 6fedc44 commit b9d4b7a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
11 changes: 4 additions & 7 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConn
stmtType, result, err := e.execute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats)
logStats.Error = err
if result == nil {
saveSessionStats(safeSession, stmtType, false, 0, 0, 0, err)
saveSessionStats(safeSession, stmtType, 0, 0, err)
} else {
saveSessionStats(safeSession, stmtType, result.InsertIDChanged, result.InsertID, result.RowsAffected, len(result.Rows), err)
saveSessionStats(safeSession, stmtType, result.RowsAffected, len(result.Rows), err)
}
if result != nil && len(result.Rows) > warnMemoryRows {
warnings.Add("ResultsExceeded", 1)
Expand Down Expand Up @@ -381,7 +381,7 @@ func (e *Executor) StreamExecute(
err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, resultHandler, srr.storeResultStats)

logStats.Error = err
saveSessionStats(safeSession, srr.stmtType, srr.insertIDChanged, srr.insertID, srr.rowsAffected, srr.rowsReturned, err)
saveSessionStats(safeSession, srr.stmtType, srr.rowsAffected, srr.rowsReturned, err)
if srr.rowsReturned > warnMemoryRows {
warnings.Add("ResultsExceeded", 1)
piiSafeSQL, err := e.env.Parser().RedactSQLQuery(sql)
Expand Down Expand Up @@ -414,17 +414,14 @@ func canReturnRows(stmtType sqlparser.StatementType) bool {
}
}

func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.StatementType, insertIDChanged bool, insertID, rowsAffected uint64, rowsReturned int, err error) {
func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.StatementType, rowsAffected uint64, rowsReturned int, err error) {
safeSession.RowCount = -1
if err != nil {
return
}
if !safeSession.IsFoundRowsHandled() {
safeSession.FoundRows = uint64(rowsReturned)
}
if insertID != 0 || insertIDChanged {
safeSession.LastInsertId = insertID
}
switch stmtType {
case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete:
safeSession.RowCount = int64(rowsAffected)
Expand Down
9 changes: 9 additions & 0 deletions go/vt/vtgate/executorcontext/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,9 @@ func (vc *VCursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primi
}

return func(result *sqltypes.Result) error {
if result.InsertIDChanged {
vc.SafeSession.LastInsertId = result.InsertID
}
vc.logOpTraffic(primitive, result)
return callback(result)
}
Expand Down Expand Up @@ -764,6 +767,9 @@ func (vc *VCursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.P
qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.SafeSession, canAutocommit, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID)
vc.setRollbackOnPartialExecIfRequired(len(errs) != len(rss), rollbackOnError)
vc.logShardsQueried(primitive, len(rss))
if qr.InsertIDChanged {
vc.SafeSession.LastInsertId = qr.InsertID
}
return qr, errs
}

Expand Down Expand Up @@ -803,6 +809,9 @@ func (vc *VCursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.P
// execute DMLs through ExecuteStandalone.
qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.SafeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID)
vc.logShardsQueried(primitive, len(rss))
if qr.InsertIDChanged {
vc.SafeSession.LastInsertId = qr.InsertID
}
return qr, vterrors.Aggregate(errs)
}

Expand Down

0 comments on commit b9d4b7a

Please sign in to comment.