Skip to content

Commit

Permalink
push values table under route and add values table to values join
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
Signed-off-by: Florent Poinsard <[email protected]>
  • Loading branch information
harshit-gangal authored and frouioui committed Jan 8, 2025
1 parent b73aded commit f584098
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 41 deletions.
43 changes: 43 additions & 0 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 18 additions & 7 deletions go/vt/vtgate/engine/join_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ var _ Primitive = (*JoinValues)(nil)
type JoinValues struct {
// Left and Right are the LHS and RHS primitives
// of the Join. They can be any primitive.
// When Left is empty, the WhenLeftEmpty primitive will be executed.
Left, Right, WhenLeftEmpty Primitive
Left, Right Primitive

Vars map[string]int
Columns []string
Expand All @@ -44,12 +43,24 @@ func (jv *JoinValues) TryExecute(ctx context.Context, vcursor VCursor, bindVars
if err != nil {
return nil, err
}
if len(lresult.Rows) == 0 && wantfields {
return jv.WhenLeftEmpty.GetFields(ctx, vcursor, bindVars)
}
bv := &querypb.BindVariable{
Type: querypb.Type_TUPLE,
}
if len(lresult.Rows) == 0 && wantfields {
// If there are no rows, we still need to construct a single row
// to send down to RHS for Values Table to execute correctly.
// It will be used to execute the field query to provide the output fields.
var vals []sqltypes.Value
for _, field := range lresult.Fields {
val, _ := sqltypes.NewValue(field.Type, nil)
vals = append(vals, val)
}
bv.Values = append(bv.Values, sqltypes.TupleToProto(vals))

bindVars[jv.RowConstructorArg] = bv
return jv.Right.GetFields(ctx, vcursor, bindVars)
}

for _, row := range lresult.Rows {
bv.Values = append(bv.Values, sqltypes.TupleToProto(row))
}
Expand All @@ -64,12 +75,12 @@ func (jv *JoinValues) TryStreamExecute(ctx context.Context, vcursor VCursor, bin

// GetFields fetches the field info.
func (jv *JoinValues) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return jv.WhenLeftEmpty.GetFields(ctx, vcursor, bindVars)
return jv.Right.GetFields(ctx, vcursor, bindVars)
}

// Inputs returns the input primitives for this join
func (jv *JoinValues) Inputs() ([]Primitive, []map[string]any) {
return []Primitive{jv.Left, jv.Right, jv.WhenLeftEmpty}, nil
return []Primitive{jv.Left, jv.Right}, nil
}

// RouteType returns a description of the query routing type used by the primitive
Expand Down
2 changes: 0 additions & 2 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ func transformValuesJoin(ctx *plancontext.PlanningContext, op *operators.ValuesJ
Right: rhs,
Vars: op.Vars,
Columns: op.Columns,

WhenLeftEmpty: lhs, // wip florent
}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ func buildValuesTable(vt *ValuesTable, qb *queryBuilder) {
},
},

As: sqlparser.NewIdentifierCS(vt.ListArgName),
As: sqlparser.NewIdentifierCS(vt.TableName),
Hints: nil,
Columns: slice.Map(cols, func(s string) sqlparser.IdentifierCI {
return sqlparser.NewIdentifierCI(s)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ func crossJoin(ctx *plancontext.PlanningContext, exprs sqlparser.TableExprs) Ope
if output == nil {
output = op
} else {
output = createJoin(ctx, output, op)
output = createLogicalJoin(ctx, output, op)
}
}
return output
Expand Down
34 changes: 17 additions & 17 deletions go/vt/vtgate/planbuilder/operators/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import (
"vitess.io/vitess/go/vt/vtgate/semantics"
)

// Join represents a join. If we have a predicate, this is an inner join. If no predicate exists, it is a cross join
type Join struct {
// LogicalJoin represents a join. If we have a predicate, this is an inner join. If no predicate exists, it is a cross join
type LogicalJoin struct {
binaryOperator
Predicate sqlparser.Expr
// JoinType is permitted to store only 3 of the possible values
Expand All @@ -35,21 +35,21 @@ type Join struct {
noColumns
}

var _ Operator = (*Join)(nil)
var _ Operator = (*LogicalJoin)(nil)

// Clone implements the Operator interface
func (j *Join) Clone(inputs []Operator) Operator {
func (j *LogicalJoin) Clone(inputs []Operator) Operator {
clone := *j
clone.LHS = inputs[0]
clone.RHS = inputs[1]
return &clone
}

func (j *Join) GetOrdering(*plancontext.PlanningContext) []OrderBy {
func (j *LogicalJoin) GetOrdering(*plancontext.PlanningContext) []OrderBy {
return nil
}

func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) {
func (j *LogicalJoin) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) {
if !j.JoinType.IsCommutative() {
// if we can't move tables around, we can't merge these inputs
return j, NoRewrite
Expand All @@ -74,7 +74,7 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult

func createStraightJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
// for inner joins we can treat the predicates as filters on top of the join
joinOp := &Join{
joinOp := &LogicalJoin{
binaryOperator: newBinaryOp(lhs, rhs),
JoinType: join.Join,
}
Expand All @@ -93,7 +93,7 @@ func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinT
join.Join = sqlparser.NaturalLeftJoinType
}

joinOp := &Join{
joinOp := &LogicalJoin{
binaryOperator: newBinaryOp(lhs, rhs),
JoinType: join.Join,
}
Expand All @@ -116,7 +116,7 @@ func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinT
}

func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
op := createJoin(ctx, lhs, rhs)
op := createLogicalJoin(ctx, lhs, rhs)
return addJoinPredicates(ctx, tableExpr.Condition.On, op)
}

Expand Down Expand Up @@ -177,7 +177,7 @@ func breakCTEExpressionInLhsAndRhs(ctx *plancontext.PlanningContext, pred sqlpar
}
}

func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
func createLogicalJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
lqg, lok := LHS.(*QueryGraph)
rqg, rok := RHS.(*QueryGraph)
if lok && rok {
Expand All @@ -188,32 +188,32 @@ func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
}
return op
}
return &Join{
return &LogicalJoin{
binaryOperator: newBinaryOp(LHS, RHS),
}
}

func (j *Join) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator {
func (j *LogicalJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator {
return AddPredicate(ctx, j, expr, false, newFilterSinglePredicate)
}

var _ JoinOp = (*Join)(nil)
var _ JoinOp = (*LogicalJoin)(nil)

func (j *Join) MakeInner() {
func (j *LogicalJoin) MakeInner() {
if j.IsInner() {
return
}
j.JoinType = sqlparser.NormalJoinType
}

func (j *Join) IsInner() bool {
func (j *LogicalJoin) IsInner() bool {
return j.JoinType.IsInner()
}

func (j *Join) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
func (j *LogicalJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
j.Predicate = ctx.SemTable.AndExpressions(j.Predicate, expr)
}

func (j *Join) ShortDescription() string {
func (j *LogicalJoin) ShortDescription() string {
return sqlparser.String(j.Predicate)
}
15 changes: 12 additions & 3 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator {
switch in := in.(type) {
case *Horizon:
return pushOrExpandHorizon(ctx, in)
case *Join:
case *LogicalJoin:
return optimizeJoin(ctx, in)
case *Projection:
return tryPushProjection(ctx, in)
Expand Down Expand Up @@ -104,7 +104,8 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator {
return tryPushUpdate(in)
case *RecurseCTE:
return tryMergeRecurse(ctx, in)

case *ValuesTable:
return tryPushValuesTable(in)
default:
return in, NoRewrite
}
Expand All @@ -120,6 +121,14 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator {
return FixedPointBottomUp(root, TableID, visitor, stopAtRoute)
}

func tryPushValuesTable(in *ValuesTable) (Operator, *ApplyResult) {
r, ok := in.Source.(*Route)
if ok {
return Swap(in, r, "push values table under route")
}
return in, NoRewrite
}

func tryPushDelete(in *Delete) (Operator, *ApplyResult) {
if src, ok := in.Source.(*Route); ok {
return pushDMLUnderRoute(in, src, "pushed delete under route")
Expand Down Expand Up @@ -482,7 +491,7 @@ func setUpperLimit(in *Limit) (Operator, *ApplyResult) {
var result *ApplyResult
shouldVisit := func(op Operator) VisitRule {
switch op := op.(type) {
case *Join, *ApplyJoin, *SubQueryContainer, *SubQuery:
case *LogicalJoin, *ApplyJoin, *SubQueryContainer, *SubQuery:
// we can't push limits down on either side
return SkipChildren
case *Aggregator:
Expand Down
16 changes: 8 additions & 8 deletions go/vt/vtgate/planbuilder/operators/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func pushDerived(ctx *plancontext.PlanningContext, op *Horizon) (Operator, *Appl
return Swap(op, op.Source, "push derived under route")
}

func optimizeJoin(ctx *plancontext.PlanningContext, op *Join) (Operator, *ApplyResult) {
func optimizeJoin(ctx *plancontext.PlanningContext, op *LogicalJoin) (Operator, *ApplyResult) {
return mergeOrJoin(ctx, op.LHS, op.RHS, sqlparser.SplitAndExpression(nil, op.Predicate), op.JoinType)
}

Expand Down Expand Up @@ -290,13 +290,13 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op Operator) (requ
}

// Will create a join valid for the current mysql version
func createVersionJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType, joinPredicates []sqlparser.Expr) (join JoinOp) {
ok, err := capabilities.MySQLVersionHasCapability(ctx.VSchema.Environment().MySQLVersion(), capabilities.ValuesRow)
if !ok || err != nil {
func createJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType, joinPredicates []sqlparser.Expr) (join JoinOp) {
ok, _ := capabilities.MySQLVersionHasCapability(ctx.VSchema.Environment().MySQLVersion(), capabilities.ValuesRow)
if ok {
join = newValuesJoin(ctx, lhs, rhs, joinType)
} else {
// if we can't determine the MySQL version, we'll just assume we can't use the VALUES row
join = NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, joinType)
} else {
join = newValuesJoin(ctx, lhs, rhs, joinType)
}

for _, pred := range joinPredicates {
Expand All @@ -323,11 +323,11 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredic
return join, Rewrote("use a hash join because we have LIMIT on the LHS")
}

join := createVersionJoin(ctx, Clone(rhs), Clone(lhs), joinType, joinPredicates)
join := createJoin(ctx, Clone(rhs), Clone(lhs), joinType, joinPredicates)
return join, Rewrote("logical join to applyJoin, switching side because LIMIT")
}

join := createVersionJoin(ctx, Clone(lhs), Clone(rhs), joinType, joinPredicates)
join := createJoin(ctx, Clone(lhs), Clone(rhs), joinType, joinPredicates)

return join, Rewrote("logical join to applyJoin ")
}
Expand Down
7 changes: 6 additions & 1 deletion go/vt/vtgate/planbuilder/operators/values_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,16 @@ func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType
if err == nil {
name, err := tbl.Name()
if err == nil {
tblName = sqlparser.String(name)
tblName = name.Name.String()
}
}
}
listArg := ctx.GetReservedArgumentForString(tblName)
rhs = &ValuesTable{
unaryOperator: newUnaryOp(rhs),
ListArgName: listArg,
TableName: tblName,
}
return &ValuesJoin{
binaryOperator: newBinaryOp(lhs, rhs),
JoinType: joinType,
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/onecase.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"comment": "Add your test case here for debugging and run go test -run=One.",
"query": "select u.foo+ue.bar from user u join user_extra ue on u.val = ue.user_id",
"query": "",
"plan": {
}
}
Expand Down

0 comments on commit f584098

Please sign in to comment.