Skip to content

Commit

Permalink
feat: clean up limit code
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Dec 19, 2024
1 parent bea0515 commit 61c0b64
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 86 deletions.
25 changes: 19 additions & 6 deletions go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ func (l *Limit) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
return result, nil
}

func (l *Limit) mustRetrieveAll(vcursor VCursor) bool {
return l.RequireCompleteInput || vcursor.Session().InTransaction()
}

// TryStreamExecute satisfies the Primitive interface.
func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
count, offset, err := l.getCountAndOffset(ctx, vcursor, bindVars)
Expand All @@ -107,22 +111,31 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
mu.Lock()
defer mu.Unlock()

inputSize := len(qr.Rows)
if inputSize == 0 {
if wantfields && len(qr.Fields) != 0 {
wantfields = false
}
return callback(qr)
}

// If this is the first callback and fields are requested, send the fields immediately.
if wantfields && len(qr.Fields) != 0 {
wantfields = false
// otherwise, we need to send the fields first, and then the rows
if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
}
}
inputSize := len(qr.Rows)
if inputSize == 0 {
return callback(qr)
}

// If we still need to skip `offset` rows before returning any to the client:
if offset > 0 {
if inputSize <= offset {
// not enough to return anything yet, but we still want to pass on metadata such as last_insert_id
offset -= inputSize
if !l.mustRetrieveAll(vcursor) {
return nil
}
qr.Rows = nil
return callback(qr)
}
Expand All @@ -134,7 +147,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
// At this point, we've dealt with the offset. Now handle the count (limit).
if count == 0 {
// If count is zero, we've fetched everything we need.
if !l.RequireCompleteInput && !vcursor.Session().InTransaction() {
if !l.mustRetrieveAll(vcursor) {
return io.EOF
}

Expand All @@ -159,7 +172,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars

// If we required complete input or are in a transaction, we must not exit early.
// We'll return empty batches until the input is done.
if l.RequireCompleteInput || vcursor.Session().InTransaction() {
if l.mustRetrieveAll(vcursor) {
return nil
}

Expand Down
158 changes: 78 additions & 80 deletions go/vt/vtgate/engine/limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,7 @@ func TestLimitOffsetExecute(t *testing.T) {
t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult)
}
}

func TestLimitStreamExecute(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
"col1|col2",
"int64|varchar",
Expand All @@ -366,88 +364,88 @@ func TestLimitStreamExecute(t *testing.T) {
"b|2",
"c|3",
)
fp := &fakePrimitive{
results: []*sqltypes.Result{inputResult},
}

l := &Limit{
Count: evalengine.NewLiteralInt(2),
Input: fp,
}

// Test with limit smaller than input.
var results []*sqltypes.Result
err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
results = append(results, qr)
return nil
})
require.NoError(t, err)
wantResults := sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
)
require.Len(t, results, len(wantResults))
for i, result := range results {
if !result.Equal(wantResults[i]) {
t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
}
}
tests := []struct {
name string
countExpr evalengine.Expr
bindVars map[string]*querypb.BindVariable
want []*sqltypes.Result
RequireCompleteInput bool
}{{
name: "limit smaller than input (literal)",
countExpr: evalengine.NewLiteralInt(2),
want: sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
),
}, {
name: "limit smaller than input (literal) - require complete input",
countExpr: evalengine.NewLiteralInt(2),
RequireCompleteInput: true,
want: sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"---", // this extra result is required by RequireCompleteInput
),
}, {
name: "limit smaller than input (bind var)",
countExpr: evalengine.NewBindVar("l", evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)),
bindVars: map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)},
want: sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
),
}, {
name: "limit equal to input",
countExpr: evalengine.NewLiteralInt(3),
want: sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"---",
"c|3",
),
}, {
name: "limit higher than input",
countExpr: evalengine.NewLiteralInt(4),
// same as limit=3
want: sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"---",
"c|3",
),
}}

// Test with bind vars.
fp.rewind()
l.Count = evalengine.NewBindVar("l", evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID))
results = nil
err = l.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, true, func(qr *sqltypes.Result) error {
results = append(results, qr)
return nil
})
require.NoError(t, err)
require.Len(t, results, len(wantResults))
for i, result := range results {
if !result.Equal(wantResults[i]) {
t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fp := &fakePrimitive{
results: []*sqltypes.Result{inputResult},
}

// Test with limit equal to input
fp.rewind()
l.Count = evalengine.NewLiteralInt(3)
results = nil
err = l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
results = append(results, qr)
return nil
})
require.NoError(t, err)
wantResults = sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"---",
"c|3",
)
require.Len(t, results, len(wantResults))
for i, result := range results {
if !result.Equal(wantResults[i]) {
t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
}
}
l := &Limit{
Count: tt.countExpr,
RequireCompleteInput: tt.RequireCompleteInput,
Input: fp,
}

// Test with limit higher than input.
fp.rewind()
l.Count = evalengine.NewLiteralInt(4)
results = nil
err = l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
results = append(results, qr)
return nil
})
require.NoError(t, err)
// wantResults is same as before.
require.Len(t, results, len(wantResults))
for i, result := range results {
if !result.Equal(wantResults[i]) {
t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
}
var results []*sqltypes.Result
err := l.TryStreamExecute(context.Background(), &noopVCursor{}, tt.bindVars, true, func(qr *sqltypes.Result) error {
results = append(results, qr)
return nil
})
require.NoError(t, err)
require.Len(t, results, len(tt.want))
for i, result := range results {
if !result.Equal(tt.want[i]) {
t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(tt.want))
}
}
})
}
}

Expand Down

0 comments on commit 61c0b64

Please sign in to comment.