diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 824689d2859..a142fc8274c 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -89,6 +89,10 @@ func (l *Limit) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return result, nil } +func (l *Limit) mustRetrieveAll(vcursor VCursor) bool { + return l.RequireCompleteInput || vcursor.Session().InTransaction() +} + // TryStreamExecute satisfies the Primitive interface. func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { count, offset, err := l.getCountAndOffset(ctx, vcursor, bindVars) @@ -107,22 +111,31 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars mu.Lock() defer mu.Unlock() + inputSize := len(qr.Rows) + if inputSize == 0 { + if wantfields && len(qr.Fields) != 0 { + wantfields = false + } + return callback(qr) + } + // If this is the first callback and fields are requested, send the fields immediately. if wantfields && len(qr.Fields) != 0 { + wantfields = false + // otherwise, we need to send the fields first, and then the rows if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil { return err } } - inputSize := len(qr.Rows) - if inputSize == 0 { - return callback(qr) - } // If we still need to skip `offset` rows before returning any to the client: if offset > 0 { if inputSize <= offset { // not enough to return anything yet, but we still want to pass on metadata such as last_insert_id offset -= inputSize + if !l.mustRetrieveAll(vcursor) { + return nil + } qr.Rows = nil return callback(qr) } @@ -134,7 +147,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars // At this point, we've dealt with the offset. Now handle the count (limit). if count == 0 { // If count is zero, we've fetched everything we need. - if !l.RequireCompleteInput && !vcursor.Session().InTransaction() { + if !l.mustRetrieveAll(vcursor) { return io.EOF } @@ -159,7 +172,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars // If we required complete input or are in a transaction, we must not exit early. // We'll return empty batches until the input is done. - if l.RequireCompleteInput || vcursor.Session().InTransaction() { + if l.mustRetrieveAll(vcursor) { return nil } diff --git a/go/vt/vtgate/engine/limit_test.go b/go/vt/vtgate/engine/limit_test.go index 8b91dadecb5..8ab31610e62 100644 --- a/go/vt/vtgate/engine/limit_test.go +++ b/go/vt/vtgate/engine/limit_test.go @@ -353,9 +353,7 @@ func TestLimitOffsetExecute(t *testing.T) { t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult) } } - func TestLimitStreamExecute(t *testing.T) { - bindVars := make(map[string]*querypb.BindVariable) fields := sqltypes.MakeTestFields( "col1|col2", "int64|varchar", @@ -366,88 +364,88 @@ func TestLimitStreamExecute(t *testing.T) { "b|2", "c|3", ) - fp := &fakePrimitive{ - results: []*sqltypes.Result{inputResult}, - } - - l := &Limit{ - Count: evalengine.NewLiteralInt(2), - Input: fp, - } - // Test with limit smaller than input. - var results []*sqltypes.Result - err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error { - results = append(results, qr) - return nil - }) - require.NoError(t, err) - wantResults := sqltypes.MakeTestStreamingResults( - fields, - "a|1", - "b|2", - ) - require.Len(t, results, len(wantResults)) - for i, result := range results { - if !result.Equal(wantResults[i]) { - t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults)) - } - } + tests := []struct { + name string + countExpr evalengine.Expr + bindVars map[string]*querypb.BindVariable + want []*sqltypes.Result + RequireCompleteInput bool + }{{ + name: "limit smaller than input (literal)", + countExpr: evalengine.NewLiteralInt(2), + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + ), + }, { + name: "limit smaller than input (literal) - require complete input", + countExpr: evalengine.NewLiteralInt(2), + RequireCompleteInput: true, + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "---", // this extra result is required by RequireCompleteInput + ), + }, { + name: "limit smaller than input (bind var)", + countExpr: evalengine.NewBindVar("l", evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)), + bindVars: map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + ), + }, { + name: "limit equal to input", + countExpr: evalengine.NewLiteralInt(3), + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "---", + "c|3", + ), + }, { + name: "limit higher than input", + countExpr: evalengine.NewLiteralInt(4), + // same as limit=3 + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "---", + "c|3", + ), + }} - // Test with bind vars. - fp.rewind() - l.Count = evalengine.NewBindVar("l", evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)) - results = nil - err = l.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, true, func(qr *sqltypes.Result) error { - results = append(results, qr) - return nil - }) - require.NoError(t, err) - require.Len(t, results, len(wantResults)) - for i, result := range results { - if !result.Equal(wantResults[i]) { - t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults)) - } - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fp := &fakePrimitive{ + results: []*sqltypes.Result{inputResult}, + } - // Test with limit equal to input - fp.rewind() - l.Count = evalengine.NewLiteralInt(3) - results = nil - err = l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error { - results = append(results, qr) - return nil - }) - require.NoError(t, err) - wantResults = sqltypes.MakeTestStreamingResults( - fields, - "a|1", - "b|2", - "---", - "c|3", - ) - require.Len(t, results, len(wantResults)) - for i, result := range results { - if !result.Equal(wantResults[i]) { - t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults)) - } - } + l := &Limit{ + Count: tt.countExpr, + RequireCompleteInput: tt.RequireCompleteInput, + Input: fp, + } - // Test with limit higher than input. - fp.rewind() - l.Count = evalengine.NewLiteralInt(4) - results = nil - err = l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error { - results = append(results, qr) - return nil - }) - require.NoError(t, err) - // wantResults is same as before. - require.Len(t, results, len(wantResults)) - for i, result := range results { - if !result.Equal(wantResults[i]) { - t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults)) - } + var results []*sqltypes.Result + err := l.TryStreamExecute(context.Background(), &noopVCursor{}, tt.bindVars, true, func(qr *sqltypes.Result) error { + results = append(results, qr) + return nil + }) + require.NoError(t, err) + require.Len(t, results, len(tt.want)) + for i, result := range results { + if !result.Equal(tt.want[i]) { + t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(tt.want)) + } + } + }) } }