Skip to content

Commit

Permalink
fix concurrency on stream execute engine primitives (#14586)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
Signed-off-by: Dirkjan Bussink <[email protected]>
Co-authored-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
harshit-gangal and dbussink committed Nov 23, 2023
1 parent 53dfd30 commit e86b889
Show file tree
Hide file tree
Showing 12 changed files with 527 additions and 14 deletions.
174 changes: 174 additions & 0 deletions go/sqltypes/parse_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
Copyright 2023 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sqltypes

import (
"fmt"
"io"
"reflect"
"strconv"
"strings"
"text/scanner"

querypb "vitess.io/vitess/go/vt/proto/query"
)

// ParseRows parses the output generated by fmt.Sprintf("#v", rows), and reifies the original []sqltypes.Row
// NOTE: This is not meant for production use!
func ParseRows(input string) ([]Row, error) {
type state int
const (
stInvalid state = iota
stInit
stBeginRow
stInRow
stInValue0
stInValue1
stInValue2
)

var (
scan scanner.Scanner
result []Row
row Row
vtype int32
st = stInit
)

scan.Init(strings.NewReader(input))

for tok := scan.Scan(); tok != scanner.EOF; tok = scan.Scan() {
var next state

switch st {
case stInit:
if tok == '[' {
next = stBeginRow
}
case stBeginRow:
switch tok {
case '[':
next = stInRow
case ']':
return result, nil
}
case stInRow:
switch tok {
case ']':
result = append(result, row)
row = nil
next = stBeginRow
case scanner.Ident:
ident := scan.TokenText()

if ident == "NULL" {
row = append(row, NULL)
continue
}

var ok bool
vtype, ok = querypb.Type_value[ident]
if !ok {
return nil, fmt.Errorf("unknown SQL type %q at %s", ident, scan.Position)
}
next = stInValue0
}
case stInValue0:
if tok == '(' {
next = stInValue1
}
case stInValue1:
literal := scan.TokenText()
switch tok {
case scanner.String:
var err error
literal, err = strconv.Unquote(literal)
if err != nil {
return nil, fmt.Errorf("failed to parse literal string at %s: %w", scan.Position, err)
}
fallthrough
case scanner.Int, scanner.Float:
row = append(row, MakeTrusted(Type(vtype), []byte(literal)))
next = stInValue2
}
case stInValue2:
if tok == ')' {
next = stInRow
}
}
if next == stInvalid {
return nil, fmt.Errorf("unexpected token '%s' at %s", scan.TokenText(), scan.Position)
}
st = next
}
return nil, io.ErrUnexpectedEOF
}

type RowMismatchError struct {
err error
want, got []Row
}

func (e *RowMismatchError) Error() string {
return fmt.Sprintf("results differ: %v\n\twant: %v\n\tgot: %v", e.err, e.want, e.got)
}

func RowsEquals(want, got []Row) error {
if len(want) != len(got) {
return &RowMismatchError{
err: fmt.Errorf("expected %d rows in result, got %d", len(want), len(got)),
want: want,
got: got,
}
}

var matched = make([]bool, len(want))
for _, aa := range want {
var ok bool
for i, bb := range got {
if matched[i] {
continue
}
if reflect.DeepEqual(aa, bb) {
matched[i] = true
ok = true
break
}
}
if !ok {
return &RowMismatchError{
err: fmt.Errorf("row %v is missing from result", aa),
want: want,
got: got,
}
}
}
for _, m := range matched {
if !m {
return fmt.Errorf("not all elements matched")
}
}
return nil
}

func RowsEqualsStr(wantStr string, got []Row) error {
want, err := ParseRows(wantStr)
if err != nil {
return fmt.Errorf("malformed row assertion: %w", err)
}
return RowsEquals(want, got)
}
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package engine
import (
"context"
"fmt"
"sync"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map

// TryStreamExecute implements the Primitive interface
func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
pt := newProbeTable(d.CheckCols)
var mu sync.Mutex

pt := newProbeTable(d.CheckCols)
err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error {
result := &sqltypes.Result{
Fields: input.Fields,
InsertID: input.InsertID,
}
mu.Lock()
defer mu.Unlock()
for _, row := range input.Rows {
exists, err := pt.exists(row)
if err != nil {
Expand Down
53 changes: 53 additions & 0 deletions go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,59 @@ func TestDistinct(t *testing.T) {
}
}

func TestDistinctStreamAsync(t *testing.T) {
distinct := &Distinct{
Source: &fakePrimitive{
results: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("myid|id|num|name", "varchar|int64|int64|varchar"),
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"z|1|1|a",
"a|1|1|t",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
),
async: true,
},
CheckCols: []CheckCol{
{Col: 0, Collation: collations.CollationUtf8mb4ID},
{Col: 1, Collation: collations.CollationBinaryID},
{Col: 2, Collation: collations.CollationBinaryID},
{Col: 3, Collation: collations.CollationUtf8mb4ID},
},
}

qr := &sqltypes.Result{}
err := distinct.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(result *sqltypes.Result) error {
qr.Rows = append(qr.Rows, result.Rows...)
return nil
})
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEqualsStr(`
[[VARCHAR("c") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("z") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("t")]]`, qr.Rows))
}

func TestWeightStringFallBack(t *testing.T) {
offsetOne := 1
checkCols := []CheckCol{{
Expand Down
51 changes: 49 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/sqltypes"
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand All @@ -41,6 +42,8 @@ type fakePrimitive struct {
log []string

allResultsInOneCall bool

async bool
}

func (f *fakePrimitive) Inputs() []Primitive {
Expand Down Expand Up @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
return f.sendErr
}

if f.async {
return f.asyncCall(callback)
}
return f.syncCall(wantfields, callback)
}

func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error {
readMoreResults := true
for readMoreResults && f.curResult < len(f.results) {
readMoreResults = f.allResultsInOneCall
Expand Down Expand Up @@ -116,9 +126,46 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
}
}
}

return nil
}

func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error {
var g errgroup.Group
var fields []*querypb.Field
if len(f.results) > 0 {
fields = f.results[0].Fields
}
for _, res := range f.results {
qr := res
g.Go(func() error {
if qr == nil {
return f.sendErr
}
if err := callback(&sqltypes.Result{Fields: fields}); err != nil {
return err
}
result := &sqltypes.Result{}
for i := 0; i < len(qr.Rows); i++ {
result.Rows = append(result.Rows, qr.Rows[i])
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
return err
}
result = &sqltypes.Result{}
}
}
if len(result.Rows) != 0 {
if err := callback(result); err != nil {
return err
}
}
return nil
})
}
return g.Wait()
}

func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
Expand Down
7 changes: 6 additions & 1 deletion go/vt/vtgate/engine/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -79,10 +80,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s

// TryStreamExecute satisfies the Primitive interface.
func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var mu sync.Mutex

env := evalengine.EnvWithBindVars(bindVars, vcursor.ConnCollation())
filter := func(results *sqltypes.Result) error {
var rows [][]sqltypes.Value
env.Fields = results.Fields

mu.Lock()
defer mu.Unlock()
for _, row := range results.Rows {
env.Row = row
evalResult, err := env.Evaluate(f.Predicate)
Expand Down
Loading

0 comments on commit e86b889

Please sign in to comment.