Skip to content

Commit

Permalink
feat: optimise outer joins (#15840)
Browse files Browse the repository at this point in the history
  • Loading branch information
systay authored May 24, 2024
1 parent 7c6d5e5 commit 0cc5acd
Show file tree
Hide file tree
Showing 26 changed files with 642 additions and 372 deletions.
8 changes: 4 additions & 4 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,12 @@ func TestAliasesInOuterJoinQueries(t *testing.T) {
mcmp.Exec("insert into t1(id1, id2) values (1,2), (42,5), (5, 42)")
mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (1,2,3), (2,5,3), (3, 42, 2)")

// Check that the select query works as intended and then run it again verifying the column names as well.
mcmp.AssertMatches("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col", `[[INT64(1) INT64(1) INT64(42)] [INT64(5) INT64(5) NULL] [INT64(42) INT64(42) NULL]]`)
// Check that the select query works as intended and verifying the column names as well.
mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col")

mcmp.AssertMatches("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col order by t1.id2 limit 2", `[[INT64(1) INT64(1) INT64(42)] [INT64(42) INT64(42) NULL]]`)
mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col order by t1.id2 limit 2")
mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, tbl.unq_col as col from t1 left outer join tbl on t1.id2 = tbl.nonunq_col order by t1.id2 limit 2 offset 2")
mcmp.ExecWithColumnCompare("select t1.id1 as t0, t1.id1 as t1, count(*) as leCount from t1 left outer join tbl on t1.id2 = tbl.nonunq_col group by 1, t1")
mcmp.ExecWithColumnCompare("select t.id1, t.id2, derived.unq_col from t1 t join (select id, unq_col, nonunq_col from tbl) as derived on t.id2 = derived.nonunq_col")
}

func TestAlterTableWithView(t *testing.T) {
Expand Down
15 changes: 8 additions & 7 deletions go/test/endtoend/vtgate/queries/misc/schema.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
create table if not exists t1(
id1 bigint,
id2 bigint,
primary key(id1)
create table t1
(
id1 bigint,
id2 bigint,
primary key (id1)
) Engine=InnoDB;

create table unq_idx
Expand Down Expand Up @@ -30,8 +31,8 @@ create table tbl

create table tbl_enum_set
(
id bigint,
enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'),
set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'),
id bigint,
enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'),
set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'),
primary key (id)
) Engine = InnoDB;
7 changes: 5 additions & 2 deletions go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type Limit struct {
Input Primitive
}

var UpperLimitStr = "__upper_limit"

// RouteType returns a description of the query routing type used by the primitive
func (l *Limit) RouteType() string {
return l.Input.RouteType()
Expand All @@ -63,7 +65,8 @@ func (l *Limit) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
}
// When offset is present, we hijack the limit value so we can calculate
// the offset in memory from the result of the scatter query with count + offset.
bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))

bindVars[UpperLimitStr] = sqltypes.Int64BindVariable(int64(count + offset))

result, err := vcursor.ExecutePrimitive(ctx, l.Input, bindVars, wantfields)
if err != nil {
Expand Down Expand Up @@ -96,7 +99,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars

// When offset is present, we hijack the limit value so we can calculate
// the offset in memory from the result of the scatter query with count + offset.
bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))
bindVars[UpperLimitStr] = sqltypes.Int64BindVariable(int64(count + offset))

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
Expand Down
41 changes: 33 additions & 8 deletions go/vt/vtgate/engine/simple_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ package engine

import (
"context"
"fmt"
"strconv"
"strings"

"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)
Expand Down Expand Up @@ -90,6 +94,10 @@ func (sc *SimpleProjection) Inputs() ([]Primitive, []map[string]any) {
// buildResult builds a new result by pulling the necessary columns from
// the input in the requested order.
func (sc *SimpleProjection) buildResult(inner *sqltypes.Result) *sqltypes.Result {
if sc.namesOnly() {
sc.renameFields(inner.Fields)
return inner
}
result := &sqltypes.Result{Fields: sc.buildFields(inner)}
result.Rows = make([][]sqltypes.Value, 0, len(inner.Rows))
for _, innerRow := range inner.Rows {
Expand All @@ -103,6 +111,10 @@ func (sc *SimpleProjection) buildResult(inner *sqltypes.Result) *sqltypes.Result
return result
}

func (sc *SimpleProjection) namesOnly() bool {
return sc.Cols == nil
}

func (sc *SimpleProjection) buildFields(inner *sqltypes.Result) []*querypb.Field {
if len(inner.Fields) == 0 {
return nil
Expand All @@ -119,20 +131,33 @@ func (sc *SimpleProjection) buildFields(inner *sqltypes.Result) []*querypb.Field
return fields
}

func (sc *SimpleProjection) renameFields(fields []*querypb.Field) {
if len(fields) == 0 {
return
}
for idx, name := range sc.ColNames {
if sc.ColNames[idx] != "" {
fields[idx].Name = name
}
}
}

func (sc *SimpleProjection) description() PrimitiveDescription {
other := map[string]any{
"Columns": sc.Cols,
other := map[string]any{}
if !sc.namesOnly() {
other["Columns"] = strings.Join(slice.Map(sc.Cols, strconv.Itoa), ",")
}
emptyColNames := true
for _, cName := range sc.ColNames {

var colNames []string
for idx, cName := range sc.ColNames {
if cName != "" {
emptyColNames = false
break
colNames = append(colNames, fmt.Sprintf("%d:%s", idx, cName))
}
}
if !emptyColNames {
other["ColumnNames"] = sc.ColNames
if colNames != nil {
other["ColumnNames"] = colNames
}

return PrimitiveDescription{
OperatorType: "SimpleProjection",
Other: other,
Expand Down
60 changes: 26 additions & 34 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,7 @@ func TestSelectScatterLimit(t *testing.T) {
require.NoError(t, err)

wantQueries := []*querypb.BoundQuery{{
Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit :__upper_limit",
Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit 3",
BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)},
}}
for _, conn := range conns {
Expand Down Expand Up @@ -2401,7 +2401,7 @@ func TestStreamSelectScatterLimit(t *testing.T) {
require.NoError(t, err)

wantQueries := []*querypb.BoundQuery{{
Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit :__upper_limit",
Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit 3",
BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)},
}}
for _, conn := range conns {
Expand Down Expand Up @@ -2863,11 +2863,11 @@ func TestEmptyJoinRecursiveStream(t *testing.T) {
}
}

func TestCrossShardSubquery(t *testing.T) {
func TestCrossShardDerivedTable(t *testing.T) {
executor, sbc1, sbc2, _, ctx := createExecutorEnv(t)
result1 := []*sqltypes.Result{{
Fields: []*querypb.Field{
{Name: "id", Type: sqltypes.Int32},
{Name: "id1", Type: sqltypes.Int32},
{Name: "col", Type: sqltypes.Int32},
},
InsertID: 0,
Expand All @@ -2894,10 +2894,8 @@ func TestCrossShardSubquery(t *testing.T) {
}}
utils.MustMatch(t, wantQueries, sbc2.Queries)

wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id", "int32"), "1")
if !result.Equal(wantResult) {
t.Errorf("result: %+v, want %+v", result, wantResult)
}
wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id1", "int32"), "1")
assert.Equal(t, wantResult, result)
}

func TestSubQueryAndQueryWithLimit(t *testing.T) {
Expand Down Expand Up @@ -2946,7 +2944,7 @@ func TestCrossShardSubqueryStream(t *testing.T) {
executor, sbc1, sbc2, _, ctx := createExecutorEnv(t)
result1 := []*sqltypes.Result{{
Fields: []*querypb.Field{
{Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
{Name: "id1", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
{Name: "col", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
},
InsertID: 0,
Expand All @@ -2971,18 +2969,16 @@ func TestCrossShardSubqueryStream(t *testing.T) {

wantResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
{Name: "id1", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
},
Rows: [][]sqltypes.Value{{
sqltypes.NewInt32(1),
}},
}
if !result.Equal(wantResult) {
t.Errorf("result: %+v, want %+v", result, wantResult)
}
assert.Equal(t, wantResult, result)
}

func TestCrossShardSubqueryGetFields(t *testing.T) {
func TestCrossShardDerivedTableGetFields(t *testing.T) {
executor, sbc1, _, sbclookup, ctx := createExecutorEnv(t)
sbclookup.SetResults([]*sqltypes.Result{{
Fields: []*querypb.Field{
Expand All @@ -2991,7 +2987,7 @@ func TestCrossShardSubqueryGetFields(t *testing.T) {
}})
result1 := []*sqltypes.Result{{
Fields: []*querypb.Field{
{Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
{Name: "id1", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
{Name: "col", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
},
}}
Expand All @@ -3015,12 +3011,10 @@ func TestCrossShardSubqueryGetFields(t *testing.T) {
wantResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "col", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
{Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
{Name: "id1", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)},
},
}
if !result.Equal(wantResult) {
t.Errorf("result: %+v, want %+v", result, wantResult)
}
assert.Equal(t, wantResult, result)
}

func TestSelectBindvarswithPrepare(t *testing.T) {
Expand All @@ -3042,9 +3036,7 @@ func TestSelectBindvarswithPrepare(t *testing.T) {
BindVariables: map[string]*querypb.BindVariable{"id": sqltypes.Int64BindVariable(1)},
}}
utils.MustMatch(t, wantQueries, sbc1.Queries)
if sbc2.Queries != nil {
t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries)
}
assert.Empty(t, sbc2.Queries)
}

func TestSelectDatabasePrepare(t *testing.T) {
Expand Down Expand Up @@ -3908,14 +3900,14 @@ func TestSelectAggregationNoData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64")),
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(0)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary")),
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[]`,
},
Expand Down Expand Up @@ -4000,70 +3992,70 @@ func TestSelectAggregationData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64"), "100|200|1", "200|300|1"),
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary"), "100|3|1|NULL", "200|2|1|NULL"),
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit 9",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`,
},
{
sql: `select count(col1) from (select id, col1 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1", "int64|varchar"), "1|a", "2|b"),
expSandboxQ: "select x.id, x.col1 from (select id, col1 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.id, x.col1 from (select id, col1 from `user`) as x limit 2",
expField: `[name:"count(col1)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select count(col1), col2 from (select col2, col1 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|col1|weight_string(col2)", "int64|varchar|varbinary"), "3|a|NULL", "2|b|NULL"),
expSandboxQ: "select x.col2, x.col1, weight_string(x.col2) from (select col2, col1 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col2, x.col1, weight_string(x.col2) from (select col2, col1 from `user`) as x limit 9",
expField: `[name:"count(col1)" type:INT64 name:"col2" type:INT64]`,
expRow: `[[INT64(4) INT64(2)] [INT64(5) INT64(3)]]`,
},
{
sql: `select col1, count(col2) from (select col1, col2 from user limit 9) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|1|a", "b|null|b"),
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit 9",
expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`,
expRow: `[[VARCHAR("a") INT64(5)] [VARCHAR("b") INT64(0)]]`,
},
{
sql: `select col1, count(col2) from (select col1, col2 from user limit 32) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "null|1|null", "null|null|null", "a|1|a", "b|null|b"),
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit 32",
expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`,
expRow: `[[NULL INT64(8)] [VARCHAR("a") INT64(8)] [VARCHAR("b") INT64(0)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|3|a"),
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit 4",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:DECIMAL]`,
expRow: `[[VARCHAR("a") DECIMAL(12)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|2|a"),
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit 4",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") FLOAT64(8)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|x|a"),
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit 4",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") FLOAT64(0)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|null|a"),
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit 4",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") NULL]]`,
},
Expand Down
20 changes: 16 additions & 4 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,10 @@ func transformProjection(ctx *plancontext.PlanningContext, op *operators.Project
if cols, colNames := op.AllOffsets(); cols != nil {
// if all this op is doing is passing through columns from the input, we
// can use the faster SimpleProjection
return useSimpleProjection(cols, colNames, src)
if len(op.Source.GetColumns(ctx)) == len(cols) && offsetInInputOrder(cols) {
cols = nil
}
return newSimpleProjection(cols, colNames, src)
}

ap, err := op.GetAliasedProjections()
Expand Down Expand Up @@ -393,6 +396,16 @@ func transformProjection(ctx *plancontext.PlanningContext, op *operators.Project
}, nil
}

// offsetInInputOrder returns true if the columns are in the same order as the input
func offsetInInputOrder(cols []int) bool {
for i, c := range cols {
if c != i {
return false
}
}
return true
}

func getEvalEngingeExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr) (evalengine.Expr, error) {
switch e := pe.Info.(type) {
case *operators.EvalEngine:
Expand All @@ -406,9 +419,8 @@ func getEvalEngingeExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr

}

// useSimpleProjection uses nothing at all if the output is already correct,
// or SimpleProjection when we have to reorder or truncate the columns
func useSimpleProjection(cols []int, colNames []string, src logicalPlan) (logicalPlan, error) {
// newSimpleProjection creates a simple projections
func newSimpleProjection(cols []int, colNames []string, src logicalPlan) (logicalPlan, error) {
return &simpleProjection{
logicalPlanCommon: newBuilderCommon(src),
eSimpleProj: &engine.SimpleProjection{
Expand Down
Loading

0 comments on commit 0cc5acd

Please sign in to comment.