Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release-18.0] vtgate/engine: Fix race condition in join logic (#14435) #14441

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -115,22 +116,31 @@ func bindvarForType(t querypb.Type) *querypb.BindVariable {

// TryStreamExecute performs a streaming exec.
func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var fieldNeeded atomic.Bool
fieldNeeded.Store(wantfields)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error {
var mu sync.Mutex
// We need to use this atomic since we're also reading this
// value outside of it being locked with the mu lock.
// This is still racy, but worst case it means that we may
// retrieve the right hand side fields twice instead of once.
var fieldsSent atomic.Bool
fieldsSent.Store(!wantfields)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
var rowSent atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error {
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), !fieldsSent.Load(), func(rresult *sqltypes.Result) error {
// This needs to be locking since it's not safe to just use
// fieldsSent. This is because we can't have a race between
// checking fieldsSent and then actually calling the callback
// and in parallel another goroutine doing the same. That
// can lead to out of order execution of the callback. So the callback
// itself and the check need to be covered by the same lock.
mu.Lock()
defer mu.Unlock()
result := &sqltypes.Result{}
if fieldNeeded.Load() {
// This code is currently unreachable because the first result
// will always be just the field info, which will cause the outer
// wantfields code path to be executed. But this may change in the future.
fieldNeeded.Store(false)
if fieldsSent.CompareAndSwap(false, true) {
result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
}
for _, rrow := range rresult.Rows {
Expand All @@ -154,8 +164,15 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
return callback(result)
}
}
if fieldNeeded.Load() {
fieldNeeded.Store(false)
// This needs to be locking since it's not safe to just use
// fieldsSent. This is because we can't have a race between
// checking fieldsSent and then actually calling the callback
// and in parallel another goroutine doing the same. That
// can lead to out of order execution of the callback. So the callback
// itself and the check need to be covered by the same lock.
mu.Lock()
defer mu.Unlock()
if fieldsSent.CompareAndSwap(false, true) {
for k := range jn.Vars {
joinVars[k] = sqltypes.NullBindVariable
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/scalar_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor
var mu sync.Mutex
var agg aggregationState
var fields []*querypb.Field
var fieldsSent bool
fieldsSent := !wantfields

err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error {
// as the underlying primitive call is not sync
Expand All @@ -121,7 +121,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor
mu.Lock()
defer mu.Unlock()

if agg == nil {
if agg == nil && len(result.Fields) != 0 {
var err error
agg, fields, err = newAggregation(result.Fields, sa.Aggregates)
if err != nil {
Expand Down
Loading