Skip to content

Commit

Permalink
Add resultsObserver to ScatterConn (#16638)
Browse files Browse the repository at this point in the history
Signed-off-by: Rafer Hazen <[email protected]>
Signed-off-by: Andres Taylor <[email protected]>
Co-authored-by: Andres Taylor <[email protected]>
  • Loading branch information
rafer and systay authored Aug 29, 2024
1 parent e799299 commit 977b9a3
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 84 deletions.
10 changes: 5 additions & 5 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *Safe
})
queries = append(queries, &querypb.BoundQuery{Sql: sql})
}
qr, errs = e.ExecuteMultiShard(ctx, nil, rss, queries, safeSession, false /*autocommit*/, ignoreMaxMemoryRows)
qr, errs = e.ExecuteMultiShard(ctx, nil, rss, queries, safeSession, false /*autocommit*/, ignoreMaxMemoryRows, nullResultsObserver{})
err := vterrors.Aggregate(errs)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1454,13 +1454,13 @@ func parseAndValidateQuery(query string, parser *sqlparser.Parser) (sqlparser.St
}

// ExecuteMultiShard implements the IExecutor interface
func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool) (qr *sqltypes.Result, errs []error) {
return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows)
func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver resultsObserver) (qr *sqltypes.Result, errs []error) {
return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows, resultsObserver)
}

// StreamExecuteMulti implements the IExecutor interface
func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error) []error {
return e.scatterConn.StreamExecuteMulti(ctx, primitive, query, rss, vars, session, autocommit, callback)
func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, resultsObserver resultsObserver) []error {
return e.scatterConn.StreamExecuteMulti(ctx, primitive, query, rss, vars, session, autocommit, callback, resultsObserver)
}

// ExecuteLock implements the IExecutor interface
Expand Down
54 changes: 40 additions & 14 deletions go/vt/vtgate/legacy_scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestLegacyExecuteFailOnAutocommit(t *testing.T) {
},
Autocommit: false,
}
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{})
err := vterrors.Aggregate(errs)
require.Error(t, err)
require.Contains(t, err.Error(), "in autocommit mode, transactionID should be zero but was: 123")
Expand All @@ -123,7 +123,7 @@ func TestScatterConnExecuteMulti(t *testing.T) {
}
}

qr, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(nil), false /*autocommit*/, false)
qr, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(nil), false /*autocommit*/, false, nullResultsObserver{})
return qr, vterrors.Aggregate(errs)
})
}
Expand All @@ -143,7 +143,7 @@ func TestScatterConnStreamExecuteMulti(t *testing.T) {
defer mu.Unlock()
qr.AppendResult(r)
return nil
})
}, nullResultsObserver{})
return qr, vterrors.Aggregate(errors)
})
}
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestMaxMemoryRows(t *testing.T) {
sbc0.SetResults([]*sqltypes.Result{tworows, tworows})
sbc1.SetResults([]*sqltypes.Result{tworows, tworows})

_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, test.ignoreMaxMemoryRows)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, test.ignoreMaxMemoryRows, nullResultsObserver{})
if test.ignoreMaxMemoryRows {
require.NoError(t, err)
} else {
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestLegaceHealthCheckFailsOnReservedConnections(t *testing.T) {
})
}

_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, nullResultsObserver{})
require.Error(t, vterrors.Aggregate(errs))
}

Expand All @@ -365,10 +365,21 @@ func executeOnShardsReturnsErr(t *testing.T, ctx context.Context, res *srvtopo.R
})
}

_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, nullResultsObserver{})
return vterrors.Aggregate(errs)
}

type recordingResultsObserver struct {
mu sync.Mutex
recorded []*sqltypes.Result
}

func (o *recordingResultsObserver) observe(result *sqltypes.Result) {
mu.Lock()
o.recorded = append(o.recorded, result)
mu.Unlock()
}

func TestMultiExecs(t *testing.T) {
ctx := utils.LeakCheckContext(t)
createSandbox("TestMultiExecs")
Expand Down Expand Up @@ -409,9 +420,17 @@ func TestMultiExecs(t *testing.T) {
},
},
}
results := []*sqltypes.Result{
{Info: "r0"},
{Info: "r1"},
}
sbc0.SetResults(results[0:1])
sbc1.SetResults(results[1:2])

observer := recordingResultsObserver{}

session := NewSafeSession(&vtgatepb.Session{})
_, err := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false)
_, err := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, &observer)
require.NoError(t, vterrors.Aggregate(err))
if len(sbc0.Queries) == 0 || len(sbc1.Queries) == 0 {
t.Fatalf("didn't get expected query")
Expand All @@ -428,8 +447,12 @@ func TestMultiExecs(t *testing.T) {
if !reflect.DeepEqual(sbc1.Queries[0].BindVariables, wantVars1) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars1)
}
assert.ElementsMatch(t, results, observer.recorded)

sbc0.Queries = nil
sbc1.Queries = nil
sbc0.SetResults(results[0:1])
sbc1.SetResults(results[1:2])

rss = []*srvtopo.ResolvedShard{
{
Expand All @@ -455,15 +478,18 @@ func TestMultiExecs(t *testing.T) {
"bv1": sqltypes.Int64BindVariable(1),
},
}

observer = recordingResultsObserver{}
_ = sc.StreamExecuteMulti(ctx, nil, "query", rss, bvs, session, false /* autocommit */, func(*sqltypes.Result) error {
return nil
})
}, &observer)
if !reflect.DeepEqual(sbc0.Queries[0].BindVariables, wantVars0) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars0)
}
if !reflect.DeepEqual(sbc1.Queries[0].BindVariables, wantVars1) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars1)
}
assert.ElementsMatch(t, results, observer.recorded)
}

func TestScatterConnSingleDB(t *testing.T) {
Expand All @@ -487,27 +513,27 @@ func TestScatterConnSingleDB(t *testing.T) {
// TransactionMode_SINGLE in session
session := NewSafeSession(&vtgatepb.Session{InTransaction: true, TransactionMode: vtgatepb.TransactionMode_SINGLE})
queries := []*querypb.BoundQuery{{Sql: "query1"}}
_, errors := sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
_, errors := sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{})
require.Error(t, errors[0])
assert.Contains(t, errors[0].Error(), want)

// TransactionMode_SINGLE in txconn
sc.txConn.mode = vtgatepb.TransactionMode_SINGLE
session = NewSafeSession(&vtgatepb.Session{InTransaction: true})
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{})
require.Error(t, errors[0])
assert.Contains(t, errors[0].Error(), want)

// TransactionMode_MULTI in txconn. Should not fail.
sc.txConn.mode = vtgatepb.TransactionMode_MULTI
session = NewSafeSession(&vtgatepb.Session{InTransaction: true})
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
}

Expand Down
33 changes: 27 additions & 6 deletions go/vt/vtgate/scatter_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ type shardActionFunc func(rs *srvtopo.ResolvedShard, i int) error
// the results and errors for the caller.
type shardActionTransactionFunc func(rs *srvtopo.ResolvedShard, i int, shardActionInfo *shardActionInfo) (*shardActionInfo, error)

type (
resultsObserver interface {
observe(*sqltypes.Result)
}
nullResultsObserver struct{}
)

func (nullResultsObserver) observe(*sqltypes.Result) {}

// NewScatterConn creates a new ScatterConn.
func NewScatterConn(statsName string, txConn *TxConn, gw *TabletGateway) *ScatterConn {
// this only works with TabletGateway
Expand Down Expand Up @@ -146,6 +155,7 @@ func (stc *ScatterConn) ExecuteMultiShard(
session *SafeSession,
autocommit bool,
ignoreMaxMemoryRows bool,
resultsObserver resultsObserver,
) (qr *sqltypes.Result, errs []error) {

if len(rss) != len(queries) {
Expand Down Expand Up @@ -260,6 +270,10 @@ func (stc *ScatterConn) ExecuteMultiShard(
mu.Lock()
defer mu.Unlock()

if innerqr != nil {
resultsObserver.observe(innerqr)
}

// Don't append more rows if row count is exceeded.
if ignoreMaxMemoryRows || len(qr.Rows) <= maxMemoryRows {
qr.AppendResult(innerqr)
Expand Down Expand Up @@ -354,11 +368,18 @@ func (stc *ScatterConn) StreamExecuteMulti(
session *SafeSession,
autocommit bool,
callback func(reply *sqltypes.Result) error,
resultsObserver resultsObserver,
) []error {
if session.InLockSession() && session.TriggerLockHeartBeat() {
go stc.runLockQuery(ctx, session)
}

observedCallback := func(reply *sqltypes.Result) error {
if reply != nil {
resultsObserver.observe(reply)
}
return callback(reply)
}
allErrors := stc.multiGoTransaction(
ctx,
"StreamExecute",
Expand Down Expand Up @@ -407,41 +428,41 @@ func (stc *ScatterConn) StreamExecuteMulti(

switch info.actionNeeded {
case nothing:
err = qs.StreamExecute(ctx, rs.Target, query, bindVars[i], transactionID, reservedID, opts, callback)
err = qs.StreamExecute(ctx, rs.Target, query, bindVars[i], transactionID, reservedID, opts, observedCallback)
if err != nil {
retryRequest(func() {
// we seem to have lost our connection. it was a reserved connection, let's try to recreate it
info.actionNeeded = reserve
var state queryservice.ReservedState
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], 0 /*transactionId*/, opts, callback)
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], 0 /*transactionId*/, opts, observedCallback)
reservedID = state.ReservedID
alias = state.TabletAlias
})
}
case begin:
var state queryservice.TransactionState
state, err = qs.BeginStreamExecute(ctx, rs.Target, session.SavePoints(), query, bindVars[i], reservedID, opts, callback)
state, err = qs.BeginStreamExecute(ctx, rs.Target, session.SavePoints(), query, bindVars[i], reservedID, opts, observedCallback)
transactionID = state.TransactionID
alias = state.TabletAlias
if err != nil {
retryRequest(func() {
// we seem to have lost our connection. it was a reserved connection, let's try to recreate it
info.actionNeeded = reserveBegin
var state queryservice.ReservedTransactionState
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, callback)
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, observedCallback)
transactionID = state.TransactionID
reservedID = state.ReservedID
alias = state.TabletAlias
})
}
case reserve:
var state queryservice.ReservedState
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], transactionID, opts, callback)
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], transactionID, opts, observedCallback)
reservedID = state.ReservedID
alias = state.TabletAlias
case reserveBegin:
var state queryservice.ReservedTransactionState
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, callback)
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, observedCallback)
transactionID = state.TransactionID
reservedID = state.ReservedID
alias = state.TabletAlias
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestExecuteFailOnAutocommit(t *testing.T) {
},
Autocommit: false,
}
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{})
err := vterrors.Aggregate(errs)
require.Error(t, err)
require.Contains(t, err.Error(), "in autocommit mode, transactionID should be zero but was: 123")
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestExecutePanic(t *testing.T) {
require.Contains(t, logMessage, "(*ScatterConn).multiGoTransaction")
}()

_, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false)
_, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{})

}

Expand Down
Loading

0 comments on commit 977b9a3

Please sign in to comment.