Skip to content

Commit

Permalink
fix data race in join engine primitive olap streaming mode execution (#…
Browse files Browse the repository at this point in the history
…14012)

Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal authored Sep 19, 2023
1 parent 723c3fc commit bed92ac
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
23 changes: 13 additions & 10 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"strings"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -114,34 +115,36 @@ func bindvarForType(t querypb.Type) *querypb.BindVariable {

// TryStreamExecute performs a streaming exec.
func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
joinVars := make(map[string]*querypb.BindVariable)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
var fieldNeeded atomic.Bool
fieldNeeded.Store(wantfields)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
rowSent := false
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), wantfields, func(rresult *sqltypes.Result) error {
var rowSent atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error {
result := &sqltypes.Result{}
if wantfields {
if fieldNeeded.Load() {
// This code is currently unreachable because the first result
// will always be just the field info, which will cause the outer
// wantfields code path to be executed. But this may change in the future.
wantfields = false
fieldNeeded.Store(false)
result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
}
for _, rrow := range rresult.Rows {
result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols))
}
if len(rresult.Rows) != 0 {
rowSent = true
rowSent.Store(true)
}
return callback(result)
})
if err != nil {
return err
}
if jn.Opcode == LeftJoin && !rowSent {
if jn.Opcode == LeftJoin && !rowSent.Load() {
result := &sqltypes.Result{}
result.Rows = [][]sqltypes.Value{joinRows(
lrow,
Expand All @@ -151,8 +154,8 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
return callback(result)
}
}
if wantfields {
wantfields = false
if fieldNeeded.Load() {
fieldNeeded.Store(false)
for k := range jn.Vars {
joinVars[k] = sqltypes.NullBindVariable
}
Expand Down
35 changes: 35 additions & 0 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4156,3 +4156,38 @@ func TestMain(m *testing.M) {
_flag.ParseFlagsForTest()
os.Exit(m.Run())
}

func TestStreamJoinQuery(t *testing.T) {
ctx := utils.LeakCheckContext(t)

// Special setup: Don't use createExecutorEnv.
cell := "aa"
hc := discovery.NewFakeHealthCheck(nil)
u := createSandbox(KsTestUnsharded)
s := createSandbox(KsTestSharded)
s.VSchema = executorVSchema
u.VSchema = unshardedVSchema
serv := newSandboxForCells(ctx, []string{cell})
resolver := newTestResolver(ctx, hc, serv, cell)
shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"}
for _, shard := range shards {
_ = hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil)
}
executor := createExecutor(ctx, serv, cell, resolver)
defer executor.Close()

sql := "select u.foo, u.apa, ue.bar, ue.apa from user u join user_extra ue on u.foo = ue.bar"
result, err := executorStream(ctx, executor, sql)
require.NoError(t, err)
wantResult := &sqltypes.Result{
Fields: append(sandboxconn.SingleRowResult.Fields, sandboxconn.SingleRowResult.Fields...),
}
wantRow := append(sandboxconn.StreamRowResult.Rows[0], sandboxconn.StreamRowResult.Rows[0]...)
for i := 0; i < 64; i++ {
wantResult.Rows = append(wantResult.Rows, wantRow)
}
require.Equal(t, len(wantResult.Rows), len(result.Rows))
for idx := 0; idx < 64; idx++ {
utils.MustMatch(t, wantResult.Rows[idx], result.Rows[idx], "mismatched on: ", strconv.Itoa(idx))
}
}

0 comments on commit bed92ac

Please sign in to comment.