Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Data race in semi-join #17417

Merged
merged 4 commits into from
Dec 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
@@ -574,3 +574,23 @@ func TestTimeZones(t *testing.T) {
})
}
}

// TestSemiJoin tests that the semi join works as intended.
func TestSemiJoin(t *testing.T) {
mcmp, closer := start(t)
defer closer()

for i := 1; i <= 1000; i++ {
mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i))
mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i))
}

// Test that the semi join works as intended
for _, mode := range []string{"oltp", "olap"} {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1")
})
}
}
7 changes: 5 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
@@ -40,7 +40,8 @@ type fakePrimitive struct {
// sendErr is sent at the end of the stream if it's set.
sendErr error

log []string
noLog bool
log []string

allResultsInOneCall bool

@@ -85,7 +86,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar
}

func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
if !f.noLog {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
}
if f.results == nil {
return f.sendErr
}
13 changes: 8 additions & 5 deletions go/vt/vtgate/engine/semi_join.go
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ package engine

import (
"context"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
@@ -62,24 +63,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma

// TryStreamExecute performs a streaming exec.
func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
joinVars := make(map[string]*querypb.BindVariable)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
result := &sqltypes.Result{Fields: lresult.Fields}
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
rowAdded := false
var rowAdded atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error {
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
if len(rresult.Rows) > 0 && !rowAdded {
result.Rows = append(result.Rows, lrow)
rowAdded = true
if len(rresult.Rows) > 0 {
rowAdded.Store(true)
}
return nil
})
if err != nil {
return err
}
if rowAdded.Load() {
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
result.Rows = append(result.Rows, lrow)
}
}
return callback(result)
})
79 changes: 79 additions & 0 deletions go/vt/vtgate/engine/semi_join_test.go
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"
"testing"

"vitess.io/vitess/go/test/utils"
@@ -159,3 +160,81 @@ func TestSemiJoinStreamExecute(t *testing.T) {
"4|d|dd",
))
}

// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution
// to ensure we have no data races.
func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) {
leftPrim := &fakePrimitive{
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
), sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"3|c|cc",
"4|d|dd",
),
},
async: true,
}
rightFields := sqltypes.MakeTestFields(
"col4|col5|col6",
"int64|varchar|varchar",
)
rightPrim := &fakePrimitive{
// we'll return non-empty results for rows 2 and 4
results: sqltypes.MakeTestStreamingResults(rightFields,
"4|d|dd",
"---",
"---",
"5|e|ee",
"6|f|ff",
"7|g|gg",
),
async: true,
noLog: true,
}

jn := &SemiJoin{
Left: leftPrim,
Right: rightPrim,
Vars: map[string]int{
"bv": 1,
},
}
var res *sqltypes.Result
var mu sync.Mutex
err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error {
mu.Lock()
defer mu.Unlock()
if res == nil {
res = result
} else {
res.Rows = append(res.Rows, result.Rows...)
}
return nil
})
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`StreamExecute true`,
})
// We'll get all the rows back in left primitive, since we're returning the same set of rows
// from the right primitive that makes them all qualify.
expectResultAnyOrder(t, res, sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
"4|d|dd",
))
}