diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index d8e7ed208cf..d3d2ba8e8fd 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -262,13 +262,11 @@ func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConn } type streaminResultReceiver struct { - mu sync.Mutex - stmtType sqlparser.StatementType - rowsAffected uint64 - rowsReturned int - insertID uint64 - insertIDChanged bool - callback func(*sqltypes.Result) error + mu sync.Mutex + stmtType sqlparser.StatementType + rowsAffected uint64 + rowsReturned int + callback func(*sqltypes.Result) error } func (s *streaminResultReceiver) storeResultStats(typ sqlparser.StatementType, qr *sqltypes.Result) error { @@ -276,10 +274,6 @@ func (s *streaminResultReceiver) storeResultStats(typ sqlparser.StatementType, q defer s.mu.Unlock() s.rowsAffected += qr.RowsAffected s.rowsReturned += len(qr.Rows) - if qr.InsertIDUpdated() { - s.insertID = qr.InsertID - } - s.insertIDChanged = s.insertIDChanged || qr.InsertIDUpdated() s.stmtType = typ return s.callback(qr) } diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index eb7f6a91823..70663402287 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -394,6 +394,11 @@ func (stc *ScatterConn) StreamExecuteMulti( } return callback(reply) } + + if session.Options != nil { + session.Options.FetchLastInsertId = fetchLastInsertID + } + allErrors := stc.multiGoTransaction( ctx, "StreamExecute", diff --git a/go/vt/vtgate/scatter_conn_test.go b/go/vt/vtgate/scatter_conn_test.go index af44cf0c3c1..2a4dddbd5b7 100644 --- a/go/vt/vtgate/scatter_conn_test.go +++ b/go/vt/vtgate/scatter_conn_test.go @@ -20,27 +20,23 @@ import ( "fmt" "testing" - "vitess.io/vitess/go/vt/log" - econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" - - "vitess.io/vitess/go/mysql/sqlerror" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - + "github.com/aws/smithy-go/ptr" "github.com/stretchr/testify/assert" - - "vitess.io/vitess/go/vt/key" - - "vitess.io/vitess/go/test/utils" - "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/discovery" + "vitess.io/vitess/go/vt/key" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" ) // This file uses the sandbox_test framework. @@ -110,55 +106,115 @@ func TestExecuteFailOnAutocommit(t *testing.T) { } func TestFetchLastInsertIDResets(t *testing.T) { - ctx := utils.LeakCheckContext(t) - + // This test verifies that the FetchLastInsertID flag is reset after a call to ExecuteMultiShard. ks := "TestFetchLastInsertIDResets" - createSandbox(ks) - hc := discovery.NewFakeHealthCheck(nil) - sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa") - sbc0 := hc.AddTestTablet("aa", "0", 1, ks, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) - sbc1 := hc.AddTestTablet("aa", "1", 1, ks, "1", topodatapb.TabletType_PRIMARY, true, 1, nil) - - rss := []*srvtopo.ResolvedShard{{ - Target: &querypb.Target{ - Keyspace: ks, - Shard: "0", - TabletType: topodatapb.TabletType_PRIMARY, + tests := []struct { + name string + initialSessionOpts *querypb.ExecuteOptions + fetchLastInsertID bool + expectSessionNil bool + expectFetchLastID *bool // nil means checkLastOptionNil, otherwise checkLastOption(*bool) + }{ + { + name: "no session options, fetchLastInsertID = false", + initialSessionOpts: nil, + fetchLastInsertID: false, + expectSessionNil: true, + expectFetchLastID: nil, }, - Gateway: sbc0, - }, { - Target: &querypb.Target{ - Keyspace: ks, - Shard: "1", - TabletType: topodatapb.TabletType_PRIMARY, + { + name: "no session options, fetchLastInsertID = true", + initialSessionOpts: nil, + fetchLastInsertID: true, + expectSessionNil: true, + + expectFetchLastID: ptr.Bool(true), }, - Gateway: sbc1, - }} - queries := []*querypb.BoundQuery{{ - // This will fail to go to shard. It will be rejected at vtgate. - Sql: "query1", - BindVariables: map[string]*querypb.BindVariable{ - "bv0": sqltypes.Int64BindVariable(0), + { + name: "session options set, fetchLastInsertID = false", + initialSessionOpts: &querypb.ExecuteOptions{}, + fetchLastInsertID: false, + expectSessionNil: false, + expectFetchLastID: ptr.Bool(false), }, - }, { - // This will go to shard. - Sql: "query2", - BindVariables: map[string]*querypb.BindVariable{ - "bv1": sqltypes.Int64BindVariable(1), + { + name: "session options set, fetchLastInsertID = true", + initialSessionOpts: &querypb.ExecuteOptions{}, + fetchLastInsertID: true, + expectSessionNil: false, + expectFetchLastID: ptr.Bool(true), }, - }} + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := utils.LeakCheckContext(t) + + createSandbox(ks) + hc := discovery.NewFakeHealthCheck(nil) + sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa") + sbc0 := hc.AddTestTablet("aa", "0", 1, ks, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) + sbc1 := hc.AddTestTablet("aa", "1", 1, ks, "1", topodatapb.TabletType_PRIMARY, true, 1, nil) + + rss := []*srvtopo.ResolvedShard{{ + Target: &querypb.Target{ + Keyspace: ks, + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + }, + Gateway: sbc0, + }, { + Target: &querypb.Target{ + Keyspace: ks, + Shard: "1", + TabletType: topodatapb.TabletType_PRIMARY, + }, + Gateway: sbc1, + }} + queries := []*querypb.BoundQuery{{ + Sql: "query1", + BindVariables: map[string]*querypb.BindVariable{ + "bv0": sqltypes.Int64BindVariable(0), + }, + }, { + Sql: "query2", + BindVariables: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + }, + }} + + session := econtext.NewSafeSession(nil) + session.Options = tt.initialSessionOpts + + checkLastOption := func(expected bool) { + require.Equal(t, 1, len(sbc0.Options)) + options := sbc0.Options[0] + assert.Equal(t, options.FetchLastInsertId, expected) + sbc0.Options = nil + } + checkLastOptionNil := func() { + require.Equal(t, 1, len(sbc0.Options)) + assert.Nil(t, sbc0.Options[0]) + sbc0.Options = nil + } - session := econtext.NewSafeSession(&vtgatepb.Session{Options: &querypb.ExecuteOptions{}}) + _, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, tt.fetchLastInsertID) + require.NoError(t, vterrors.Aggregate(errs)) - fetchLastInsertID := true - _, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, fetchLastInsertID) - require.NoError(t, vterrors.Aggregate(errs)) - assert.True(t, session.Options.FetchLastInsertId) + if tt.expectSessionNil { + assert.Nil(t, session.Options) + } else { + assert.NotNil(t, session.Options) + assert.Equal(t, tt.fetchLastInsertID, session.Options.FetchLastInsertId) + } - fetchLastInsertID = false - _, errs = sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, fetchLastInsertID) - require.NoError(t, vterrors.Aggregate(errs)) - assert.False(t, session.Options.FetchLastInsertId) + if tt.expectFetchLastID == nil { + checkLastOptionNil() + } else { + checkLastOption(*tt.expectFetchLastID) + } + }) + } } func TestExecutePanic(t *testing.T) {