Skip to content

Commit

Permalink
Fix race conditions in the concatenate engine streaming (#16640)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 authored Aug 26, 2024
1 parent eb11918 commit d95e36f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 44 deletions.
87 changes: 54 additions & 33 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
}

var rows [][]sqltypes.Value
err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error {
callback := func(result *sqltypes.Result) error {
rows = append(rows, result.Rows...)
return nil
}, evalengine.ParseSQLMode(vcursor.SQLMode()))
if err != nil {
return nil, err
}
for _, r := range res {
if err = c.coerceAndVisitResultsForOneSource([]*sqltypes.Result{r}, fields, fieldTypes, callback, evalengine.ParseSQLMode(vcursor.SQLMode())); err != nil {
return nil, err
}
}

return &sqltypes.Result{
Expand Down Expand Up @@ -245,32 +247,23 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,

// Mutexes for dealing with concurrent access to shared state.
var (
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fieldTypes []evalengine.Type // Cached final field types
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fieldTypes []evalengine.Type // Cached final field types
resultFields []*querypb.Field // Final fields that need to be set for the first result.
needsCoercion = make([]bool, len(c.Sources)) // Tracks if coercion is needed for each individual source
)

// Process each result chunk, considering type coercion.
callback := func(res *sqltypes.Result, srcIdx int) error {
muCallback.Lock()
defer muCallback.Unlock()

// Check if type coercion needed for this source.
// We only need to check if fields are not in NoNeedToTypeCheck set.
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}

// Apply type coercion if needed.
if needsCoercion {
if needsCoercion[srcIdx] {
for _, row := range res.Rows {
if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil {
return err
Expand All @@ -296,12 +289,29 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
if !slices.Contains(rest, nil) {
// We have received fields from all sources. We can now calculate the output types
var err error
resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest)
resultFields, fieldTypes, err = c.getFieldTypes(vcursor, rest)
if err != nil {
muFields.Unlock()
return err
}

// Check if we need coercion for each source.
for srcIdx, result := range rest {
srcNeedsCoercion := false
for idx, field := range result.Fields {
_, skip := c.NoNeedToTypeCheck[idx]
// We only need to check if fields are not in NoNeedToTypeCheck set.
if !skip && fieldTypes[idx].Type() != field.Type {
srcNeedsCoercion = true
break
}
}
needsCoercion[srcIdx] = srcNeedsCoercion
}

// We only need to send the fields in the first result.
// We set this field after the coercion check to avoid calculating incorrect needs coercion value.
resultChunk.Fields = resultFields
muFields.Unlock()
defer condFields.Broadcast()
return callback(resultChunk, currIndex)
Expand All @@ -310,8 +320,11 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,

// Wait for fields from all sources.
for slices.Contains(rest, nil) {
// This wait call lets go of the muFields lock and acquires it again later after waiting.
condFields.Wait()
}
// We only need to send fields in the first result.
resultChunk.Fields = nil
muFields.Unlock()

// Context check to avoid extra work.
Expand Down Expand Up @@ -368,38 +381,46 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor,
firsts[i] = result[0]
}

_, fieldTypes, err := c.getFieldTypes(vcursor, firsts)
fields, fieldTypes, err := c.getFieldTypes(vcursor, firsts)
if err != nil {
return err
}
for _, res := range results {
if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil {
if err = c.coerceAndVisitResultsForOneSource(res, fields, fieldTypes, callback, sqlmode); err != nil {
return err
}
}

return nil
}

func (c *Concatenate) coerceAndVisitResults(
func (c *Concatenate) coerceAndVisitResultsForOneSource(
res []*sqltypes.Result,
fields []*querypb.Field,
fieldTypes []evalengine.Type,
callback func(*sqltypes.Result) error,
sqlmode evalengine.SQLMode,
) error {
if len(res) == 0 {
return nil
}
needsCoercion := false
for idx, field := range res[0].Fields {
if fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}
if res[0].Fields != nil {
res[0].Fields = fields
}

for _, r := range res {
if len(r.Rows) > 0 &&
len(fieldTypes) != len(r.Rows[0]) {
return errWrongNumberOfColumnsInSelect
}

needsCoercion := false
for idx, field := range r.Fields {
if fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}
if needsCoercion {
for _, row := range r.Rows {
err := c.coerceValuesTo(row, fieldTypes, sqlmode)
Expand Down
17 changes: 7 additions & 10 deletions go/vt/vtgate/engine/concatenate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,24 @@ func TestConcatenate_NoErrors(t *testing.T) {
if !tx {
txStr = "NotInTx"
}
t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) {
qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true)
checkResult := func(t *testing.T, qr *sqltypes.Result, err error) {
if tc.expectedError == "" {
require.NoError(t, err)
utils.MustMatch(t, tc.expectedResult.Fields, qr.Fields, "fields")
utils.MustMatch(t, tc.expectedResult.Rows, qr.Rows)
require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows))
} else {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectedError)
}
}
t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) {
qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true)
checkResult(t, qr, err)
})

t.Run(fmt.Sprintf("%s-%s-StreamExec", txStr, tc.testName), func(t *testing.T) {
qr, err := wrapStreamExecute(concatenate, vcursor, nil, true)
if tc.expectedError == "" {
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows))
} else {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectedError)
}
checkResult(t, qr, err)
})
}
}
Expand Down
12 changes: 11 additions & 1 deletion go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"testing"

"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -111,7 +112,7 @@ func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result
}
result := &sqltypes.Result{}
for i := 0; i < len(r.Rows); i++ {
result.Rows = append(result.Rows, r.Rows[i])
result.Rows = append(result.Rows, sqltypes.CopyRow(r.Rows[i]))
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
Expand Down Expand Up @@ -188,6 +189,15 @@ func wrapStreamExecute(prim Primitive, vcursor VCursor, bindVars map[string]*que
if result == nil {
result = r
} else {
if r.Fields != nil {
for i, field := range r.Fields {
aField := field
bField := result.Fields[i]
if !proto.Equal(aField, bField) {
return fmt.Errorf("fields differ: %s <> %s", aField.String(), bField.String())
}
}
}
result.Rows = append(result.Rows, r.Rows...)
}
return nil
Expand Down

0 comments on commit d95e36f

Please sign in to comment.