Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more queries with derived tables #14218

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,30 @@ func (jc JoinColumn) IsPureRight() bool {
return len(jc.LHSExprs) == 0
}

func (jc JoinColumn) Map(f func(sqlparser.Expr) (sqlparser.Expr, error)) (JoinColumn, error) {
var err error
if jc.Original != nil {
jc.Original.Expr, err = f(jc.Original.Expr)
if err != nil {
return JoinColumn{}, err
}
}

jc.RHSExpr, err = f(jc.RHSExpr)
if err != nil {
return JoinColumn{}, err
}

for i, expr := range jc.LHSExprs {
jc.LHSExprs[i].Expr, err = f(expr.Expr)
if err != nil {
return JoinColumn{}, err
}
}

return jc, nil
}

func (jc JoinColumn) IsMixedLeftAndRight() bool {
return len(jc.LHSExprs) > 0 && jc.RHSExpr != nil
}
Expand Down
10 changes: 10 additions & 0 deletions go/vt/vtgate/planbuilder/operators/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ func (dt *DerivedTable) RewriteExpression(ctx *plancontext.PlanningContext, expr
}
return semantics.RewriteDerivedTableExpression(expr, tableInfo)
}
func (dt *DerivedTable) RewriteExpression2(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (sqlparser.Expr, error) {
if dt == nil {
return expr, nil
}
tableInfo, err := ctx.SemTable.TableInfoFor(dt.TableID)
if err != nil {
return nil, err
}
return semantics.ExposeExpressionThroughDerived(expr, tableInfo), nil
}

func (dt *DerivedTable) introducesTableID() semantics.TableSet {
if dt == nil {
Expand Down
142 changes: 102 additions & 40 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ package operators
import (
"fmt"
"io"
"math/rand"
"time"
"unsafe"

"vitess.io/vitess/go/test/dbg"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite"
Expand Down Expand Up @@ -291,30 +295,25 @@ func pushProjectionInApplyJoin(
src *ApplyJoin,
) (ops.Operator, *rewrite.ApplyResult, error) {
ap, err := p.GetAliasedProjections()
if src.LeftJoin || err != nil {
if err != nil {
// we can't push down expression evaluation to the rhs if we are not sure if it will even be executed
return p, rewrite.SameTree, nil
}
lhs, rhs := &projector{}, &projector{}
if p.DT != nil && len(p.DT.Columns) > 0 {
lhs.explicitColumnAliases = true
rhs.explicitColumnAliases = true
}

explicitColumnAlias := p.DT != nil && len(p.DT.Columns) > 0
lhs := &projector{explicitColumnAliases: explicitColumnAlias}
rhs := &projector{explicitColumnAliases: explicitColumnAlias}

src.JoinColumns = nil
for idx, pe := range ap {
var col *sqlparser.IdentifierCI
if p.DT != nil && idx < len(p.DT.Columns) {
col = &p.DT.Columns[idx]
}
err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, col)
err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, p.aliasFor(idx))
if err != nil {
return nil, nil, err
}
}

if p.isDerived() {
err := exposeColumnsThroughDerivedTable(ctx, p, src, lhs)
err := exposeColumnsThroughDerivedTable(ctx, p, src, lhs, rhs)
if err != nil {
return nil, nil, err
}
Expand All @@ -334,6 +333,14 @@ func pushProjectionInApplyJoin(
return src, rewrite.NewTree("split projection to either side of join", src), nil
}

func (p *Projection) aliasFor(idx int) *sqlparser.IdentifierCI {
if p.DT == nil || idx >= len(p.DT.Columns) {
return nil
}

return &p.DT.Columns[idx]
}

// splitProjectionAcrossJoin creates JoinPredicates for all projections,
// and pushes down columns as needed between the LHS and RHS of a join
func splitProjectionAcrossJoin(
Expand Down Expand Up @@ -408,45 +415,100 @@ func splitUnexploredExpression(
// The function iterates through each join predicate, rewriting the expressions in the predicate's
// LHS expressions to include the derived table. This allows the expressions to be accessed outside
// the derived table.
func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs *projector) error {
derivedTbl, err := ctx.SemTable.TableInfoFor(p.DT.TableID)
if err != nil {
return err
}
derivedTblName, err := derivedTbl.Name()
if err != nil {
return err
func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs *projector, rhs *projector) error {
if p.DT == nil {
return nil
}
for _, predicate := range src.JoinPredicates {
for idx, bve := range predicate.LHSExprs {
expr := bve.Expr
tbl, err := ctx.SemTable.TableInfoForExpr(expr)
if err != nil {
return err

cols := p.Columns.GetColumns()
f := func(expr sqlparser.Expr) (e sqlparser.Expr, err error) {
rewriter := func(cursor *sqlparser.CopyOnWriteCursor) {
this, ok := cursor.Node().(sqlparser.Expr)
if !ok {
return
}
tblExpr := tbl.GetExpr()
tblName, err := tblExpr.TableName()
if err != nil {
return err

// If we didn't find it, and we are dealing with a ColName, we have to add it to the derived table
colExpr, ok := this.(*sqlparser.ColName)
if !ok {
return
}

expr = semantics.RewriteDerivedTableExpression(expr, derivedTbl)
out := prefixColNames(tblName, expr)
// First we check if this expression is already being returned
for _, column := range cols {
if ctx.SemTable.EqualsExprWithDeps(column.Expr, colExpr) {
col := sqlparser.NewColName(column.ColumnName())
cursor.Replace(col)
return
}
}

colAlias := fmt.Sprintf("%s_vt_%s_%s", p.DT.Alias, sqlparser.String(colExpr), RandString(2))

alias := sqlparser.UnescapedString(out)
predicate.LHSExprs[idx].Expr = sqlparser.NewColNameWithQualifier(alias, derivedTblName)
identifierCI := sqlparser.NewIdentifierCI(alias)
projExpr := newProjExprWithInner(&sqlparser.AliasedExpr{Expr: out, As: identifierCI}, out)
var colAlias *sqlparser.IdentifierCI
if lhs.explicitColumnAliases {
colAlias = &identifierCI
newCol := sqlparser.NewColName(colAlias)
inner := newProjExprWithInner(aeWrap(newCol), colExpr)
_, thisErr := p.addProjExpr(inner)
if thisErr != nil {
err = thisErr
cursor.StopTreeWalk()
return
}
lhs.add(projExpr, colAlias)
cursor.Replace(newCol)
}

e = sqlparser.CopyOnRewrite(expr, nil, rewriter, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr)
if expr == e {
panic(dbg.S())
}

return e, nil
}

//for i, pred := range src.JoinPredicates {
// x, err := pred.Map(f)
// if err != nil {
// return err
// }
// src.JoinPredicates[i] = x
//}

for i, col := range src.JoinColumns {
var err error
src.JoinColumns[i], err = col.Map(f)
if err != nil {
return err
}
}
return nil
}

var src = rand.NewSource(time.Now().UnixNano())

const letterBytes = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
const (
letterIdxBits = len(letterBytes) / 8 // bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)

func RandString(n int) string {
b := make([]byte, n)
// A src.Int63() generates 63 random bits, enough for letterIdxMax characters!
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}

return *(*string)(unsafe.Pointer(&b))
}

// prefixColNames adds qualifier prefixes to all ColName:s.
// We want to be more explicit than the user was to make sure we never produce invalid SQL
func prefixColNames(tblName sqlparser.TableName, e sqlparser.Expr) sqlparser.Expr {
Expand Down
34 changes: 11 additions & 23 deletions go/vt/vtgate/planbuilder/operators/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,23 +342,19 @@ func getJoinFor(ctx *plancontext.PlanningContext, cm opCacheMap, lhs, rhs ops.Op
return join, nil
}

// requiresSwitchingSides will return true if any of the operators with the root from the given operator tree
// is of the type that should not be on the RHS of a join
func requiresSwitchingSides(ctx *plancontext.PlanningContext, op ops.Operator) bool {
required := false

// needsLimit will return true this op tree requires LIMIT
func needsLimit(op ops.Operator) (required bool) {
_ = rewrite.Visit(op, func(current ops.Operator) error {
horizon, isHorizon := current.(*Horizon)

if isHorizon && horizon.IsDerived() && !horizon.IsMergeable(ctx) {
if isHorizon && horizon.Query.GetLimit() != nil {
required = true
return io.EOF
}

return nil
})

return required
return
}

func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPredicates []sqlparser.Expr, inner bool) (ops.Operator, *rewrite.ApplyResult, error) {
Expand All @@ -367,29 +363,21 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPr
return newPlan, rewrite.NewTree("merge routes into single operator", newPlan), nil
}

if len(joinPredicates) > 0 && requiresSwitchingSides(ctx, rhs) {
if !inner {
return nil, nil, vterrors.VT12001("LEFT JOIN with derived tables")
}

if requiresSwitchingSides(ctx, lhs) {
return nil, nil, vterrors.VT12001("JOIN between derived tables")
}

join := NewApplyJoin(Clone(rhs), Clone(lhs), nil, !inner)
newOp, err := pushJoinPredicates(ctx, joinPredicates, join)
if err != nil {
return nil, nil, err
message := "logical join to applyJoin"
if needsLimit(rhs) {
if needsLimit(lhs) || !inner {
return nil, nil, vterrors.VT12001("can't handle join with limit on the RHS")
}
return newOp, rewrite.NewTree("logical join to applyJoin, switching side because derived table", newOp), nil
lhs, rhs = rhs, lhs
message += ", switch sides because limit"
}

join := NewApplyJoin(Clone(lhs), Clone(rhs), nil, !inner)
newOp, err := pushJoinPredicates(ctx, joinPredicates, join)
if err != nil {
return nil, nil, err
}
return newOp, rewrite.NewTree("logical join to applyJoin ", newOp), nil
return newOp, rewrite.NewTree(message, newOp), nil
}

func operatorsToRoutes(a, b ops.Operator) (*Route, *Route) {
Expand Down
Loading
Loading