Skip to content

Commit

Permalink
vtgate/engine: Fix race condition in join logic (#14435)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Nov 2, 2023
1 parent 806f777 commit 369b6a1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
39 changes: 28 additions & 11 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -115,22 +116,31 @@ func bindvarForType(t querypb.Type) *querypb.BindVariable {

// TryStreamExecute performs a streaming exec.
func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var fieldNeeded atomic.Bool
fieldNeeded.Store(wantfields)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error {
var mu sync.Mutex
// We need to use this atomic since we're also reading this
// value outside of it being locked with the mu lock.
// This is still racy, but worst case it means that we may
// retrieve the right hand side fields twice instead of once.
var fieldsSent atomic.Bool
fieldsSent.Store(!wantfields)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
var rowSent atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error {
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), !fieldsSent.Load(), func(rresult *sqltypes.Result) error {
// This needs to be locking since it's not safe to just use
// fieldsSent. This is because we can't have a race between
// checking fieldsSent and then actually calling the callback
// and in parallel another goroutine doing the same. That
// can lead to out of order execution of the callback. So the callback
// itself and the check need to be covered by the same lock.
mu.Lock()
defer mu.Unlock()
result := &sqltypes.Result{}
if fieldNeeded.Load() {
// This code is currently unreachable because the first result
// will always be just the field info, which will cause the outer
// wantfields code path to be executed. But this may change in the future.
fieldNeeded.Store(false)
if fieldsSent.CompareAndSwap(false, true) {
result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
}
for _, rrow := range rresult.Rows {
Expand All @@ -154,8 +164,15 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
return callback(result)
}
}
if fieldNeeded.Load() {
fieldNeeded.Store(false)
// This needs to be locking since it's not safe to just use
// fieldsSent. This is because we can't have a race between
// checking fieldsSent and then actually calling the callback
// and in parallel another goroutine doing the same. That
// can lead to out of order execution of the callback. So the callback
// itself and the check need to be covered by the same lock.
mu.Lock()
defer mu.Unlock()
if fieldsSent.CompareAndSwap(false, true) {
for k := range jn.Vars {
joinVars[k] = sqltypes.NullBindVariable
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/scalar_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor
var mu sync.Mutex
var agg aggregationState
var fields []*querypb.Field
var fieldsSent bool
fieldsSent := !wantfields

err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error {
// as the underlying primitive call is not sync
Expand All @@ -121,7 +121,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor
mu.Lock()
defer mu.Unlock()

if agg == nil {
if agg == nil && len(result.Fields) != 0 {
var err error
agg, fields, err = newAggregation(result.Fields, sa.Aggregates)
if err != nil {
Expand Down

0 comments on commit 369b6a1

Please sign in to comment.