Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Oct 23, 2023
1 parent c77b904 commit 1689ea8
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,49 +238,61 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
}

func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error {
// Scoped context; any early exit triggers cancel() to clean up ongoing work.
ctx, cancel := context.WithCancel(inCtx)
defer cancel()

// Mutexes for dealing with concurrent access to shared state.
var (
muCallback sync.Mutex
muFields sync.Mutex
condFields = sync.NewCond(&muFields)
wg errgroup.Group
rest = make([]*sqltypes.Result, len(c.Sources))
fields []*querypb.Field
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fields []*querypb.Field // Cached final field types
)

// Process each result chunk, considering type coercion.
callback := func(res *sqltypes.Result, srcIdx int) error {
muCallback.Lock()
defer muCallback.Unlock()

// Check if type coercion needed for this source.
// We only need to check if fields are not in NoNeedToTypeCheck set.
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, ok := c.NoNeedToTypeCheck[idx]
if !ok && fields[idx].Type != field.Type {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fields[idx].Type != field.Type {
needsCoercion = true
break
}
}

// Apply type coercion if needed.
if needsCoercion {
for _, row := range res.Rows {
err := c.coerceValuesTo(row, fields)
if err != nil {
if err := c.coerceValuesTo(row, fields); err != nil {
return err
}
}
}
return in(res)
}

// Start streaming query execution in parallel for all sources.
for i, source := range c.Sources {
currIndex, currSource := i, source
wg.Go(func() error {
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error {
// Process fields when they arrive; coordinate field agreement across sources.
if resultChunk.Fields != nil {
muFields.Lock()

// Capture the initial result chunk to determine field types later.
if rest[currIndex] == nil {
rest[currIndex] = resultChunk

// If this was the last source to report its fields, derive the final output fields.
if !slices.Contains(rest, nil) {
muFields.Unlock()

Expand All @@ -296,24 +308,25 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
return callback(resultChunk, currIndex)
}
}

// Wait for fields from all sources.
for slices.Contains(rest, nil) {
condFields.Wait()
}
muFields.Unlock()
}

// If we get here, all the fields have been received
// Context check to avoid extra work.
if ctx.Err() != nil {
return nil
}
return callback(resultChunk, currIndex)
})

// Error handling and context cleanup for this source.
if err != nil {
muFields.Lock()
if rest[currIndex] == nil {
// In case we haven't received any fields yet, we need to set it
// empty, or otherwise we will keep waiting forever.
// Signal that this source is done, even if by failure, to unblock field waiting.
rest[currIndex] = &sqltypes.Result{}
}
cancel()
Expand All @@ -323,6 +336,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
return err
})
}
// Wait for all sources to complete.
return wg.Wait()
}

Expand Down

0 comments on commit 1689ea8

Please sign in to comment.