diff --git a/go/sqltypes/result.go b/go/sqltypes/result.go index 389b7fff620..c1f63a1488b 100644 --- a/go/sqltypes/result.go +++ b/go/sqltypes/result.go @@ -153,7 +153,6 @@ func (result *Result) Truncate(l int) *Result { out := &Result{ InsertID: result.InsertID, - RowsAffected: result.RowsAffected, Info: result.Info, SessionStateChanges: result.SessionStateChanges, } @@ -166,6 +165,7 @@ func (result *Result) Truncate(l int) *Result { out.Rows = append(out.Rows, r[:l]) } } + out.MergeStats(result) return out } @@ -324,17 +324,33 @@ func (result *Result) StripMetadata(incl querypb.ExecuteOptions_IncludedFields) // to another result.Note currently it doesn't handle cases like // if two results have different fields.We will enhance this function. func (result *Result) AppendResult(src *Result) { - if src.RowsAffected == 0 && len(src.Rows) == 0 && len(src.Fields) == 0 { + if src.StatsEmpty() && len(src.Rows) == 0 && len(src.Fields) == 0 { return } if result.Fields == nil { result.Fields = src.Fields } - result.RowsAffected += src.RowsAffected if src.InsertID != 0 { result.InsertID = src.InsertID } result.Rows = append(result.Rows, src.Rows...) + result.MergeStats(src) +} + +// Stats returns a copy of result with only the stats fields +func (result *Result) Stats() *Result { + return &Result{ + RowsAffected: result.RowsAffected, + } +} + +func (result *Result) StatsEmpty() bool { + return result.RowsAffected == 0 +} + +// MergeStats updates the receiver's stats by merging in the stats from src. +func (result *Result) MergeStats(src *Result) { + result.RowsAffected += src.RowsAffected } // Named returns a NamedResult based on this struct diff --git a/go/sqltypes/result_test.go b/go/sqltypes/result_test.go index d8075ec0633..0771f1f3825 100644 --- a/go/sqltypes/result_test.go +++ b/go/sqltypes/result_test.go @@ -348,6 +348,32 @@ func TestAppendResult(t *testing.T) { } } +func TestStats(t *testing.T) { + result := &Result{ + RowsAffected: 1, + Fields: []*querypb.Field{{ + Type: Int64, + }, { + Type: VarChar, + }}, + InsertID: 1, + Rows: [][]Value{ + {TestValue(Int64, "1"), MakeTrusted(VarChar, nil)}, + {TestValue(Int64, "2"), MakeTrusted(VarChar, nil)}, + }, + } + want := &Result{ + RowsAffected: 1, + } + assert.Equal(t, want, result.Stats()) +} + +func TestMergeStats(t *testing.T) { + result := &Result{RowsAffected: 1} + result.MergeStats(&Result{RowsAffected: 2}) + assert.Equal(t, uint64(3), result.RowsAffected) +} + func TestReplaceKeyspace(t *testing.T) { result := &Result{ Fields: []*querypb.Field{{ diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 13727124e78..6bda880c45c 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -101,19 +101,19 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars return nil, err } - var rows [][]sqltypes.Value + out := &sqltypes.Result{ + Fields: fields, + } err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error { - rows = append(rows, result.Rows...) + out.Rows = append(out.Rows, result.Rows...) + out.MergeStats(result) return nil }, evalengine.ParseSQLMode(vcursor.SQLMode())) if err != nil { return nil, err } - return &sqltypes.Result{ - Fields: fields, - Rows: rows, - }, nil + return out, nil } func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.Type, sqlmode evalengine.SQLMode) error { diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index dd2b1300e9b..463bd627dd4 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -46,6 +46,12 @@ func r(names, types string, rows ...string) *sqltypes.Result { return sqltypes.MakeTestResult(fields, rows...) } +func rWithStats(rowsAffected uint64, names, types string, rows ...string) *sqltypes.Result { + result := r(names, types, rows...) + result.RowsAffected = rowsAffected + return result +} + func TestConcatenate_NoErrors(t *testing.T) { type testCase struct { testName string @@ -108,6 +114,13 @@ func TestConcatenate_NoErrors(t *testing.T) { r("id|col1|col2", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2"), }, expectedResult: r("myid|mycol1|mycol2", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2"), + }, { + testName: "merged stats", + inputs: []*sqltypes.Result{ + rWithStats(1, "id|col1|col2", "int64|varbinary|varbinary", "1|a1|b1"), + rWithStats(2, "id|col1|col2", "int64|varbinary|varbinary", "2|a2|b2"), + }, + expectedResult: rWithStats(3, "id|col1|col2", "int64|varbinary|varbinary", "1|a1|b1", "2|a2|b2"), }} for _, tc := range testCases { @@ -130,6 +143,9 @@ func TestConcatenate_NoErrors(t *testing.T) { require.NoError(t, err) utils.MustMatch(t, tc.expectedResult.Fields, qr.Fields, "fields") utils.MustMatch(t, tc.expectedResult.Rows, qr.Rows) + + // Only testing stats match in non-streaming mode + utils.MustMatch(t, tc.expectedResult.Stats(), qr.Stats(), "stats") } else { require.Error(t, err) require.Contains(t, err.Error(), tc.expectedError) diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 189440611c3..e48b7b7cfb9 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -111,6 +111,7 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map Fields: input.Fields, InsertID: input.InsertID, } + result.MergeStats(input) pt := newProbeTable(d.CheckCols, vcursor.Environment().CollationEnv()) diff --git a/go/vt/vtgate/engine/distinct_test.go b/go/vt/vtgate/engine/distinct_test.go index d7fe8786158..d8cc3131044 100644 --- a/go/vt/vtgate/engine/distinct_test.go +++ b/go/vt/vtgate/engine/distinct_test.go @@ -75,6 +75,11 @@ func TestDistinct(t *testing.T) { collations: []collations.ID{collations.CollationUtf8mb4ID, collations.Unknown}, inputs: r("myid|id", "varchar|int64", "monkey|1", "horse|1", "Horse|1", "Monkey|1", "horses|1", "MONKEY|2"), expectedResult: r("myid|id", "varchar|int64", "monkey|1", "horse|1", "horses|1", "MONKEY|2"), + }, { + testName: "merged stats", + collations: []collations.ID{collations.CollationUtf8mb4ID, collations.Unknown}, + inputs: rWithStats(10, "myid", "int64", "0", "1", "1", "null", "null"), + expectedResult: rWithStats(10, "myid", "int64", "0", "1", "null"), }} for _, tc := range testCases { @@ -107,6 +112,9 @@ func TestDistinct(t *testing.T) { got := fmt.Sprintf("%v", qr.Rows) expected := fmt.Sprintf("%v", tc.expectedResult.Rows) utils.MustMatch(t, expected, got, "result not what correct") + + // Only testing stats match in non-streaming mode + utils.MustMatch(t, tc.expectedResult.Stats(), qr.Stats(), "result stats did not match") } else { require.EqualError(t, err, tc.expectedError) } diff --git a/go/vt/vtgate/engine/dml_with_input.go b/go/vt/vtgate/engine/dml_with_input.go index e0eb3b03592..288dda603f5 100644 --- a/go/vt/vtgate/engine/dml_with_input.go +++ b/go/vt/vtgate/engine/dml_with_input.go @@ -83,7 +83,7 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa if res == nil { res = qr } else { - res.RowsAffected += qr.RowsAffected + res.MergeStats(qr) } } return res, nil @@ -146,7 +146,7 @@ func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[ if res == nil { res = qr } else { - res.RowsAffected += qr.RowsAffected + res.MergeStats(qr) } } return res, nil diff --git a/go/vt/vtgate/engine/fk_verify.go b/go/vt/vtgate/engine/fk_verify.go index 7184e5d8381..3668792d840 100644 --- a/go/vt/vtgate/engine/fk_verify.go +++ b/go/vt/vtgate/engine/fk_verify.go @@ -69,6 +69,7 @@ func (f *FkVerify) GetFields(ctx context.Context, vcursor VCursor, bindVars map[ // TryExecute implements the Primitive interface func (f *FkVerify) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + stats := &sqltypes.Result{} for _, v := range f.Verify { qr, err := vcursor.ExecutePrimitive(ctx, v.Exec, bindVars, wantfields) if err != nil { @@ -77,8 +78,14 @@ func (f *FkVerify) TryExecute(ctx context.Context, vcursor VCursor, bindVars map if len(qr.Rows) > 0 { return nil, getError(v.Typ) } + stats.MergeStats(qr) } - return vcursor.ExecutePrimitive(ctx, f.Exec, bindVars, wantfields) + + result, err := vcursor.ExecutePrimitive(ctx, f.Exec, bindVars, wantfields) + if result != nil { + result.MergeStats(stats) + } + return result, err } // TryStreamExecute implements the Primitive interface diff --git a/go/vt/vtgate/engine/fk_verify_test.go b/go/vt/vtgate/engine/fk_verify_test.go index 5c9ff83c2ec..a16246bd8dd 100644 --- a/go/vt/vtgate/engine/fk_verify_test.go +++ b/go/vt/vtgate/engine/fk_verify_test.go @@ -20,6 +20,8 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -58,9 +60,13 @@ func TestFKVerifyUpdate(t *testing.T) { t.Run("foreign key verification success", func(t *testing.T) { fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64")) + fakeRes.RowsAffected = 1 vc := newDMLTestVCursor("0") - vc.results = []*sqltypes.Result{fakeRes} - _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) + vc.results = []*sqltypes.Result{ + fakeRes, + {RowsAffected: 2}, + } + result, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.NoError(t, err) vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, @@ -68,6 +74,7 @@ func TestFKVerifyUpdate(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: update child set cola = 1, colb = 'a' where foo = 48 {} true true`, }) + assert.Equal(t, uint64(3), result.RowsAffected) vc.Rewind() err = fkc.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { return nil }) diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index 6b9425a35d1..254c55f0c84 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -115,6 +115,8 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma result := &sqltypes.Result{ Fields: joinFields(lresult.Fields, rresult.Fields, hj.Cols), } + result.MergeStats(lresult) + result.MergeStats(rresult) for _, currentRHSRow := range rresult.Rows { matches, err := pt.get(currentRHSRow) diff --git a/go/vt/vtgate/engine/hash_join_test.go b/go/vt/vtgate/engine/hash_join_test.go index d3271c643be..39a3beb185d 100644 --- a/go/vt/vtgate/engine/hash_join_test.go +++ b/go/vt/vtgate/engine/hash_join_test.go @@ -32,7 +32,7 @@ func TestHashJoinVariations(t *testing.T) { // This test tries the different variations of hash-joins: // comparing values of same type and different types, and both left and right outer joins lhs := func() Primitive { - return &fakePrimitive{ + p := &fakePrimitive{ results: []*sqltypes.Result{ sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -46,9 +46,12 @@ func TestHashJoinVariations(t *testing.T) { ), }, } + p.results[0].RowsAffected = 1 + return p } + rhs := func() Primitive { - return &fakePrimitive{ + p := &fakePrimitive{ results: []*sqltypes.Result{ sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -62,6 +65,8 @@ func TestHashJoinVariations(t *testing.T) { ), }, } + p.results[0].RowsAffected = 2 + return p } rows := func(r ...string) []string { return r } @@ -131,6 +136,7 @@ func TestHashJoinVariations(t *testing.T) { } expected := sqltypes.MakeTestResult(fields, tc.expected...) + expected.RowsAffected = 3 typ, err := evalengine.CoerceTypes(typeForOffset(tc.lhs), typeForOffset(tc.rhs), collations.MySQL8()) require.NoError(t, err) @@ -157,6 +163,9 @@ func TestHashJoinVariations(t *testing.T) { jn.Right = last() r, err := wrapStreamExecute(jn, &noopVCursor{}, map[string]*querypb.BindVariable{}, true) require.NoError(t, err) + + // Result stats handling not implemented for streaming + expected.RowsAffected = 0 expectResultAnyOrder(t, r, expected) }) } diff --git a/go/vt/vtgate/engine/insert_select.go b/go/vt/vtgate/engine/insert_select.go index 88767420508..d6477118b63 100644 --- a/go/vt/vtgate/engine/insert_select.go +++ b/go/vt/vtgate/engine/insert_select.go @@ -306,15 +306,20 @@ func (ins *InsertSelect) buildVindexRowsValues(rows []sqltypes.Row) ([][]sqltype } func (ins *InsertSelect) execInsertSharded(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - result, err := ins.execSelect(ctx, vcursor, bindVars) + selectResult, err := ins.execSelect(ctx, vcursor, bindVars) if err != nil { return nil, err } - if len(result.rows) == 0 { - return &sqltypes.Result{}, nil + if len(selectResult.rows) == 0 { + return selectResult.stats, nil + } + + result, err := ins.insertIntoShardedTable(ctx, vcursor, bindVars, selectResult) + if result != nil { + result.MergeStats(selectResult.stats) } - return ins.insertIntoShardedTable(ctx, vcursor, bindVars, result) + return result, err } func (ins *InsertSelect) description() PrimitiveDescription { @@ -369,6 +374,7 @@ func insertVarOffset(rowNum, colOffset int) string { type insertRowsResult struct { rows []sqltypes.Row insertID uint64 + stats *sqltypes.Result } func (ins *InsertSelect) execSelect( @@ -377,9 +383,12 @@ func (ins *InsertSelect) execSelect( bindVars map[string]*querypb.BindVariable, ) (insertRowsResult, error) { res, err := vcursor.ExecutePrimitive(ctx, ins.Input, bindVars, false) - if err != nil || len(res.Rows) == 0 { + if err != nil { return insertRowsResult{}, err } + if len(res.Rows) == 0 { + return insertRowsResult{stats: res.Stats()}, nil + } insertID, err := ins.processGenerateFromSelect(ctx, vcursor, ins, res.Rows) if err != nil { @@ -389,6 +398,7 @@ func (ins *InsertSelect) execSelect( return insertRowsResult{ rows: res.Rows, insertID: uint64(insertID), + stats: res.Stats(), }, nil } diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index 2de95e5d186..be82fe92162 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -21,6 +21,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -1703,9 +1705,14 @@ func TestInsertSelectSimple(t *testing.T) { "varchar|int64"), "a|1", "a|3", - "b|2")} + "b|2"), + {}, + } - _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + vc.results[0].RowsAffected = 1 + vc.results[1].RowsAffected = 2 + + result, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ `ResolveDestinations sharded [] Destinations:DestinationAllShards()`, @@ -1721,6 +1728,7 @@ func TestInsertSelectSimple(t *testing.T) { ` _c2_0: type:VARCHAR value:"b" _c2_1: type:INT64 value:"2"} ` + `sharded.-20: prefix values (:_c1_0, :_c1_1)` + ` {_c1_0: type:VARCHAR value:"a" _c1_1: type:INT64 value:"3"} true false`}) + assert.Equal(t, uint64(3), result.RowsAffected) vc.Rewind() err = ins.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false, func(result *sqltypes.Result) error { diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 51976396cba..30d0a82ce7b 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -60,6 +60,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return nil, err } result := &sqltypes.Result{} + result.MergeStats(lresult) if len(lresult.Rows) == 0 && wantfields { for k, col := range jn.Vars { joinVars[k] = bindvarForType(lresult.Fields[col]) @@ -86,6 +87,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st for _, rrow := range rresult.Rows { result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols)) } + result.MergeStats(rresult) if jn.Opcode == LeftJoin && len(rresult.Rows) == 0 { result.Rows = append(result.Rows, joinRows(lrow, nil, jn.Cols)) } diff --git a/go/vt/vtgate/engine/join_test.go b/go/vt/vtgate/engine/join_test.go index eef5810ce69..90c95792c40 100644 --- a/go/vt/vtgate/engine/join_test.go +++ b/go/vt/vtgate/engine/join_test.go @@ -21,6 +21,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -42,6 +44,7 @@ func TestJoinExecute(t *testing.T) { ), }, } + leftPrim.results[0].RowsAffected = 1 rightFields := sqltypes.MakeTestFields( "col4|col5|col6", "int64|varchar|varchar", @@ -63,6 +66,10 @@ func TestJoinExecute(t *testing.T) { ), }, } + rightPrim.results[0].RowsAffected = 2 + rightPrim.results[1].RowsAffected = 3 + rightPrim.results[2].RowsAffected = 4 + bv := map[string]*querypb.BindVariable{ "a": sqltypes.Int64BindVariable(10), } @@ -99,6 +106,7 @@ func TestJoinExecute(t *testing.T) { "3|c|6|f", "3|c|7|g", )) + assert.Equal(t, uint64(10), r.RowsAffected) // Left Join leftPrim.rewind() @@ -127,6 +135,7 @@ func TestJoinExecute(t *testing.T) { "3|c|6|f", "3|c|7|g", )) + assert.Equal(t, uint64(10), r.RowsAffected) } func TestJoinExecuteMaxMemoryRows(t *testing.T) { diff --git a/go/vt/vtgate/engine/lock.go b/go/vt/vtgate/engine/lock.go index 7739cbcd0cc..9886eedaeb0 100644 --- a/go/vt/vtgate/engine/lock.go +++ b/go/vt/vtgate/engine/lock.go @@ -89,6 +89,7 @@ func (l *Lock) execLock(ctx context.Context, vcursor VCursor, bindVars map[strin env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) var fields []*querypb.Field var rrow sqltypes.Row + stats := &sqltypes.Result{} for _, lf := range l.LockFunctions { var lName string if lf.Name != nil { @@ -105,6 +106,7 @@ func (l *Lock) execLock(ctx context.Context, vcursor VCursor, bindVars map[strin fields = append(fields, qr.Fields...) lockRes := qr.Rows[0] rrow = append(rrow, lockRes...) + stats.MergeStats(qr) switch lf.Typ.Type { case sqlparser.IsFreeLock, sqlparser.IsUsedLock: @@ -130,10 +132,12 @@ func (l *Lock) execLock(ctx context.Context, vcursor VCursor, bindVars map[strin } } } - return &sqltypes.Result{ + result := &sqltypes.Result{ Fields: fields, Rows: []sqltypes.Row{rrow}, - }, nil + } + result.MergeStats(stats) + return result, nil } func (lf *LockFunc) execLock(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, rs *srvtopo.ResolvedShard) (*sqltypes.Result, error) { diff --git a/go/vt/vtgate/engine/scalar_aggregation.go b/go/vt/vtgate/engine/scalar_aggregation.go index e33204f5c58..d27576bc3a6 100644 --- a/go/vt/vtgate/engine/scalar_aggregation.go +++ b/go/vt/vtgate/engine/scalar_aggregation.go @@ -100,6 +100,7 @@ func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bind Fields: fields, Rows: [][]sqltypes.Value{agg.finish()}, } + out.MergeStats(result) return out.Truncate(sa.TruncateColumnCount), nil } diff --git a/go/vt/vtgate/engine/scalar_aggregation_test.go b/go/vt/vtgate/engine/scalar_aggregation_test.go index 6fa0c8aecb8..735ab8c5962 100644 --- a/go/vt/vtgate/engine/scalar_aggregation_test.go +++ b/go/vt/vtgate/engine/scalar_aggregation_test.go @@ -79,6 +79,7 @@ func TestEmptyRows(outer *testing.T) { // Empty input table )}, } + fp.results[0].RowsAffected = 1 oa := &ScalarAggregate{ Aggregates: []*AggregateParams{{ @@ -100,6 +101,8 @@ func TestEmptyRows(outer *testing.T) { ), test.expectedVal, ) + // Ensure Result stats are passed through + wantResult.RowsAffected = 1 utils.MustMatch(t, wantResult, result) }) } diff --git a/go/vt/vtgate/engine/semi_join.go b/go/vt/vtgate/engine/semi_join.go index f0dd0d09033..39a2b4b71a3 100644 --- a/go/vt/vtgate/engine/semi_join.go +++ b/go/vt/vtgate/engine/semi_join.go @@ -45,6 +45,7 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma return nil, err } result := &sqltypes.Result{Fields: lresult.Fields} + result.MergeStats(lresult) for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) @@ -53,6 +54,7 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma if err != nil { return nil, err } + result.MergeStats(rresult) if len(rresult.Rows) > 0 { result.Rows = append(result.Rows, lrow) } diff --git a/go/vt/vtgate/engine/semi_join_test.go b/go/vt/vtgate/engine/semi_join_test.go index 8fee0490415..172a4777ff1 100644 --- a/go/vt/vtgate/engine/semi_join_test.go +++ b/go/vt/vtgate/engine/semi_join_test.go @@ -43,6 +43,7 @@ func TestSemiJoinExecute(t *testing.T) { ), }, } + leftPrim.results[0].RowsAffected = 1 rightFields := sqltypes.MakeTestFields( "col4|col5|col6", "int64|varchar|varchar", @@ -64,6 +65,10 @@ func TestSemiJoinExecute(t *testing.T) { ), }, } + rightPrim.results[0].RowsAffected = 2 + rightPrim.results[1].RowsAffected = 3 + rightPrim.results[2].RowsAffected = 4 + bv := map[string]*querypb.BindVariable{ "a": sqltypes.Int64BindVariable(10), } @@ -85,14 +90,16 @@ func TestSemiJoinExecute(t *testing.T) { `Execute a: type:INT64 value:"10" bv: type:VARCHAR value:"b" false`, `Execute a: type:INT64 value:"10" bv: type:VARCHAR value:"c" false`, }) - utils.MustMatch(t, sqltypes.MakeTestResult( + want := sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col2|col3", "int64|varchar|varchar", ), "1|a|aa", "3|c|cc", - ), r) + ) + want.RowsAffected = 10 + utils.MustMatch(t, want, r) } func TestSemiJoinStreamExecute(t *testing.T) { diff --git a/go/vt/vtgate/engine/sequential.go b/go/vt/vtgate/engine/sequential.go index ecf74d663a2..53d6b16e81d 100644 --- a/go/vt/vtgate/engine/sequential.go +++ b/go/vt/vtgate/engine/sequential.go @@ -72,7 +72,7 @@ func (s *Sequential) TryExecute(ctx context.Context, vcursor VCursor, bindVars m if err != nil { return nil, err } - finalRes.RowsAffected += res.RowsAffected + finalRes.MergeStats(res) if finalRes.InsertID == 0 { finalRes.InsertID = res.InsertID } diff --git a/go/vt/vtgate/engine/simple_projection.go b/go/vt/vtgate/engine/simple_projection.go index 6edc5883be1..b6d5c23356a 100644 --- a/go/vt/vtgate/engine/simple_projection.go +++ b/go/vt/vtgate/engine/simple_projection.go @@ -107,7 +107,7 @@ func (sc *SimpleProjection) buildResult(inner *sqltypes.Result) *sqltypes.Result } result.Rows = append(result.Rows, row) } - result.RowsAffected = inner.RowsAffected + result.MergeStats(inner) return result } diff --git a/go/vt/vtgate/engine/simple_projection_test.go b/go/vt/vtgate/engine/simple_projection_test.go index 37c5a4d1dc0..2f39cc58ffd 100644 --- a/go/vt/vtgate/engine/simple_projection_test.go +++ b/go/vt/vtgate/engine/simple_projection_test.go @@ -21,6 +21,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -42,6 +44,7 @@ func TestSubqueryExecute(t *testing.T) { ), }, } + prim.results[0].RowsAffected = 1 sq := &SimpleProjection{ Cols: []int{0, 2}, @@ -69,6 +72,7 @@ func TestSubqueryExecute(t *testing.T) { "2|bb", "3|cc", )) + assert.Equal(t, uint64(1), r.RowsAffected) // Error case. sq.Input = &fakePrimitive{ diff --git a/go/vt/vtgate/engine/uncorrelated_subquery.go b/go/vt/vtgate/engine/uncorrelated_subquery.go index 311cd8d203a..dd760adb946 100644 --- a/go/vt/vtgate/engine/uncorrelated_subquery.go +++ b/go/vt/vtgate/engine/uncorrelated_subquery.go @@ -67,16 +67,21 @@ func (ps *UncorrelatedSubquery) GetTableName() string { // TryExecute satisfies the Primitive interface. func (ps *UncorrelatedSubquery) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars) + stats := &sqltypes.Result{} + combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars, stats) if err != nil { return nil, err } - return vcursor.ExecutePrimitive(ctx, ps.Outer, combinedVars, wantfields) + results, err := vcursor.ExecutePrimitive(ctx, ps.Outer, combinedVars, wantfields) + if results != nil { + results.MergeStats(stats) + } + return results, err } // TryStreamExecute performs a streaming exec. func (ps *UncorrelatedSubquery) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars) + combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars, &sqltypes.Result{}) if err != nil { return err } @@ -114,7 +119,7 @@ var ( errSqColumn = vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "subquery returned more than one column") ) -func (ps *UncorrelatedSubquery) execSubquery(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (map[string]*querypb.BindVariable, error) { +func (ps *UncorrelatedSubquery) execSubquery(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, stats *sqltypes.Result) (map[string]*querypb.BindVariable, error) { subqueryBindVars := make(map[string]*querypb.BindVariable, len(bindVars)) for k, v := range bindVars { subqueryBindVars[k] = v @@ -123,6 +128,7 @@ func (ps *UncorrelatedSubquery) execSubquery(ctx context.Context, vcursor VCurso if err != nil { return nil, err } + stats.MergeStats(result) combinedVars := make(map[string]*querypb.BindVariable, len(bindVars)+1) for k, v := range bindVars { combinedVars[k] = v diff --git a/go/vt/vtgate/engine/uncorrelated_subquery_test.go b/go/vt/vtgate/engine/uncorrelated_subquery_test.go index 085fe09238f..7b13da6a76c 100644 --- a/go/vt/vtgate/engine/uncorrelated_subquery_test.go +++ b/go/vt/vtgate/engine/uncorrelated_subquery_test.go @@ -21,6 +21,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -41,6 +43,7 @@ func TestPulloutSubqueryValueGood(t *testing.T) { ), "1", ) + sqResult.RowsAffected = 1 sfp := &fakePrimitive{ results: []*sqltypes.Result{sqResult}, } @@ -51,6 +54,7 @@ func TestPulloutSubqueryValueGood(t *testing.T) { ), "0", ) + underlyingResult.RowsAffected = 2 ufp := &fakePrimitive{ results: []*sqltypes.Result{underlyingResult}, } @@ -66,6 +70,7 @@ func TestPulloutSubqueryValueGood(t *testing.T) { sfp.ExpectLog(t, []string{`Execute aa: type:INT64 value:"1" false`}) ufp.ExpectLog(t, []string{`Execute aa: type:INT64 value:"1" sq: type:INT64 value:"1" false`}) expectResult(t, result, underlyingResult) + assert.Equal(t, uint64(3), result.RowsAffected) } func TestPulloutSubqueryValueNone(t *testing.T) { diff --git a/go/vt/vtgate/engine/upsert.go b/go/vt/vtgate/engine/upsert.go index 2e42452a7a4..ef69f85f486 100644 --- a/go/vt/vtgate/engine/upsert.go +++ b/go/vt/vtgate/engine/upsert.go @@ -83,7 +83,7 @@ func (u *Upsert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s if err != nil { return nil, err } - result.RowsAffected += qr.RowsAffected + result.MergeStats(qr) } return result, nil }