Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix concurrency on stream execute engine primitives #14586

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package engine
import (
"context"
"fmt"
"sync"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map

// TryStreamExecute implements the Primitive interface
func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
pt := newProbeTable(d.CheckCols)
var mu sync.Mutex

pt := newProbeTable(d.CheckCols)
err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error {
result := &sqltypes.Result{
Fields: input.Fields,
InsertID: input.InsertID,
}
mu.Lock()
defer mu.Unlock()
for _, row := range input.Rows {
exists, err := pt.exists(row)
if err != nil {
Expand Down
44 changes: 42 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/sqltypes"
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand All @@ -41,6 +42,8 @@ type fakePrimitive struct {
log []string

allResultsInOneCall bool

async bool
}

func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) {
Expand Down Expand Up @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
return f.sendErr
}

if f.async {
return f.asyncCall(callback)
}
return f.syncCall(wantfields, callback)
}

func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error {
readMoreResults := true
for readMoreResults && f.curResult < len(f.results) {
readMoreResults = f.allResultsInOneCall
Expand Down Expand Up @@ -116,9 +126,39 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
}
}
}

return nil
}

func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error {
var g errgroup.Group
for _, res := range f.results {
qr := res
g.Go(func() error {
if qr == nil {
return f.sendErr
}
result := &sqltypes.Result{}
for i := 0; i < len(qr.Rows); i++ {
result.Rows = append(result.Rows, qr.Rows[i])
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
return err
}
result = &sqltypes.Result{}
}
}
if len(result.Rows) != 0 {
if err := callback(result); err != nil {
return err
}
}
return nil
})
}
return g.Wait()
}

func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/engine/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -78,9 +79,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s

// TryStreamExecute satisfies the Primitive interface.
func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var mu sync.Mutex

env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor)
filter := func(results *sqltypes.Result) error {
var rows [][]sqltypes.Value

mu.Lock()
defer mu.Unlock()
dbussink marked this conversation as resolved.
Show resolved Hide resolved
for _, row := range results.Rows {
env.Row = row
evalResult, err := env.Evaluate(f.Predicate)
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"strconv"
"sync"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -97,6 +98,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
// the offset in memory from the result of the scatter query with count + offset.
bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
if len(qr.Fields) != 0 {
if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
Expand All @@ -108,6 +110,8 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
return nil
}

mu.Lock()
defer mu.Unlock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like need this for the count, but wonder, should we do this with an atomic counter instead? The usage might be complicated though, so the lock might be a lot easier for now and then to optimize later.

// we've still not seen all rows we need to see before we can return anything to the client
if offset > 0 {
if inputSize <= offset {
Expand Down
67 changes: 67 additions & 0 deletions go/vt/vtgate/engine/limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,73 @@ func TestLimitStreamExecute(t *testing.T) {
}
}

func TestLimitStreamExecuteAsync(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
"col1|col2",
"int64|varchar",
)
inputResults := sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"d|3",
"e|4",
"a|1",
"b|2",
"d|3",
"e|4",
"---",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"---",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
)
fp := &fakePrimitive{
results: inputResults,
async: true,
}

const maxCount = 26
for i := 0; i <= maxCount*20; i++ {
expRows := i
l := &Limit{
Count: evalengine.NewLiteralInt(int64(expRows)),
Input: fp,
}
// Test with limit smaller than input.
results := &sqltypes.Result{}

err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
if qr != nil {
results.Rows = append(results.Rows, qr.Rows...)
}
return nil
})
require.NoError(t, err)
if expRows > maxCount {
expRows = maxCount
}
require.Len(t, results.Rows, expRows)
}

}

func TestOffsetStreamExecute(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/engine/memory_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -101,12 +102,16 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
Compare: ms.OrderBy,
Limit: count,
}

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
if len(qr.Fields) != 0 {
if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
}
}
mu.Lock()
defer mu.Unlock()
for _, row := range qr.Rows {
sorter.Push(row)
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind
env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor)
var once sync.Once
var fields []*querypb.Field
var mu sync.Mutex
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harshit-gangal Is this one necessary? There's already locking in the implementation here which should afaik handle any concurrency issues.

@vmg Afaik the evalengine.NewExpressionEnv should be concurrent usage safe, or is it not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it definitely isn't, it contains the VM that is used to execute the expressions. You cannot evaluate two expressions at once in the same VM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, so we gotta lock here then as well for now unless we'd want to move creating the env inside the callback then.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next plan is to create env per shard.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That refactor would not be part of the backport and will be made as a separate PR.

return vcursor.StreamExecutePrimitive(ctx, p.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
var err error
if wantfields {
Expand All @@ -107,6 +108,8 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind
return err
}
resultRows := make([]sqltypes.Row, 0, len(qr.Rows))
mu.Lock()
defer mu.Unlock()
for _, r := range qr.Rows {
resultRow := make(sqltypes.Row, 0, len(p.Exprs))
env.Row = r
Expand Down
Loading