From 1689ea80cef51528ceca697df29514a04cb04f09 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 23 Oct 2023 11:02:09 +0200 Subject: [PATCH] comments Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/concatenate.go | 42 ++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 2692f34695d..1e8cb655547 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -238,34 +238,40 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin } func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error { + // Scoped context; any early exit triggers cancel() to clean up ongoing work. ctx, cancel := context.WithCancel(inCtx) defer cancel() + // Mutexes for dealing with concurrent access to shared state. var ( - muCallback sync.Mutex - muFields sync.Mutex - condFields = sync.NewCond(&muFields) - wg errgroup.Group - rest = make([]*sqltypes.Result, len(c.Sources)) - fields []*querypb.Field + muCallback sync.Mutex // Protects callback + muFields sync.Mutex // Protects field state + condFields = sync.NewCond(&muFields) // Condition var for field arrival + wg errgroup.Group // Wait group for all streaming goroutines + rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields + fields []*querypb.Field // Cached final field types ) + // Process each result chunk, considering type coercion. callback := func(res *sqltypes.Result, srcIdx int) error { muCallback.Lock() defer muCallback.Unlock() + // Check if type coercion needed for this source. + // We only need to check if fields are not in NoNeedToTypeCheck set. needsCoercion := false for idx, field := range rest[srcIdx].Fields { - _, ok := c.NoNeedToTypeCheck[idx] - if !ok && fields[idx].Type != field.Type { + _, skip := c.NoNeedToTypeCheck[idx] + if !skip && fields[idx].Type != field.Type { needsCoercion = true break } } + + // Apply type coercion if needed. if needsCoercion { for _, row := range res.Rows { - err := c.coerceValuesTo(row, fields) - if err != nil { + if err := c.coerceValuesTo(row, fields); err != nil { return err } } @@ -273,14 +279,20 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return in(res) } + // Start streaming query execution in parallel for all sources. for i, source := range c.Sources { currIndex, currSource := i, source wg.Go(func() error { err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error { + // Process fields when they arrive; coordinate field agreement across sources. if resultChunk.Fields != nil { muFields.Lock() + + // Capture the initial result chunk to determine field types later. if rest[currIndex] == nil { rest[currIndex] = resultChunk + + // If this was the last source to report its fields, derive the final output fields. if !slices.Contains(rest, nil) { muFields.Unlock() @@ -296,24 +308,25 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return callback(resultChunk, currIndex) } } - + // Wait for fields from all sources. for slices.Contains(rest, nil) { condFields.Wait() } muFields.Unlock() } - // If we get here, all the fields have been received + // Context check to avoid extra work. if ctx.Err() != nil { return nil } return callback(resultChunk, currIndex) }) + + // Error handling and context cleanup for this source. if err != nil { muFields.Lock() if rest[currIndex] == nil { - // In case we haven't received any fields yet, we need to set it - // empty, or otherwise we will keep waiting forever. + // Signal that this source is done, even if by failure, to unblock field waiting. rest[currIndex] = &sqltypes.Result{} } cancel() @@ -323,6 +336,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return err }) } + // Wait for all sources to complete. return wg.Wait() }