Skip to content

Commit

Permalink
feat: simplify the code for parallel execution
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Mar 19, 2024
1 parent 661fe19 commit 05b33a9
Showing 1 changed file with 55 additions and 72 deletions.
127 changes: 55 additions & 72 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ package engine

import (
"context"
"slices"
"sync"
"sync/atomic"

"golang.org/x/sync/errgroup"

Expand All @@ -40,6 +40,13 @@ type Concatenate struct {
// These column offsets do not need to be typed checked - they usually contain weight_string()
// columns that are not going to be returned to the user
NoNeedToTypeCheck map[int]any

// the following fields are written to only once, and can then be shared between all users of this plan
typeLoading sync.Once
typesLoaded atomic.Bool
fields []*querypb.Field
fieldTypes []evalengine.Type
typeError error
}

// NewConcatenate creates a Concatenate primitive. The ignoreCols slice contains the offsets that
Expand Down Expand Up @@ -96,13 +103,13 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
return nil, err
}

fields, fieldTypes, err := c.getFieldTypes(vcursor, res)
err = c.loadTypes(vcursor, res)
if err != nil {
return nil, err
}

var rows [][]sqltypes.Value
err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error {
err = c.coerceAndVisitResults(res, c.fieldTypes, func(result *sqltypes.Result) error {
rows = append(rows, result.Rows...)
return nil
}, evalengine.ParseSQLMode(vcursor.SQLMode()))
Expand All @@ -111,11 +118,18 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
}

return &sqltypes.Result{
Fields: fields,
Fields: c.fields,
Rows: rows,
}, nil
}

func (c *Concatenate) loadTypes(vcursor VCursor, res []*sqltypes.Result) error {
c.typeLoading.Do(func() {
c.getFieldTypes(vcursor, res)
})
return c.typeError
}

func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.Type, sqlmode evalengine.SQLMode) error {
if len(row) != len(fieldTypes) {
return errWrongNumberOfColumnsInSelect
Expand All @@ -136,9 +150,9 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.T
return nil
}

func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([]*querypb.Field, []evalengine.Type, error) {
func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) {
if len(res) == 0 {
return nil, nil, nil
return
}

typers := make([]evalengine.TypeAggregator, len(res[0].Fields))
Expand All @@ -149,11 +163,13 @@ func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([]
continue
}
if len(r.Fields) != len(typers) {
return nil, nil, errWrongNumberOfColumnsInSelect
c.typeError = errWrongNumberOfColumnsInSelect
return
}
for idx, field := range r.Fields {
if err := typers[idx].AddField(field, collations); err != nil {
return nil, nil, err
c.typeError = err
return
}
}
}
Expand All @@ -173,7 +189,9 @@ func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([]
fields = append(fields, t.ToField(f.Name))
types = append(types, t)
}
return fields, types, nil
c.fields = fields
c.fieldTypes = types
c.typesLoaded.Store(true)
}

func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) {
Expand Down Expand Up @@ -229,7 +247,7 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV
// TryStreamExecute performs a streaming exec.
func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error {
sqlmode := evalengine.ParseSQLMode(vcursor.SQLMode())
if vcursor.Session().InTransaction() {
if vcursor.Session().InTransaction() || !c.typesLoaded.Load() {
// as we are in a transaction, we need to execute all queries inside a single connection,
// which holds the single transaction we have
return c.sequentialStreamExec(ctx, vcursor, bindVars, callback, sqlmode)
Expand All @@ -238,82 +256,52 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
return c.parallelStreamExec(ctx, vcursor, bindVars, callback, sqlmode)
}

func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error, sqlmode evalengine.SQLMode) error {
// parallelStreamExec runs and returns the sub queries in parallel
// it assumes the field types have been loaded
func (c *Concatenate) parallelStreamExec(
inCtx context.Context,
vcursor VCursor,
bindVars map[string]*querypb.BindVariable,
in func(*sqltypes.Result) error,
sqlmode evalengine.SQLMode,
) error {
// Scoped context; any early exit triggers cancel() to clean up ongoing work.
ctx, cancel := context.WithCancel(inCtx)
defer cancel()

// Mutexes for dealing with concurrent access to shared state.
var (
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fieldTypes []evalengine.Type // Cached final field types
)

// Process each result chunk, considering type coercion.
callback := func(res *sqltypes.Result, srcIdx int) error {
muCallback.Lock()
defer muCallback.Unlock()

if len(res.Rows) == 0 {
return in(res)
}
// Check if type coercion needed for this source.
// We only need to check if fields are not in NoNeedToTypeCheck set.
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
for idx, field := range c.fieldTypes {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fieldTypes[idx].Type() != field.Type {
if !skip && field.Type() != res.Fields[idx].Type {
needsCoercion = true
break
}
}

// Apply type coercion if needed.
// TODO: we should be able to do this only once as well, and remember if we need coercing here or not
if needsCoercion {
for _, row := range res.Rows {
if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil {
if err := c.coerceValuesTo(row, c.fieldTypes, sqlmode); err != nil {
return err
}
}
}
return in(res)
}

var wg errgroup.Group
// Start streaming query execution in parallel for all sources.
for i, source := range c.Sources {
currIndex, currSource := i, source
wg.Go(func() error {
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error {
muFields.Lock()

// Process fields when they arrive; coordinate field agreement across sources.
if resultChunk.Fields != nil && rest[currIndex] == nil {
// Capture the initial result chunk to determine field types later.
rest[currIndex] = resultChunk

// If this was the last source to report its fields, derive the final output fields.
if !slices.Contains(rest, nil) {
// We have received fields from all sources. We can now calculate the output types
var err error
resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest)
if err != nil {
muFields.Unlock()
return err
}

muFields.Unlock()
defer condFields.Broadcast()
return callback(resultChunk, currIndex)
}
}

// Wait for fields from all sources.
for slices.Contains(rest, nil) {
condFields.Wait()
}
muFields.Unlock()

// Context check to avoid extra work.
if ctx.Err() != nil {
return nil
Expand All @@ -323,14 +311,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,

// Error handling and context cleanup for this source.
if err != nil {
muFields.Lock()
if rest[currIndex] == nil {
// Signal that this source is done, even if by failure, to unblock field waiting.
rest[currIndex] = &sqltypes.Result{}
}
cancel()
condFields.Broadcast()
muFields.Unlock()
}
return err
})
Expand Down Expand Up @@ -363,17 +344,19 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor,
}
}

firsts := make([]*sqltypes.Result, len(c.Sources))
for i, result := range results {
firsts[i] = result[0]
c.typeLoading.Do(func() {
firsts := make([]*sqltypes.Result, len(c.Sources))
for i, result := range results {
firsts[i] = result[0]
}
c.getFieldTypes(vcursor, firsts)
})
if c.typeError != nil {
return c.typeError
}

_, fieldTypes, err := c.getFieldTypes(vcursor, firsts)
if err != nil {
return err
}
for _, res := range results {
if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil {
if err := c.coerceAndVisitResults(res, c.fieldTypes, callback, sqlmode); err != nil {
return err
}
}
Expand Down

0 comments on commit 05b33a9

Please sign in to comment.