Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions in the concatenate engine streaming #16640

Merged
merged 4 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 54 additions & 33 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
}

var rows [][]sqltypes.Value
err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error {
callback := func(result *sqltypes.Result) error {
rows = append(rows, result.Rows...)
return nil
}, evalengine.ParseSQLMode(vcursor.SQLMode()))
if err != nil {
return nil, err
}
for _, r := range res {
if err = c.coerceAndVisitResultsForOneSource([]*sqltypes.Result{r}, fields, fieldTypes, callback, evalengine.ParseSQLMode(vcursor.SQLMode())); err != nil {
return nil, err
}
}

return &sqltypes.Result{
Expand Down Expand Up @@ -245,32 +247,23 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,

// Mutexes for dealing with concurrent access to shared state.
var (
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fieldTypes []evalengine.Type // Cached final field types
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fieldTypes []evalengine.Type // Cached final field types
resultFields []*querypb.Field // Final fields that need to be set for the first result.
needsCoercion = make([]bool, len(c.Sources)) // Tracks if coercion is needed for each individual source
)

// Process each result chunk, considering type coercion.
callback := func(res *sqltypes.Result, srcIdx int) error {
muCallback.Lock()
defer muCallback.Unlock()

// Check if type coercion needed for this source.
// We only need to check if fields are not in NoNeedToTypeCheck set.
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}

// Apply type coercion if needed.
if needsCoercion {
if needsCoercion[srcIdx] {
for _, row := range res.Rows {
if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil {
return err
Expand All @@ -296,12 +289,29 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
if !slices.Contains(rest, nil) {
// We have received fields from all sources. We can now calculate the output types
var err error
resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest)
resultFields, fieldTypes, err = c.getFieldTypes(vcursor, rest)
if err != nil {
muFields.Unlock()
return err
}

// Check if we need coercion for each source.
for srcIdx, result := range rest {
srcNeedsCoercion := false
for idx, field := range result.Fields {
_, skip := c.NoNeedToTypeCheck[idx]
// We only need to check if fields are not in NoNeedToTypeCheck set.
if !skip && fieldTypes[idx].Type() != field.Type {
srcNeedsCoercion = true
break
}
}
needsCoercion[srcIdx] = srcNeedsCoercion
}

// We only need to send the fields in the first result.
// We set this field after the coercion check to avoid calculating incorrect needs coercion value.
resultChunk.Fields = resultFields
muFields.Unlock()
defer condFields.Broadcast()
return callback(resultChunk, currIndex)
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -310,8 +320,11 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,

// Wait for fields from all sources.
for slices.Contains(rest, nil) {
// This wait call lets go of the muFields lock and acquires it again later after waiting.
condFields.Wait()
}
// We only need to send fields in the first result.
resultChunk.Fields = nil
muFields.Unlock()

// Context check to avoid extra work.
Expand Down Expand Up @@ -368,38 +381,46 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor,
firsts[i] = result[0]
}

_, fieldTypes, err := c.getFieldTypes(vcursor, firsts)
fields, fieldTypes, err := c.getFieldTypes(vcursor, firsts)
if err != nil {
return err
}
for _, res := range results {
if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil {
if err = c.coerceAndVisitResultsForOneSource(res, fields, fieldTypes, callback, sqlmode); err != nil {
return err
}
}

return nil
}

func (c *Concatenate) coerceAndVisitResults(
func (c *Concatenate) coerceAndVisitResultsForOneSource(
res []*sqltypes.Result,
fields []*querypb.Field,
fieldTypes []evalengine.Type,
callback func(*sqltypes.Result) error,
sqlmode evalengine.SQLMode,
) error {
if len(res) == 0 {
return nil
}
needsCoercion := false
for idx, field := range res[0].Fields {
if fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}
if res[0].Fields != nil {
res[0].Fields = fields
}

for _, r := range res {
if len(r.Rows) > 0 &&
len(fieldTypes) != len(r.Rows[0]) {
return errWrongNumberOfColumnsInSelect
}

needsCoercion := false
for idx, field := range r.Fields {
if fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}
if needsCoercion {
for _, row := range r.Rows {
err := c.coerceValuesTo(row, fieldTypes, sqlmode)
Expand Down
17 changes: 7 additions & 10 deletions go/vt/vtgate/engine/concatenate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,24 @@ func TestConcatenate_NoErrors(t *testing.T) {
if !tx {
txStr = "NotInTx"
}
t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) {
qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true)
checkResult := func(t *testing.T, qr *sqltypes.Result, err error) {
if tc.expectedError == "" {
require.NoError(t, err)
utils.MustMatch(t, tc.expectedResult.Fields, qr.Fields, "fields")
utils.MustMatch(t, tc.expectedResult.Rows, qr.Rows)
require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows))
} else {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectedError)
}
}
t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) {
qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true)
checkResult(t, qr, err)
})

t.Run(fmt.Sprintf("%s-%s-StreamExec", txStr, tc.testName), func(t *testing.T) {
qr, err := wrapStreamExecute(concatenate, vcursor, nil, true)
if tc.expectedError == "" {
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows))
} else {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectedError)
}
checkResult(t, qr, err)
})
}
}
Expand Down
12 changes: 11 additions & 1 deletion go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"testing"

"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -111,7 +112,7 @@ func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result
}
result := &sqltypes.Result{}
for i := 0; i < len(r.Rows); i++ {
result.Rows = append(result.Rows, r.Rows[i])
result.Rows = append(result.Rows, sqltypes.CopyRow(r.Rows[i]))
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
Expand Down Expand Up @@ -188,6 +189,15 @@ func wrapStreamExecute(prim Primitive, vcursor VCursor, bindVars map[string]*que
if result == nil {
result = r
} else {
if r.Fields != nil {
for i, field := range r.Fields {
aField := field
bField := result.Fields[i]
if !proto.Equal(aField, bField) {
return fmt.Errorf("fields differ: %s <> %s", aField.String(), bField.String())
}
}
}
result.Rows = append(result.Rows, r.Rows...)
}
return nil
Expand Down
Loading