diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 4df774f9f5d..04d08245b0e 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -18,8 +18,10 @@ package engine import ( "context" + "slices" "sync" - "sync/atomic" + + "golang.org/x/sync/errgroup" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -238,19 +240,19 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error { ctx, cancel := context.WithCancel(inCtx) defer cancel() - var outerErr error - var cbMu sync.Mutex - var wg, fieldMu sync.WaitGroup - var fieldRec atomic.Int64 - fieldRec.Store(int64(len(c.Sources))) - fieldMu.Add(1) + var ( + muCallback sync.Mutex + muFields sync.Mutex + condFields = sync.NewCond(&muFields) + wg errgroup.Group + rest = make([]*sqltypes.Result, len(c.Sources)) + fields []*querypb.Field + ) - rest := make([]*sqltypes.Result, len(c.Sources)) - var fields []*querypb.Field callback := func(res *sqltypes.Result, srcIdx int) error { - cbMu.Lock() - defer cbMu.Unlock() + muCallback.Lock() + defer muCallback.Unlock() needsCoercion := false for idx, field := range rest[srcIdx].Fields { @@ -271,57 +273,50 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return in(res) } - once := sync.Once{} - for i, source := range c.Sources { - wg.Add(1) currIndex, currSource := i, source - - go func() { - defer wg.Done() + wg.Go(func() error { err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error { - // if we have fields to compare, make sure all the fields are all the same - if fieldRec.Load() > 0 && resultChunk.Fields != nil { - rest[currIndex] = resultChunk - res := fieldRec.Add(-1) - if res == 0 { - // We have received fields from all sources. We can now calculate the output types - var err error - fields, err = c.getFields(rest) - if err != nil { - return err + if resultChunk.Fields != nil { + muFields.Lock() + if rest[currIndex] == nil { + rest[currIndex] = resultChunk + if !slices.Contains(rest, nil) { + muFields.Unlock() + + // We have received fields from all sources. We can now calculate the output types + var err error + fields, err = c.getFields(rest) + if err != nil { + return err + } + resultChunk.Fields = fields + + defer condFields.Broadcast() + return callback(resultChunk, currIndex) } - resultChunk.Fields = fields - defer once.Do(func() { - fieldMu.Done() - }) - - return callback(resultChunk, currIndex) - } else { - fieldMu.Wait() } + + for slices.Contains(rest, nil) { + condFields.Wait() + } + muFields.Unlock() } // If we get here, all the fields have been received - select { - case <-ctx.Done(): + if ctx.Err() != nil { return nil - default: - return callback(resultChunk, currIndex) } + return callback(resultChunk, currIndex) }) if err != nil { - outerErr = err cancel() - once.Do(func() { - fieldMu.Done() - }) + condFields.Broadcast() } - }() - + return err + }) } - wg.Wait() - return outerErr + return wg.Wait() } func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error {