From 6c36526309cd103d6d9e4492412884e06b2a2a2b Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Fri, 23 Aug 2024 11:43:30 +0530 Subject: [PATCH 1/4] feat: fix fields passed in results from concatenate Signed-off-by: Manan Gupta --- go/vt/vtgate/engine/concatenate.go | 29 ++++++++++++++-------- go/vt/vtgate/engine/concatenate_test.go | 17 ++++++------- go/vt/vtgate/engine/fake_primitive_test.go | 10 ++++++++ 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 13727124e78..f6621be2192 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -102,7 +102,7 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars } var rows [][]sqltypes.Value - err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error { + err = c.coerceAndVisitResults(res, fields, fieldTypes, func(result *sqltypes.Result) error { rows = append(rows, result.Rows...) return nil }, evalengine.ParseSQLMode(vcursor.SQLMode())) @@ -245,12 +245,13 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Mutexes for dealing with concurrent access to shared state. var ( - muCallback sync.Mutex // Protects callback - muFields sync.Mutex // Protects field state - condFields = sync.NewCond(&muFields) // Condition var for field arrival - wg errgroup.Group // Wait group for all streaming goroutines - rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields - fieldTypes []evalengine.Type // Cached final field types + muCallback sync.Mutex // Protects callback + muFields sync.Mutex // Protects field state + condFields = sync.NewCond(&muFields) // Condition var for field arrival + wg errgroup.Group // Wait group for all streaming goroutines + rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields + fieldTypes []evalengine.Type // Cached final field types + resultFields []*querypb.Field ) // Process each result chunk, considering type coercion. @@ -277,6 +278,9 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, } } } + if res.Fields != nil { + res.Fields = resultFields + } return in(res) } @@ -296,7 +300,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, if !slices.Contains(rest, nil) { // We have received fields from all sources. We can now calculate the output types var err error - resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest) + resultFields, fieldTypes, err = c.getFieldTypes(vcursor, rest) if err != nil { muFields.Unlock() return err @@ -310,6 +314,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Wait for fields from all sources. for slices.Contains(rest, nil) { + // This wait call lets go of the muFields lock and acquires it again later after waiting. condFields.Wait() } muFields.Unlock() @@ -368,12 +373,12 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, firsts[i] = result[0] } - _, fieldTypes, err := c.getFieldTypes(vcursor, firsts) + fields, fieldTypes, err := c.getFieldTypes(vcursor, firsts) if err != nil { return err } for _, res := range results { - if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil { + if err = c.coerceAndVisitResults(res, fields, fieldTypes, callback, sqlmode); err != nil { return err } } @@ -383,6 +388,7 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, func (c *Concatenate) coerceAndVisitResults( res []*sqltypes.Result, + fields []*querypb.Field, fieldTypes []evalengine.Type, callback func(*sqltypes.Result) error, sqlmode evalengine.SQLMode, @@ -408,6 +414,9 @@ func (c *Concatenate) coerceAndVisitResults( } } } + if r.Fields != nil { + r.Fields = fields + } err := callback(r) if err != nil { return err diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index dd2b1300e9b..39b9ed961b3 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -124,27 +124,24 @@ func TestConcatenate_NoErrors(t *testing.T) { if !tx { txStr = "NotInTx" } - t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) { - qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true) + checkResult := func(t *testing.T, qr *sqltypes.Result, err error) { if tc.expectedError == "" { require.NoError(t, err) utils.MustMatch(t, tc.expectedResult.Fields, qr.Fields, "fields") - utils.MustMatch(t, tc.expectedResult.Rows, qr.Rows) + require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows)) } else { require.Error(t, err) require.Contains(t, err.Error(), tc.expectedError) } + } + t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) { + qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true) + checkResult(t, qr, err) }) t.Run(fmt.Sprintf("%s-%s-StreamExec", txStr, tc.testName), func(t *testing.T) { qr, err := wrapStreamExecute(concatenate, vcursor, nil, true) - if tc.expectedError == "" { - require.NoError(t, err) - require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows)) - } else { - require.Error(t, err) - require.Contains(t, err.Error(), tc.expectedError) - } + checkResult(t, qr, err) }) } } diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index 6ab54fe9e7b..f1aa757e62d 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -24,6 +24,7 @@ import ( "testing" "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/proto" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -188,6 +189,15 @@ func wrapStreamExecute(prim Primitive, vcursor VCursor, bindVars map[string]*que if result == nil { result = r } else { + if r.Fields != nil { + for i, field := range r.Fields { + aField := field + bField := result.Fields[i] + if !proto.Equal(aField, bField) { + return fmt.Errorf("fields differ: %s <> %s", aField.String(), bField.String()) + } + } + } result.Rows = append(result.Rows, r.Rows...) } return nil From 84461a7194783cb28693d45209b0b3561544fb18 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Fri, 23 Aug 2024 13:08:41 +0530 Subject: [PATCH 2/4] feat: fix coercion for sequential streaming execute Signed-off-by: Manan Gupta --- go/vt/vtgate/engine/concatenate.go | 32 +++++++++++++--------- go/vt/vtgate/engine/fake_primitive_test.go | 2 +- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index f6621be2192..1098d07ecc4 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -102,12 +102,14 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars } var rows [][]sqltypes.Value - err = c.coerceAndVisitResults(res, fields, fieldTypes, func(result *sqltypes.Result) error { + callback := func(result *sqltypes.Result) error { rows = append(rows, result.Rows...) return nil - }, evalengine.ParseSQLMode(vcursor.SQLMode())) - if err != nil { - return nil, err + } + for _, r := range res { + if err = c.coerceAndVisitResultsForOneSource([]*sqltypes.Result{r}, fields, fieldTypes, callback, evalengine.ParseSQLMode(vcursor.SQLMode())); err != nil { + return nil, err + } } return &sqltypes.Result{ @@ -378,7 +380,7 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, return err } for _, res := range results { - if err = c.coerceAndVisitResults(res, fields, fieldTypes, callback, sqlmode); err != nil { + if err = c.coerceAndVisitResultsForOneSource(res, fields, fieldTypes, callback, sqlmode); err != nil { return err } } @@ -386,26 +388,30 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, return nil } -func (c *Concatenate) coerceAndVisitResults( +func (c *Concatenate) coerceAndVisitResultsForOneSource( res []*sqltypes.Result, fields []*querypb.Field, fieldTypes []evalengine.Type, callback func(*sqltypes.Result) error, sqlmode evalengine.SQLMode, ) error { + if len(res) == 0 { + return nil + } + needsCoercion := false + for idx, field := range res[0].Fields { + if fieldTypes[idx].Type() != field.Type { + needsCoercion = true + break + } + } + for _, r := range res { if len(r.Rows) > 0 && len(fieldTypes) != len(r.Rows[0]) { return errWrongNumberOfColumnsInSelect } - needsCoercion := false - for idx, field := range r.Fields { - if fieldTypes[idx].Type() != field.Type { - needsCoercion = true - break - } - } if needsCoercion { for _, row := range r.Rows { err := c.coerceValuesTo(row, fieldTypes, sqlmode) diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index f1aa757e62d..b878c1931c0 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -112,7 +112,7 @@ func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result } result := &sqltypes.Result{} for i := 0; i < len(r.Rows); i++ { - result.Rows = append(result.Rows, r.Rows[i]) + result.Rows = append(result.Rows, sqltypes.CopyRow(r.Rows[i])) // Send only two rows at a time. if i%2 == 1 { if err := callback(result); err != nil { From b84cc3614ebb938af556f5235edea5a0fc1fd58b Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Fri, 23 Aug 2024 13:50:25 +0530 Subject: [PATCH 3/4] feat: fix the flakiness wherein one source might not have its values coerced Signed-off-by: Manan Gupta --- go/vt/vtgate/engine/concatenate.go | 42 ++++++++++++++++-------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 1098d07ecc4..8d527d84033 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -247,13 +247,14 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Mutexes for dealing with concurrent access to shared state. var ( - muCallback sync.Mutex // Protects callback - muFields sync.Mutex // Protects field state - condFields = sync.NewCond(&muFields) // Condition var for field arrival - wg errgroup.Group // Wait group for all streaming goroutines - rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields - fieldTypes []evalengine.Type // Cached final field types - resultFields []*querypb.Field + muCallback sync.Mutex // Protects callback + muFields sync.Mutex // Protects field state + condFields = sync.NewCond(&muFields) // Condition var for field arrival + wg errgroup.Group // Wait group for all streaming goroutines + rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields + fieldTypes []evalengine.Type // Cached final field types + resultFields []*querypb.Field // Final fields that need to be set for any result having fields. + needsCoercion = make([]bool, len(c.Sources)) // Tracks if coercion is needed for each individual source ) // Process each result chunk, considering type coercion. @@ -261,19 +262,8 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, muCallback.Lock() defer muCallback.Unlock() - // Check if type coercion needed for this source. - // We only need to check if fields are not in NoNeedToTypeCheck set. - needsCoercion := false - for idx, field := range rest[srcIdx].Fields { - _, skip := c.NoNeedToTypeCheck[idx] - if !skip && fieldTypes[idx].Type() != field.Type { - needsCoercion = true - break - } - } - // Apply type coercion if needed. - if needsCoercion { + if needsCoercion[srcIdx] { for _, row := range res.Rows { if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil { return err @@ -308,6 +298,20 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return err } + // Check if we need coercion for each source. + for srcIdx, result := range rest { + srcNeedsCoercion := false + for idx, field := range result.Fields { + _, skip := c.NoNeedToTypeCheck[idx] + // We only need to check if fields are not in NoNeedToTypeCheck set. + if !skip && fieldTypes[idx].Type() != field.Type { + srcNeedsCoercion = true + break + } + } + needsCoercion[srcIdx] = srcNeedsCoercion + } + muFields.Unlock() defer condFields.Broadcast() return callback(resultChunk, currIndex) From bcec73c35e9133b88ecd6e6881f6ec2de14b9024 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Mon, 26 Aug 2024 11:01:30 +0530 Subject: [PATCH 4/4] feat: address review comments Signed-off-by: Manan Gupta --- go/vt/vtgate/engine/concatenate.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 8d527d84033..eb93711eed2 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -253,7 +253,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, wg errgroup.Group // Wait group for all streaming goroutines rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields fieldTypes []evalengine.Type // Cached final field types - resultFields []*querypb.Field // Final fields that need to be set for any result having fields. + resultFields []*querypb.Field // Final fields that need to be set for the first result. needsCoercion = make([]bool, len(c.Sources)) // Tracks if coercion is needed for each individual source ) @@ -270,9 +270,6 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, } } } - if res.Fields != nil { - res.Fields = resultFields - } return in(res) } @@ -312,6 +309,9 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, needsCoercion[srcIdx] = srcNeedsCoercion } + // We only need to send the fields in the first result. + // We set this field after the coercion check to avoid calculating incorrect needs coercion value. + resultChunk.Fields = resultFields muFields.Unlock() defer condFields.Broadcast() return callback(resultChunk, currIndex) @@ -323,6 +323,8 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // This wait call lets go of the muFields lock and acquires it again later after waiting. condFields.Wait() } + // We only need to send fields in the first result. + resultChunk.Fields = nil muFields.Unlock() // Context check to avoid extra work. @@ -409,6 +411,9 @@ func (c *Concatenate) coerceAndVisitResultsForOneSource( break } } + if res[0].Fields != nil { + res[0].Fields = fields + } for _, r := range res { if len(r.Rows) > 0 && @@ -424,9 +429,6 @@ func (c *Concatenate) coerceAndVisitResultsForOneSource( } } } - if r.Fields != nil { - r.Fields = fields - } err := callback(r) if err != nil { return err