Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix concurrency on stream execute engine primitives #14586

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package engine
import (
"context"
"fmt"
"sync"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map

// TryStreamExecute implements the Primitive interface
func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
pt := newProbeTable(d.CheckCols)
var mu sync.Mutex

pt := newProbeTable(d.CheckCols)
err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error {
result := &sqltypes.Result{
Fields: input.Fields,
InsertID: input.InsertID,
}
mu.Lock()
defer mu.Unlock()
for _, row := range input.Rows {
exists, err := pt.exists(row)
if err != nil {
Expand Down
53 changes: 53 additions & 0 deletions go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,59 @@ func TestDistinct(t *testing.T) {
}
}

func TestDistinctStreamAsync(t *testing.T) {
distinct := &Distinct{
Source: &fakePrimitive{
results: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("myid|id|num|name", "varchar|int64|int64|varchar"),
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"z|1|1|a",
"a|1|1|t",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
),
async: true,
},
CheckCols: []CheckCol{
{Col: 0, Type: evalengine.NewType(sqltypes.VarChar, collations.CollationUtf8mb4ID)},
{Col: 1, Type: evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)},
{Col: 2, Type: evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)},
{Col: 3, Type: evalengine.NewType(sqltypes.VarChar, collations.CollationUtf8mb4ID)},
},
}

qr := &sqltypes.Result{}
err := distinct.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(result *sqltypes.Result) error {
qr.Rows = append(qr.Rows, result.Rows...)
return nil
})
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEqualsStr(`
[[VARCHAR("c") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("z") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("t")]]`, qr.Rows))
}

func TestWeightStringFallBack(t *testing.T) {
offsetOne := 1
checkCols := []CheckCol{{
Expand Down
51 changes: 49 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/sqltypes"
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand All @@ -41,6 +42,8 @@ type fakePrimitive struct {
log []string

allResultsInOneCall bool

async bool
}

func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) {
Expand Down Expand Up @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
return f.sendErr
}

if f.async {
return f.asyncCall(callback)
}
return f.syncCall(wantfields, callback)
}

func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error {
readMoreResults := true
for readMoreResults && f.curResult < len(f.results) {
readMoreResults = f.allResultsInOneCall
Expand Down Expand Up @@ -116,9 +126,46 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
}
}
}

return nil
}

func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error {
var g errgroup.Group
var fields []*querypb.Field
if len(f.results) > 0 {
fields = f.results[0].Fields
}
for _, res := range f.results {
qr := res
g.Go(func() error {
if qr == nil {
return f.sendErr
}
if err := callback(&sqltypes.Result{Fields: fields}); err != nil {
return err
}
result := &sqltypes.Result{}
for i := 0; i < len(qr.Rows); i++ {
result.Rows = append(result.Rows, qr.Rows[i])
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
return err
}
result = &sqltypes.Result{}
}
}
if len(result.Rows) != 0 {
if err := callback(result); err != nil {
return err
}
}
return nil
})
}
return g.Wait()
}

func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/engine/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -78,9 +79,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s

// TryStreamExecute satisfies the Primitive interface.
func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var mu sync.Mutex

env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor)
filter := func(results *sqltypes.Result) error {
var rows [][]sqltypes.Value

mu.Lock()
defer mu.Unlock()
dbussink marked this conversation as resolved.
Show resolved Hide resolved
for _, row := range results.Rows {
env.Row = row
evalResult, err := env.Evaluate(f.Predicate)
Expand Down
60 changes: 60 additions & 0 deletions go/vt/vtgate/engine/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,63 @@ func TestFilterPass(t *testing.T) {
})
}
}

func TestFilterStreaming(t *testing.T) {
utf8mb4Bin := collationEnv.LookupByName("utf8mb4_bin")
predicate := &sqlparser.ComparisonExpr{
Operator: sqlparser.GreaterThanOp,
Left: sqlparser.NewColName("left"),
Right: sqlparser.NewColName("right"),
}

tcases := []struct {
name string
res []*sqltypes.Result
expRes string
}{{
name: "int32",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "int32|int32"), "0|1", "---", "1|0", "2|3"),
expRes: `[[INT32(1) INT32(0)]]`,
}, {
name: "uint16",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint16|uint16"), "0|1", "1|0", "---", "2|3"),
expRes: `[[UINT16(1) UINT16(0)]]`,
}, {
name: "uint64_int64",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint64|int64"), "0|1", "---", "1|0", "2|3"),
expRes: `[[UINT64(1) INT64(0)]]`,
}, {
name: "int32_uint32",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "int32|uint32"), "0|1", "---", "1|0", "---", "2|3"),
expRes: `[[INT32(1) UINT32(0)]]`,
}, {
name: "uint16_int8",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint16|int8"), "0|1", "1|0", "2|3", "---"),
expRes: `[[UINT16(1) INT8(0)]]`,
}, {
name: "uint64_int32",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint64|int32"), "0|1", "1|0", "2|3", "---", "0|1", "1|3", "5|3"),
expRes: `[[UINT64(1) INT32(0)] [UINT64(5) INT32(3)]]`,
}}
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
pred, err := evalengine.Translate(predicate, &evalengine.Config{
Collation: utf8mb4Bin,
ResolveColumn: evalengine.FieldResolver(tc.res[0].Fields).Column,
})
require.NoError(t, err)

filter := &Filter{
Predicate: pred,
Input: &fakePrimitive{results: tc.res, async: true},
}
qr := &sqltypes.Result{}
err = filter.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(result *sqltypes.Result) error {
qr.Rows = append(qr.Rows, result.Rows...)
return nil
})
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEqualsStr(tc.expRes, qr.Rows))
})
}
}
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"strconv"
"sync"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -97,8 +98,11 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
// the offset in memory from the result of the scatter query with count + offset.
bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
if len(qr.Fields) != 0 {
mu.Lock()
defer mu.Unlock()
if wantfields && len(qr.Fields) != 0 {
if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
}
Expand Down
67 changes: 67 additions & 0 deletions go/vt/vtgate/engine/limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,73 @@ func TestLimitStreamExecute(t *testing.T) {
}
}

func TestLimitStreamExecuteAsync(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
"col1|col2",
"int64|varchar",
)
inputResults := sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"d|3",
"e|4",
"a|1",
"b|2",
"d|3",
"e|4",
"---",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"---",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
)
fp := &fakePrimitive{
results: inputResults,
async: true,
}

const maxCount = 26
for i := 0; i <= maxCount*20; i++ {
expRows := i
l := &Limit{
Count: evalengine.NewLiteralInt(int64(expRows)),
Input: fp,
}
// Test with limit smaller than input.
results := &sqltypes.Result{}

err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
if qr != nil {
results.Rows = append(results.Rows, qr.Rows...)
}
return nil
})
require.NoError(t, err)
if expRows > maxCount {
expRows = maxCount
}
require.Len(t, results.Rows, expRows)
}

}

func TestOffsetStreamExecute(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/engine/memory_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -101,7 +102,11 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
Compare: ms.OrderBy,
Limit: count,
}

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
mu.Lock()
defer mu.Unlock()
if len(qr.Fields) != 0 {
if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
Expand Down
Loading
Loading