From 369b6a1e55aecd98c3cf6d4366cfbcee0477c474 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 2 Nov 2023 16:58:28 +0100 Subject: [PATCH] vtgate/engine: Fix race condition in join logic (#14435) Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/engine/join.go | 39 ++++++++++++++++------- go/vt/vtgate/engine/scalar_aggregation.go | 4 +-- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 1c3adc1f5c9..ef50389c989 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "strings" + "sync" "sync/atomic" "vitess.io/vitess/go/sqltypes" @@ -115,22 +116,31 @@ func bindvarForType(t querypb.Type) *querypb.BindVariable { // TryStreamExecute performs a streaming exec. func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - var fieldNeeded atomic.Bool - fieldNeeded.Store(wantfields) - err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error { + var mu sync.Mutex + // We need to use this atomic since we're also reading this + // value outside of it being locked with the mu lock. + // This is still racy, but worst case it means that we may + // retrieve the right hand side fields twice instead of once. + var fieldsSent atomic.Bool + fieldsSent.Store(!wantfields) + err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { joinVars := make(map[string]*querypb.BindVariable) for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } var rowSent atomic.Bool - err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error { + err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), !fieldsSent.Load(), func(rresult *sqltypes.Result) error { + // This needs to be locking since it's not safe to just use + // fieldsSent. This is because we can't have a race between + // checking fieldsSent and then actually calling the callback + // and in parallel another goroutine doing the same. That + // can lead to out of order execution of the callback. So the callback + // itself and the check need to be covered by the same lock. + mu.Lock() + defer mu.Unlock() result := &sqltypes.Result{} - if fieldNeeded.Load() { - // This code is currently unreachable because the first result - // will always be just the field info, which will cause the outer - // wantfields code path to be executed. But this may change in the future. - fieldNeeded.Store(false) + if fieldsSent.CompareAndSwap(false, true) { result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) } for _, rrow := range rresult.Rows { @@ -154,8 +164,15 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return callback(result) } } - if fieldNeeded.Load() { - fieldNeeded.Store(false) + // This needs to be locking since it's not safe to just use + // fieldsSent. This is because we can't have a race between + // checking fieldsSent and then actually calling the callback + // and in parallel another goroutine doing the same. That + // can lead to out of order execution of the callback. So the callback + // itself and the check need to be covered by the same lock. + mu.Lock() + defer mu.Unlock() + if fieldsSent.CompareAndSwap(false, true) { for k := range jn.Vars { joinVars[k] = sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/engine/scalar_aggregation.go b/go/vt/vtgate/engine/scalar_aggregation.go index 6190e2e5fd6..85e90420ff9 100644 --- a/go/vt/vtgate/engine/scalar_aggregation.go +++ b/go/vt/vtgate/engine/scalar_aggregation.go @@ -112,7 +112,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor var mu sync.Mutex var agg aggregationState var fields []*querypb.Field - var fieldsSent bool + fieldsSent := !wantfields err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error { // as the underlying primitive call is not sync @@ -121,7 +121,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor mu.Lock() defer mu.Unlock() - if agg == nil { + if agg == nil && len(result.Fields) != 0 { var err error agg, fields, err = newAggregation(result.Fields, sa.Aggregates) if err != nil {