From fd5f549acb63e592617f732c3d6f21ac6ed9615d Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Wed, 22 Nov 2023 23:43:51 +0530 Subject: [PATCH 1/5] fix: take lock on critical section on engine primitive stream execute apis Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/distinct.go | 6 +- go/vt/vtgate/engine/fake_primitive_test.go | 44 +++++++++++++- go/vt/vtgate/engine/filter.go | 6 ++ go/vt/vtgate/engine/limit.go | 4 ++ go/vt/vtgate/engine/limit_test.go | 67 ++++++++++++++++++++++ go/vt/vtgate/engine/memory_sort.go | 5 ++ go/vt/vtgate/engine/projection.go | 3 + 7 files changed, 132 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 7e55138e27e..2d263464a2e 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -19,6 +19,7 @@ package engine import ( "context" "fmt" + "sync" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map // TryStreamExecute implements the Primitive interface func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - pt := newProbeTable(d.CheckCols) + var mu sync.Mutex + pt := newProbeTable(d.CheckCols) err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error { result := &sqltypes.Result{ Fields: input.Fields, InsertID: input.InsertID, } + mu.Lock() + defer mu.Unlock() for _, row := range input.Rows { exists, err := pt.exists(row) if err != nil { diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index dcec32f1ffd..b76658c97fc 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -23,8 +23,9 @@ import ( "strings" "testing" - "vitess.io/vitess/go/sqltypes" + "golang.org/x/sync/errgroup" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -41,6 +42,8 @@ type fakePrimitive struct { log []string allResultsInOneCall bool + + async bool } func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) { @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b return f.sendErr } + if f.async { + return f.asyncCall(callback) + } + return f.syncCall(wantfields, callback) +} + +func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error { readMoreResults := true for readMoreResults && f.curResult < len(f.results) { readMoreResults = f.allResultsInOneCall @@ -116,9 +126,39 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b } } } - return nil } + +func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error { + var g errgroup.Group + for _, res := range f.results { + qr := res + g.Go(func() error { + if qr == nil { + return f.sendErr + } + result := &sqltypes.Result{} + for i := 0; i < len(qr.Rows); i++ { + result.Rows = append(result.Rows, qr.Rows[i]) + // Send only two rows at a time. + if i%2 == 1 { + if err := callback(result); err != nil { + return err + } + result = &sqltypes.Result{} + } + } + if len(result.Rows) != 0 { + if err := callback(result); err != nil { + return err + } + } + return nil + }) + } + return g.Wait() +} + func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars))) return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */) diff --git a/go/vt/vtgate/engine/filter.go b/go/vt/vtgate/engine/filter.go index c0a54f2b6ac..78e9dde4ee2 100644 --- a/go/vt/vtgate/engine/filter.go +++ b/go/vt/vtgate/engine/filter.go @@ -18,6 +18,7 @@ package engine import ( "context" + "sync" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -78,9 +79,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s // TryStreamExecute satisfies the Primitive interface. func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + var mu sync.Mutex + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) filter := func(results *sqltypes.Result) error { var rows [][]sqltypes.Value + + mu.Lock() + defer mu.Unlock() for _, row := range results.Rows { env.Row = row evalResult, err := env.Evaluate(f.Predicate) diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 4ef809ad1fa..58e2ec5780d 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "strconv" + "sync" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -97,6 +98,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars // the offset in memory from the result of the scatter query with count + offset. bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset)) + var mu sync.Mutex err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { if len(qr.Fields) != 0 { if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil { @@ -108,6 +110,8 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return nil } + mu.Lock() + defer mu.Unlock() // we've still not seen all rows we need to see before we can return anything to the client if offset > 0 { if inputSize <= offset { diff --git a/go/vt/vtgate/engine/limit_test.go b/go/vt/vtgate/engine/limit_test.go index 15bda20ace7..8b91dadecb5 100644 --- a/go/vt/vtgate/engine/limit_test.go +++ b/go/vt/vtgate/engine/limit_test.go @@ -451,6 +451,73 @@ func TestLimitStreamExecute(t *testing.T) { } } +func TestLimitStreamExecuteAsync(t *testing.T) { + bindVars := make(map[string]*querypb.BindVariable) + fields := sqltypes.MakeTestFields( + "col1|col2", + "int64|varchar", + ) + inputResults := sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "d|3", + "e|4", + "a|1", + "b|2", + "d|3", + "e|4", + "---", + "c|7", + "x|8", + "y|9", + "c|7", + "x|8", + "y|9", + "c|7", + "x|8", + "y|9", + "---", + "l|4", + "m|5", + "n|6", + "l|4", + "m|5", + "n|6", + "l|4", + "m|5", + "n|6", + ) + fp := &fakePrimitive{ + results: inputResults, + async: true, + } + + const maxCount = 26 + for i := 0; i <= maxCount*20; i++ { + expRows := i + l := &Limit{ + Count: evalengine.NewLiteralInt(int64(expRows)), + Input: fp, + } + // Test with limit smaller than input. + results := &sqltypes.Result{} + + err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error { + if qr != nil { + results.Rows = append(results.Rows, qr.Rows...) + } + return nil + }) + require.NoError(t, err) + if expRows > maxCount { + expRows = maxCount + } + require.Len(t, results.Rows, expRows) + } + +} + func TestOffsetStreamExecute(t *testing.T) { bindVars := make(map[string]*querypb.BindVariable) fields := sqltypes.MakeTestFields( diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index b896b303923..c5511575fd2 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -23,6 +23,7 @@ import ( "reflect" "strconv" "strings" + "sync" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -101,12 +102,16 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin Compare: ms.OrderBy, Limit: count, } + + var mu sync.Mutex err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { if len(qr.Fields) != 0 { if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil { return err } } + mu.Lock() + defer mu.Unlock() for _, row := range qr.Rows { sorter.Push(row) } diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 166dd88f477..12c28b185da 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -88,6 +88,7 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) var once sync.Once var fields []*querypb.Field + var mu sync.Mutex return vcursor.StreamExecutePrimitive(ctx, p.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { var err error if wantfields { @@ -107,6 +108,8 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind return err } resultRows := make([]sqltypes.Row, 0, len(qr.Rows)) + mu.Lock() + defer mu.Unlock() for _, r := range qr.Rows { resultRow := make(sqltypes.Row, 0, len(p.Exprs)) env.Row = r From 932cfdde30a2f81c712637a9aa90d0eee88a7188 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 23 Nov 2023 11:31:14 +0530 Subject: [PATCH 2/5] test: added filter and memory sort aynsc call test Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/filter_test.go | 60 +++++++++++++++++++++++++ go/vt/vtgate/engine/memory_sort_test.go | 54 ++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/go/vt/vtgate/engine/filter_test.go b/go/vt/vtgate/engine/filter_test.go index 9a8335e4d7e..7d3c2cd0696 100644 --- a/go/vt/vtgate/engine/filter_test.go +++ b/go/vt/vtgate/engine/filter_test.go @@ -83,3 +83,63 @@ func TestFilterPass(t *testing.T) { }) } } + +func TestFilterStreaming(t *testing.T) { + utf8mb4Bin := collationEnv.LookupByName("utf8mb4_bin") + predicate := &sqlparser.ComparisonExpr{ + Operator: sqlparser.GreaterThanOp, + Left: sqlparser.NewColName("left"), + Right: sqlparser.NewColName("right"), + } + + tcases := []struct { + name string + res []*sqltypes.Result + expRes string + }{{ + name: "int32", + res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "int32|int32"), "0|1", "---", "1|0", "2|3"), + expRes: `[[INT32(1) INT32(0)]]`, + }, { + name: "uint16", + res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint16|uint16"), "0|1", "1|0", "---", "2|3"), + expRes: `[[UINT16(1) UINT16(0)]]`, + }, { + name: "uint64_int64", + res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint64|int64"), "0|1", "---", "1|0", "2|3"), + expRes: `[[UINT64(1) INT64(0)]]`, + }, { + name: "int32_uint32", + res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "int32|uint32"), "0|1", "---", "1|0", "---", "2|3"), + expRes: `[[INT32(1) UINT32(0)]]`, + }, { + name: "uint16_int8", + res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint16|int8"), "0|1", "1|0", "2|3", "---"), + expRes: `[[UINT16(1) INT8(0)]]`, + }, { + name: "uint64_int32", + res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint64|int32"), "0|1", "1|0", "2|3", "---", "0|1", "1|3", "5|3"), + expRes: `[[UINT64(1) INT32(0)] [UINT64(5) INT32(3)]]`, + }} + for _, tc := range tcases { + t.Run(tc.name, func(t *testing.T) { + pred, err := evalengine.Translate(predicate, &evalengine.Config{ + Collation: utf8mb4Bin, + ResolveColumn: evalengine.FieldResolver(tc.res[0].Fields).Column, + }) + require.NoError(t, err) + + filter := &Filter{ + Predicate: pred, + Input: &fakePrimitive{results: tc.res, async: true}, + } + qr := &sqltypes.Result{} + err = filter.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(result *sqltypes.Result) error { + qr.Rows = append(qr.Rows, result.Rows...) + return nil + }) + require.NoError(t, err) + require.NoError(t, sqltypes.RowsEqualsStr(tc.expRes, qr.Rows)) + }) + } +} diff --git a/go/vt/vtgate/engine/memory_sort_test.go b/go/vt/vtgate/engine/memory_sort_test.go index 4f3d39ca851..3ec7b247029 100644 --- a/go/vt/vtgate/engine/memory_sort_test.go +++ b/go/vt/vtgate/engine/memory_sort_test.go @@ -657,3 +657,57 @@ func TestMemorySortExecuteNoVarChar(t *testing.T) { t.Errorf("StreamExecute err: %v, want %v", err, want) } } + +func TestMemorySortStreamAsync(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2", + "varbinary|decimal", + ) + fp := &fakePrimitive{ + results: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "g|2", + "a|1", + "---", + "c|3", + "g|2", + "a|1", + "---", + "c|4", + "c|3", + "g|2", + "a|1", + "---", + "c|4", + "c|3", + "g|2", + "a|1", + "---", + "c|4", + "c|3", + ), + async: true, + } + + ms := &MemorySort{ + OrderBy: []evalengine.OrderByParams{{ + WeightStringCol: -1, + Col: 1, + }}, + Input: fp, + } + + qr := &sqltypes.Result{} + err := ms.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(res *sqltypes.Result) error { + qr.Rows = append(qr.Rows, res.Rows...) + return nil + }) + require.NoError(t, err) + require.NoError(t, sqltypes.RowsEqualsStr( + `[[VARBINARY("a") DECIMAL(1)] [VARBINARY("a") DECIMAL(1)] [VARBINARY("a") DECIMAL(1)] [VARBINARY("a") DECIMAL(1)] [VARBINARY("a") DECIMAL(1)] +[VARBINARY("g") DECIMAL(2)] [VARBINARY("g") DECIMAL(2)] [VARBINARY("g") DECIMAL(2)] [VARBINARY("g") DECIMAL(2)] +[VARBINARY("c") DECIMAL(3)] [VARBINARY("c") DECIMAL(3)] [VARBINARY("c") DECIMAL(3)] [VARBINARY("c") DECIMAL(3)] +[VARBINARY("c") DECIMAL(4)] [VARBINARY("c") DECIMAL(4)] [VARBINARY("c") DECIMAL(4)]]`, + qr.Rows)) +} From 40878b9e370e6524955bd134498ad77cb2c09216 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 23 Nov 2023 13:12:24 +0530 Subject: [PATCH 3/5] fix: projection to check for results field before evaluating, added projectio and distinct streaming async test Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/distinct_test.go | 53 +++++++++++++++++++++++++ go/vt/vtgate/engine/projection.go | 2 +- go/vt/vtgate/engine/projection_test.go | 55 +++++++++++++++++++++++--- 3 files changed, 103 insertions(+), 7 deletions(-) diff --git a/go/vt/vtgate/engine/distinct_test.go b/go/vt/vtgate/engine/distinct_test.go index 4ec3c1c0f4a..a1591403f92 100644 --- a/go/vt/vtgate/engine/distinct_test.go +++ b/go/vt/vtgate/engine/distinct_test.go @@ -130,6 +130,59 @@ func TestDistinct(t *testing.T) { } } +func TestDistinctStreamAsync(t *testing.T) { + distinct := &Distinct{ + Source: &fakePrimitive{ + results: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("myid|id|num|name", "varchar|int64|int64|varchar"), + "a|1|1|a", + "a|1|1|a", + "a|1|1|a", + "a|1|1|a", + "---", + "c|1|1|a", + "a|1|1|a", + "z|1|1|a", + "a|1|1|t", + "a|1|1|a", + "a|1|1|a", + "a|1|1|a", + "---", + "c|1|1|a", + "a|1|1|a", + "---", + "c|1|1|a", + "a|1|1|a", + "a|1|1|a", + "c|1|1|a", + "a|1|1|a", + "a|1|1|a", + "---", + "c|1|1|a", + "a|1|1|a", + ), + async: true, + }, + CheckCols: []CheckCol{ + {Col: 0, Type: evalengine.NewType(sqltypes.VarChar, collations.CollationUtf8mb4ID)}, + {Col: 1, Type: evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)}, + {Col: 2, Type: evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)}, + {Col: 3, Type: evalengine.NewType(sqltypes.VarChar, collations.CollationUtf8mb4ID)}, + }, + } + + qr := &sqltypes.Result{} + err := distinct.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(result *sqltypes.Result) error { + qr.Rows = append(qr.Rows, result.Rows...) + return nil + }) + require.NoError(t, err) + require.NoError(t, sqltypes.RowsEqualsStr(` +[[VARCHAR("c") INT64(1) INT64(1) VARCHAR("a")] +[VARCHAR("a") INT64(1) INT64(1) VARCHAR("a")] +[VARCHAR("z") INT64(1) INT64(1) VARCHAR("a")] +[VARCHAR("a") INT64(1) INT64(1) VARCHAR("t")]]`, qr.Rows)) +} + func TestWeightStringFallBack(t *testing.T) { offsetOne := 1 checkCols := []CheckCol{{ diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 12c28b185da..db493aee53b 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -91,7 +91,7 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind var mu sync.Mutex return vcursor.StreamExecutePrimitive(ctx, p.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { var err error - if wantfields { + if wantfields && qr.Fields != nil { once.Do(func() { fields, err = p.evalFields(env, qr.Fields) if err != nil { diff --git a/go/vt/vtgate/engine/projection_test.go b/go/vt/vtgate/engine/projection_test.go index 37d1730e2e1..2d260e901ea 100644 --- a/go/vt/vtgate/engine/projection_test.go +++ b/go/vt/vtgate/engine/projection_test.go @@ -72,6 +72,49 @@ func TestMultiply(t *testing.T) { assert.Equal(t, "[[UINT64(6)] [UINT64(0)] [UINT64(2)]]", fmt.Sprintf("%v", qr.Rows)) } +func TestProjectionStreaming(t *testing.T) { + expr := &sqlparser.BinaryExpr{ + Operator: sqlparser.MultOp, + Left: &sqlparser.Offset{V: 0}, + Right: &sqlparser.Offset{V: 1}, + } + evalExpr, err := evalengine.Translate(expr, nil) + require.NoError(t, err) + fp := &fakePrimitive{ + results: sqltypes.MakeTestStreamingResults( + sqltypes.MakeTestFields("a|b", "uint64|uint64"), + "3|2", + "1|0", + "6|2", + "---", + "3|2", + "---", + "1|0", + "---", + "1|2", + "4|2", + "---", + "5|5", + "4|10", + ), + async: true, + } + proj := &Projection{ + Cols: []string{"apa"}, + Exprs: []evalengine.Expr{evalExpr}, + Input: fp, + } + + qr := &sqltypes.Result{} + err = proj.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(result *sqltypes.Result) error { + qr.Rows = append(qr.Rows, result.Rows...) + return nil + }) + require.NoError(t, err) + require.NoError(t, sqltypes.RowsEqualsStr(`[[UINT64(25)] [UINT64(40)] [UINT64(6)] [UINT64(2)] [UINT64(8)] [UINT64(0)] [UINT64(6)] [UINT64(0)] [UINT64(12)]]`, + qr.Rows)) +} + func TestEmptyInput(t *testing.T) { expr := &sqlparser.BinaryExpr{ Operator: sqlparser.MultOp, @@ -93,18 +136,18 @@ func TestEmptyInput(t *testing.T) { require.NoError(t, err) assert.Equal(t, "[]", fmt.Sprintf("%v", qr.Rows)) - //fp = &fakePrimitive{ + // fp = &fakePrimitive{ // results: []*sqltypes.Result{sqltypes.MakeTestResult( // sqltypes.MakeTestFields("a|b", "uint64|uint64"), // "3|2", // "1|0", // "1|2", // )}, - //} - //proj.Input = fp - //qr, err = wrapStreamExecute(proj, newNoopVCursor(context.Background()), nil, true) - //require.NoError(t, err) - //assert.Equal(t, "[[UINT64(6)] [UINT64(0)] [UINT64(2)]]", fmt.Sprintf("%v", qr.Rows)) + // } + // proj.Input = fp + // qr, err = wrapStreamExecute(proj, newNoopVCursor(context.Background()), nil, true) + // require.NoError(t, err) + // assert.Equal(t, "[[UINT64(6)] [UINT64(0)] [UINT64(2)]]", fmt.Sprintf("%v", qr.Rows)) } func TestHexAndBinaryArgument(t *testing.T) { From 745b0f744d80e87710bb2fe6d4183a7fa6de62ea Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 23 Nov 2023 10:42:07 +0100 Subject: [PATCH 4/5] Fix the test primitive and locking This ensures that we project fields properly and removes racy behavior from the limit logic as well by moving up the lock. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/engine/fake_primitive_test.go | 7 +++++++ go/vt/vtgate/engine/limit.go | 6 +++--- go/vt/vtgate/engine/projection.go | 6 +++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index b76658c97fc..e992c2a4623 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -131,12 +131,19 @@ func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error { var g errgroup.Group + var fields []*querypb.Field + if len(f.results) > 0 { + fields = f.results[0].Fields + } for _, res := range f.results { qr := res g.Go(func() error { if qr == nil { return f.sendErr } + if err := callback(&sqltypes.Result{Fields: fields}); err != nil { + return err + } result := &sqltypes.Result{} for i := 0; i < len(qr.Rows); i++ { result.Rows = append(result.Rows, qr.Rows[i]) diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 58e2ec5780d..8be186f66bd 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -100,7 +100,9 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars var mu sync.Mutex err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { - if len(qr.Fields) != 0 { + mu.Lock() + defer mu.Unlock() + if wantfields && len(qr.Fields) != 0 { if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil { return err } @@ -110,8 +112,6 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return nil } - mu.Lock() - defer mu.Unlock() // we've still not seen all rows we need to see before we can return anything to the client if offset > 0 { if inputSize <= offset { diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index db493aee53b..aaa49061c8e 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -91,7 +91,9 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind var mu sync.Mutex return vcursor.StreamExecutePrimitive(ctx, p.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { var err error - if wantfields && qr.Fields != nil { + mu.Lock() + defer mu.Unlock() + if wantfields { once.Do(func() { fields, err = p.evalFields(env, qr.Fields) if err != nil { @@ -108,8 +110,6 @@ func (p *Projection) TryStreamExecute(ctx context.Context, vcursor VCursor, bind return err } resultRows := make([]sqltypes.Row, 0, len(qr.Rows)) - mu.Lock() - defer mu.Unlock() for _, r := range qr.Rows { resultRow := make(sqltypes.Row, 0, len(p.Exprs)) env.Row = r From 9e338a384806fff792bedf22a36fb174ef1851f5 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 23 Nov 2023 10:52:19 +0100 Subject: [PATCH 5/5] Move the lock up Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/engine/memory_sort.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index c5511575fd2..8a4cd9188ac 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -105,13 +105,13 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin var mu sync.Mutex err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { + mu.Lock() + defer mu.Unlock() if len(qr.Fields) != 0 { if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil { return err } } - mu.Lock() - defer mu.Unlock() for _, row := range qr.Rows { sorter.Push(row) }