Skip to content

Commit

Permalink
fix: take lock on critical section on engine primitive stream execute…
Browse files Browse the repository at this point in the history
… apis

Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Nov 22, 2023
1 parent 40f314c commit fd5f549
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 3 deletions.
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package engine
import (
"context"
"fmt"
"sync"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map

// TryStreamExecute implements the Primitive interface
func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
pt := newProbeTable(d.CheckCols)
var mu sync.Mutex

pt := newProbeTable(d.CheckCols)
err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error {
result := &sqltypes.Result{
Fields: input.Fields,
InsertID: input.InsertID,
}
mu.Lock()
defer mu.Unlock()
for _, row := range input.Rows {
exists, err := pt.exists(row)
if err != nil {
Expand Down
44 changes: 42 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/sqltypes"
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand All @@ -41,6 +42,8 @@ type fakePrimitive struct {
log []string

allResultsInOneCall bool

async bool
}

func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) {
Expand Down Expand Up @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
return f.sendErr
}

if f.async {
return f.asyncCall(callback)
}
return f.syncCall(wantfields, callback)
}

func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error {
readMoreResults := true
for readMoreResults && f.curResult < len(f.results) {
readMoreResults = f.allResultsInOneCall
Expand Down Expand Up @@ -116,9 +126,39 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
}
}
}

return nil
}

func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error {
var g errgroup.Group
for _, res := range f.results {
qr := res
g.Go(func() error {
if qr == nil {
return f.sendErr
}
result := &sqltypes.Result{}
for i := 0; i < len(qr.Rows); i++ {
result.Rows = append(result.Rows, qr.Rows[i])
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
return err
}
result = &sqltypes.Result{}
}
}
if len(result.Rows) != 0 {
if err := callback(result); err != nil {
return err
}
}
return nil
})
}
return g.Wait()
}

func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/engine/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -78,9 +79,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s

// TryStreamExecute satisfies the Primitive interface.
func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var mu sync.Mutex

env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor)
filter := func(results *sqltypes.Result) error {
var rows [][]sqltypes.Value

mu.Lock()
defer mu.Unlock()
for _, row := range results.Rows {
env.Row = row
evalResult, err := env.Evaluate(f.Predicate)
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"strconv"
"sync"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -97,6 +98,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
// the offset in memory from the result of the scatter query with count + offset.
bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
if len(qr.Fields) != 0 {
if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
Expand All @@ -108,6 +110,8 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
return nil
}

mu.Lock()
defer mu.Unlock()
// we've still not seen all rows we need to see before we can return anything to the client
if offset > 0 {
if inputSize <= offset {
Expand Down
67 changes: 67 additions & 0 deletions go/vt/vtgate/engine/limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,73 @@ func TestLimitStreamExecute(t *testing.T) {
}
}

func TestLimitStreamExecuteAsync(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
"col1|col2",
"int64|varchar",
)
inputResults := sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"d|3",
"e|4",
"a|1",
"b|2",
"d|3",
"e|4",
"---",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"---",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
)
fp := &fakePrimitive{
results: inputResults,
async: true,
}

const maxCount = 26
for i := 0; i <= maxCount*20; i++ {
expRows := i
l := &Limit{
Count: evalengine.NewLiteralInt(int64(expRows)),
Input: fp,
}
// Test with limit smaller than input.
results := &sqltypes.Result{}

err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
if qr != nil {
results.Rows = append(results.Rows, qr.Rows...)
}
return nil
})
require.NoError(t, err)
if expRows > maxCount {
expRows = maxCount
}
require.Len(t, results.Rows, expRows)
}

}

func TestOffsetStreamExecute(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/engine/memory_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -101,12 +102,16 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
Compare: ms.OrderBy,
Limit: count,
}

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
if len(qr.Fields) != 0 {
if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
}
}
mu.Lock()
defer mu.Unlock()
for _, row := range qr.Rows {
sorter.Push(row)
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind
env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor)
var once sync.Once
var fields []*querypb.Field
var mu sync.Mutex
return vcursor.StreamExecutePrimitive(ctx, p.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
var err error
if wantfields {
Expand All @@ -107,6 +108,8 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind
return err
}
resultRows := make([]sqltypes.Row, 0, len(qr.Rows))
mu.Lock()
defer mu.Unlock()
for _, r := range qr.Rows {
resultRow := make(sqltypes.Row, 0, len(p.Exprs))
env.Row = r
Expand Down

0 comments on commit fd5f549

Please sign in to comment.