diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 4df774f9f5d..1e8cb655547 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -18,8 +18,10 @@ package engine import ( "context" + "slices" "sync" - "sync/atomic" + + "golang.org/x/sync/errgroup" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -236,34 +238,40 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin } func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error { + // Scoped context; any early exit triggers cancel() to clean up ongoing work. ctx, cancel := context.WithCancel(inCtx) defer cancel() - var outerErr error - - var cbMu sync.Mutex - var wg, fieldMu sync.WaitGroup - var fieldRec atomic.Int64 - fieldRec.Store(int64(len(c.Sources))) - fieldMu.Add(1) - rest := make([]*sqltypes.Result, len(c.Sources)) - var fields []*querypb.Field + // Mutexes for dealing with concurrent access to shared state. + var ( + muCallback sync.Mutex // Protects callback + muFields sync.Mutex // Protects field state + condFields = sync.NewCond(&muFields) // Condition var for field arrival + wg errgroup.Group // Wait group for all streaming goroutines + rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields + fields []*querypb.Field // Cached final field types + ) + + // Process each result chunk, considering type coercion. callback := func(res *sqltypes.Result, srcIdx int) error { - cbMu.Lock() - defer cbMu.Unlock() + muCallback.Lock() + defer muCallback.Unlock() + // Check if type coercion needed for this source. + // We only need to check if fields are not in NoNeedToTypeCheck set. needsCoercion := false for idx, field := range rest[srcIdx].Fields { - _, ok := c.NoNeedToTypeCheck[idx] - if !ok && fields[idx].Type != field.Type { + _, skip := c.NoNeedToTypeCheck[idx] + if !skip && fields[idx].Type != field.Type { needsCoercion = true break } } + + // Apply type coercion if needed. if needsCoercion { for _, row := range res.Rows { - err := c.coerceValuesTo(row, fields) - if err != nil { + if err := c.coerceValuesTo(row, fields); err != nil { return err } } @@ -271,57 +279,65 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return in(res) } - once := sync.Once{} - + // Start streaming query execution in parallel for all sources. for i, source := range c.Sources { - wg.Add(1) currIndex, currSource := i, source - - go func() { - defer wg.Done() + wg.Go(func() error { err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error { - // if we have fields to compare, make sure all the fields are all the same - if fieldRec.Load() > 0 && resultChunk.Fields != nil { - rest[currIndex] = resultChunk - res := fieldRec.Add(-1) - if res == 0 { - // We have received fields from all sources. We can now calculate the output types - var err error - fields, err = c.getFields(rest) - if err != nil { - return err + // Process fields when they arrive; coordinate field agreement across sources. + if resultChunk.Fields != nil { + muFields.Lock() + + // Capture the initial result chunk to determine field types later. + if rest[currIndex] == nil { + rest[currIndex] = resultChunk + + // If this was the last source to report its fields, derive the final output fields. + if !slices.Contains(rest, nil) { + muFields.Unlock() + + // We have received fields from all sources. We can now calculate the output types + var err error + fields, err = c.getFields(rest) + if err != nil { + return err + } + resultChunk.Fields = fields + + defer condFields.Broadcast() + return callback(resultChunk, currIndex) } - resultChunk.Fields = fields - defer once.Do(func() { - fieldMu.Done() - }) - - return callback(resultChunk, currIndex) - } else { - fieldMu.Wait() } + // Wait for fields from all sources. + for slices.Contains(rest, nil) { + condFields.Wait() + } + muFields.Unlock() } - // If we get here, all the fields have been received - select { - case <-ctx.Done(): + // Context check to avoid extra work. + if ctx.Err() != nil { return nil - default: - return callback(resultChunk, currIndex) } + return callback(resultChunk, currIndex) }) + + // Error handling and context cleanup for this source. if err != nil { - outerErr = err + muFields.Lock() + if rest[currIndex] == nil { + // Signal that this source is done, even if by failure, to unblock field waiting. + rest[currIndex] = &sqltypes.Result{} + } cancel() - once.Do(func() { - fieldMu.Done() - }) + condFields.Broadcast() + muFields.Unlock() } - }() - + return err + }) } - wg.Wait() - return outerErr + // Wait for all sources to complete. + return wg.Wait() } func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error {