From fd5f549acb63e592617f732c3d6f21ac6ed9615d Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Wed, 22 Nov 2023 23:43:51 +0530 Subject: [PATCH] fix: take lock on critical section on engine primitive stream execute apis Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/distinct.go | 6 +- go/vt/vtgate/engine/fake_primitive_test.go | 44 +++++++++++++- go/vt/vtgate/engine/filter.go | 6 ++ go/vt/vtgate/engine/limit.go | 4 ++ go/vt/vtgate/engine/limit_test.go | 67 ++++++++++++++++++++++ go/vt/vtgate/engine/memory_sort.go | 5 ++ go/vt/vtgate/engine/projection.go | 3 + 7 files changed, 132 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 7e55138e27e..2d263464a2e 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -19,6 +19,7 @@ package engine import ( "context" "fmt" + "sync" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map // TryStreamExecute implements the Primitive interface func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - pt := newProbeTable(d.CheckCols) + var mu sync.Mutex + pt := newProbeTable(d.CheckCols) err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error { result := &sqltypes.Result{ Fields: input.Fields, InsertID: input.InsertID, } + mu.Lock() + defer mu.Unlock() for _, row := range input.Rows { exists, err := pt.exists(row) if err != nil { diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index dcec32f1ffd..b76658c97fc 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -23,8 +23,9 @@ import ( "strings" "testing" - "vitess.io/vitess/go/sqltypes" + "golang.org/x/sync/errgroup" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -41,6 +42,8 @@ type fakePrimitive struct { log []string allResultsInOneCall bool + + async bool } func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) { @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b return f.sendErr } + if f.async { + return f.asyncCall(callback) + } + return f.syncCall(wantfields, callback) +} + +func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error { readMoreResults := true for readMoreResults && f.curResult < len(f.results) { readMoreResults = f.allResultsInOneCall @@ -116,9 +126,39 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b } } } - return nil } + +func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error { + var g errgroup.Group + for _, res := range f.results { + qr := res + g.Go(func() error { + if qr == nil { + return f.sendErr + } + result := &sqltypes.Result{} + for i := 0; i < len(qr.Rows); i++ { + result.Rows = append(result.Rows, qr.Rows[i]) + // Send only two rows at a time. + if i%2 == 1 { + if err := callback(result); err != nil { + return err + } + result = &sqltypes.Result{} + } + } + if len(result.Rows) != 0 { + if err := callback(result); err != nil { + return err + } + } + return nil + }) + } + return g.Wait() +} + func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars))) return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */) diff --git a/go/vt/vtgate/engine/filter.go b/go/vt/vtgate/engine/filter.go index c0a54f2b6ac..78e9dde4ee2 100644 --- a/go/vt/vtgate/engine/filter.go +++ b/go/vt/vtgate/engine/filter.go @@ -18,6 +18,7 @@ package engine import ( "context" + "sync" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -78,9 +79,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s // TryStreamExecute satisfies the Primitive interface. func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + var mu sync.Mutex + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) filter := func(results *sqltypes.Result) error { var rows [][]sqltypes.Value + + mu.Lock() + defer mu.Unlock() for _, row := range results.Rows { env.Row = row evalResult, err := env.Evaluate(f.Predicate) diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 4ef809ad1fa..58e2ec5780d 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "strconv" + "sync" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -97,6 +98,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars // the offset in memory from the result of the scatter query with count + offset. bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset)) + var mu sync.Mutex err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { if len(qr.Fields) != 0 { if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil { @@ -108,6 +110,8 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return nil } + mu.Lock() + defer mu.Unlock() // we've still not seen all rows we need to see before we can return anything to the client if offset > 0 { if inputSize <= offset { diff --git a/go/vt/vtgate/engine/limit_test.go b/go/vt/vtgate/engine/limit_test.go index 15bda20ace7..8b91dadecb5 100644 --- a/go/vt/vtgate/engine/limit_test.go +++ b/go/vt/vtgate/engine/limit_test.go @@ -451,6 +451,73 @@ func TestLimitStreamExecute(t *testing.T) { } } +func TestLimitStreamExecuteAsync(t *testing.T) { + bindVars := make(map[string]*querypb.BindVariable) + fields := sqltypes.MakeTestFields( + "col1|col2", + "int64|varchar", + ) + inputResults := sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "d|3", + "e|4", + "a|1", + "b|2", + "d|3", + "e|4", + "---", + "c|7", + "x|8", + "y|9", + "c|7", + "x|8", + "y|9", + "c|7", + "x|8", + "y|9", + "---", + "l|4", + "m|5", + "n|6", + "l|4", + "m|5", + "n|6", + "l|4", + "m|5", + "n|6", + ) + fp := &fakePrimitive{ + results: inputResults, + async: true, + } + + const maxCount = 26 + for i := 0; i <= maxCount*20; i++ { + expRows := i + l := &Limit{ + Count: evalengine.NewLiteralInt(int64(expRows)), + Input: fp, + } + // Test with limit smaller than input. + results := &sqltypes.Result{} + + err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error { + if qr != nil { + results.Rows = append(results.Rows, qr.Rows...) + } + return nil + }) + require.NoError(t, err) + if expRows > maxCount { + expRows = maxCount + } + require.Len(t, results.Rows, expRows) + } + +} + func TestOffsetStreamExecute(t *testing.T) { bindVars := make(map[string]*querypb.BindVariable) fields := sqltypes.MakeTestFields( diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index b896b303923..c5511575fd2 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -23,6 +23,7 @@ import ( "reflect" "strconv" "strings" + "sync" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -101,12 +102,16 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin Compare: ms.OrderBy, Limit: count, } + + var mu sync.Mutex err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { if len(qr.Fields) != 0 { if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil { return err } } + mu.Lock() + defer mu.Unlock() for _, row := range qr.Rows { sorter.Push(row) } diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 166dd88f477..12c28b185da 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -88,6 +88,7 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) var once sync.Once var fields []*querypb.Field + var mu sync.Mutex return vcursor.StreamExecutePrimitive(ctx, p.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { var err error if wantfields { @@ -107,6 +108,8 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind return err } resultRows := make([]sqltypes.Row, 0, len(qr.Rows)) + mu.Lock() + defer mu.Unlock() for _, r := range qr.Rows { resultRow := make(sqltypes.Row, 0, len(p.Exprs)) env.Row = r