diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 21d4629ffdc..01cbd6a8f30 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -699,7 +699,7 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result * got = &sqltypes.Result{} got.RowsAffected = result.RowsAffected got.InsertID = result.InsertID - got.InsertIDChanged = result.InsertIDChanged + got.InsertIDChanged = result.InsertIDUpdated() got.Fields, err = cConn.Fields() if err != nil { fatalError = fmt.Sprintf("Fields(%v) failed: %v", query, err) diff --git a/go/sqltypes/proto3.go b/go/sqltypes/proto3.go index 0ca03b153cf..a9c66c45a4c 100644 --- a/go/sqltypes/proto3.go +++ b/go/sqltypes/proto3.go @@ -103,7 +103,7 @@ func ResultToProto3(qr *Result) *querypb.QueryResult { Fields: qr.Fields, RowsAffected: qr.RowsAffected, InsertId: qr.InsertID, - InsertIdChanged: qr.InsertIDChanged, + InsertIdChanged: qr.InsertIDUpdated(), Rows: RowsToProto3(qr.Rows), Info: qr.Info, SessionStateChanges: qr.SessionStateChanges, diff --git a/go/sqltypes/result.go b/go/sqltypes/result.go index 3c56c8b7eea..4fd8f29d57a 100644 --- a/go/sqltypes/result.go +++ b/go/sqltypes/result.go @@ -93,7 +93,7 @@ func (result *Result) Copy() *Result { out := &Result{ RowsAffected: result.RowsAffected, InsertID: result.InsertID, - InsertIDChanged: result.InsertIDChanged, + InsertIDChanged: result.InsertIDUpdated(), SessionStateChanges: result.SessionStateChanges, StatusFlags: result.StatusFlags, Info: result.Info, @@ -132,7 +132,7 @@ func (result *Result) Metadata() *Result { return &Result{ Fields: result.Fields, InsertID: result.InsertID, - InsertIDChanged: result.InsertIDChanged, + InsertIDChanged: result.InsertIDUpdated(), RowsAffected: result.RowsAffected, Info: result.Info, SessionStateChanges: result.SessionStateChanges, @@ -157,7 +157,7 @@ func (result *Result) Truncate(l int) *Result { out := &Result{ InsertID: result.InsertID, - InsertIDChanged: result.InsertIDChanged, + InsertIDChanged: result.InsertIDUpdated(), RowsAffected: result.RowsAffected, Info: result.Info, SessionStateChanges: result.SessionStateChanges, @@ -331,10 +331,10 @@ func (result *Result) StripMetadata(incl querypb.ExecuteOptions_IncludedFields) // if two results have different fields.We will enhance this function. func (result *Result) AppendResult(src *Result) { result.RowsAffected += src.RowsAffected - if src.InsertID != 0 || src.InsertIDChanged { + if src.InsertIDUpdated() { result.InsertID = src.InsertID } - result.InsertIDChanged = result.InsertIDChanged || src.InsertIDChanged + result.InsertIDChanged = result.InsertIDUpdated() || src.InsertIDUpdated() if result.Fields == nil { result.Fields = src.Fields } @@ -355,3 +355,7 @@ func (result *Result) IsMoreResultsExists() bool { func (result *Result) IsInTransaction() bool { return result.StatusFlags&ServerStatusInTrans == ServerStatusInTrans } + +func (result *Result) InsertIDUpdated() bool { + return result.InsertIDChanged || result.InsertID > 0 +} diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 2cae911ccd5..10a4048572f 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -178,6 +178,7 @@ func (ins *Insert) executeInsertQueries( if insertID != 0 { result.InsertID = insertID + result.InsertIDChanged = true } return result, nil } diff --git a/go/vt/vtgate/engine/insert_common.go b/go/vt/vtgate/engine/insert_common.go index 9a4d71eca58..629d848d978 100644 --- a/go/vt/vtgate/engine/insert_common.go +++ b/go/vt/vtgate/engine/insert_common.go @@ -171,6 +171,7 @@ func (ins *InsertCommon) executeUnshardedTableQuery(ctx context.Context, vcursor // values, we don't return an error because this behavior // is required to support migration. if insertID != 0 { + qr.InsertIDChanged = true qr.InsertID = insertID } return qr, nil diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 2cc3043dfdf..d8e7ed208cf 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -276,10 +276,10 @@ func (s *streaminResultReceiver) storeResultStats(typ sqlparser.StatementType, q defer s.mu.Unlock() s.rowsAffected += qr.RowsAffected s.rowsReturned += len(qr.Rows) - if qr.InsertID != 0 || qr.InsertIDChanged { + if qr.InsertIDUpdated() { s.insertID = qr.InsertID } - s.insertIDChanged = s.insertIDChanged || qr.InsertIDChanged + s.insertIDChanged = s.insertIDChanged || qr.InsertIDUpdated() s.stmtType = typ return s.callback(qr) } diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index 792e197f48d..503b5e5cd8b 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -1812,8 +1812,9 @@ func TestInsertGeneratorSharded(t *testing.T) { Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), }}, - RowsAffected: 1, - InsertID: 1, + RowsAffected: 1, + InsertIDChanged: true, + InsertID: 1, }}) session := &vtgatepb.Session{ TargetString: "@primary", @@ -1840,8 +1841,9 @@ func TestInsertGeneratorSharded(t *testing.T) { }} assertQueries(t, sbclookup, wantQueries) wantResult := &sqltypes.Result{ - InsertID: 1, - RowsAffected: 1, + InsertID: 1, + RowsAffected: 1, + InsertIDChanged: true, } utils.MustMatch(t, wantResult, result) } @@ -1854,8 +1856,9 @@ func TestInsertAutoincSharded(t *testing.T) { Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), }}, - RowsAffected: 1, - InsertID: 2, + RowsAffected: 1, + InsertID: 2, + InsertIDChanged: true, } sbc.SetResults([]*sqltypes.Result{wantResult}) session := &vtgatepb.Session{ @@ -1894,8 +1897,9 @@ func TestInsertGeneratorUnsharded(t *testing.T) { }} assertQueries(t, sbclookup, wantQueries) wantResult := &sqltypes.Result{ - InsertID: 1, - RowsAffected: 1, + InsertID: 1, + InsertIDChanged: true, + RowsAffected: 1, } utils.MustMatch(t, wantResult, result) } @@ -1912,8 +1916,9 @@ func TestInsertAutoincUnsharded(t *testing.T) { Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), }}, - RowsAffected: 1, - InsertID: 2, + RowsAffected: 1, + InsertID: 2, + InsertIDChanged: true, } sbclookup.SetResults([]*sqltypes.Result{wantResult}) @@ -1965,8 +1970,9 @@ func TestInsertLookupOwnedGenerator(t *testing.T) { Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(4), }}, - RowsAffected: 1, - InsertID: 1, + RowsAffected: 1, + InsertID: 1, + InsertIDChanged: true, }}) session := &vtgatepb.Session{ TargetString: "@primary", @@ -1993,8 +1999,9 @@ func TestInsertLookupOwnedGenerator(t *testing.T) { }} assertQueries(t, sbclookup, wantQueries) wantResult := &sqltypes.Result{ - InsertID: 4, - RowsAffected: 1, + InsertID: 4, + InsertIDChanged: true, + RowsAffected: 1, } utils.MustMatch(t, wantResult, result) } @@ -2226,8 +2233,9 @@ func TestMultiInsertGenerator(t *testing.T) { Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), }}, - RowsAffected: 1, - InsertID: 1, + RowsAffected: 1, + InsertIDChanged: true, + InsertID: 1, }}) session := &vtgatepb.Session{ TargetString: "@primary", @@ -2258,8 +2266,9 @@ func TestMultiInsertGenerator(t *testing.T) { }} assertQueries(t, sbclookup, wantQueries) wantResult := &sqltypes.Result{ - InsertID: 1, - RowsAffected: 1, + InsertIDChanged: true, + InsertID: 1, + RowsAffected: 1, } utils.MustMatch(t, wantResult, result) } @@ -2271,8 +2280,9 @@ func TestMultiInsertGeneratorSparse(t *testing.T) { Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), }}, - RowsAffected: 1, - InsertID: 1, + RowsAffected: 1, + InsertIDChanged: true, + InsertID: 1, }}) session := &vtgatepb.Session{ TargetString: "@primary", @@ -2307,8 +2317,9 @@ func TestMultiInsertGeneratorSparse(t *testing.T) { }} assertQueries(t, sbclookup, wantQueries) wantResult := &sqltypes.Result{ - InsertID: 1, - RowsAffected: 1, + InsertIDChanged: true, + InsertID: 1, + RowsAffected: 1, } utils.MustMatch(t, wantResult, result) } diff --git a/go/vt/vtgate/executorcontext/vcursor_impl.go b/go/vt/vtgate/executorcontext/vcursor_impl.go index dcd27633c38..4414397f49d 100644 --- a/go/vt/vtgate/executorcontext/vcursor_impl.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl.go @@ -676,7 +676,7 @@ func (vc *VCursorImpl) ExecutePrimitiveStandalone(ctx context.Context, primitive func (vc *VCursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primitive engine.Primitive) func(*sqltypes.Result) error { if vc.interOpStats == nil { return func(r *sqltypes.Result) error { - if r.InsertIDChanged { + if r.InsertIDUpdated() { vc.SafeSession.LastInsertId = r.InsertID } return callback(r) @@ -684,7 +684,7 @@ func (vc *VCursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primi } return func(r *sqltypes.Result) error { - if r.InsertIDChanged { + if r.InsertIDUpdated() { vc.SafeSession.LastInsertId = r.InsertID } vc.logOpTraffic(primitive, r) @@ -772,7 +772,7 @@ func (vc *VCursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.P qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.SafeSession, canAutocommit, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID) vc.setRollbackOnPartialExecIfRequired(len(errs) != len(rss), rollbackOnError) vc.logShardsQueried(primitive, len(rss)) - if qr.InsertIDChanged { + if qr.InsertIDUpdated() { vc.SafeSession.LastInsertId = qr.InsertID } return qr, errs @@ -814,7 +814,7 @@ func (vc *VCursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.P // execute DMLs through ExecuteStandalone. qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.SafeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, vc.observer, fetchLastInsertID) vc.logShardsQueried(primitive, len(rss)) - if qr.InsertIDChanged { + if qr.InsertIDUpdated() { vc.SafeSession.LastInsertId = qr.InsertID } return qr, vterrors.Aggregate(errs) diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 8a67844437e..21423dd28b9 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -1251,7 +1251,7 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction if err = qre.fetchLastInsertID(ctx, conn.Conn, res); err != nil { return err } - if res.InsertIDChanged { + if res.InsertIDUpdated() { return callback(res) } return nil