diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 138c17f2da7..6621a9d00eb 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -382,6 +382,30 @@ func (jc JoinColumn) IsPureRight() bool { return len(jc.LHSExprs) == 0 } +func (jc JoinColumn) Map(f func(sqlparser.Expr) (sqlparser.Expr, error)) (JoinColumn, error) { + var err error + if jc.Original != nil { + jc.Original.Expr, err = f(jc.Original.Expr) + if err != nil { + return JoinColumn{}, err + } + } + + jc.RHSExpr, err = f(jc.RHSExpr) + if err != nil { + return JoinColumn{}, err + } + + for i, expr := range jc.LHSExprs { + jc.LHSExprs[i].Expr, err = f(expr.Expr) + if err != nil { + return JoinColumn{}, err + } + } + + return jc, nil +} + func (jc JoinColumn) IsMixedLeftAndRight() bool { return len(jc.LHSExprs) > 0 && jc.RHSExpr != nil } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 2d4630bd87a..2721838c644 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -66,6 +66,16 @@ func (dt *DerivedTable) RewriteExpression(ctx *plancontext.PlanningContext, expr } return semantics.RewriteDerivedTableExpression(expr, tableInfo) } +func (dt *DerivedTable) RewriteExpression2(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (sqlparser.Expr, error) { + if dt == nil { + return expr, nil + } + tableInfo, err := ctx.SemTable.TableInfoFor(dt.TableID) + if err != nil { + return nil, err + } + return semantics.ExposeExpressionThroughDerived(expr, tableInfo), nil +} func (dt *DerivedTable) introducesTableID() semantics.TableSet { if dt == nil { diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 4554b09fcb7..dfa35cf7c49 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -19,7 +19,11 @@ package operators import ( "fmt" "io" + "math/rand" + "time" + "unsafe" + "vitess.io/vitess/go/test/dbg" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" @@ -291,30 +295,25 @@ func pushProjectionInApplyJoin( src *ApplyJoin, ) (ops.Operator, *rewrite.ApplyResult, error) { ap, err := p.GetAliasedProjections() - if src.LeftJoin || err != nil { + if err != nil { // we can't push down expression evaluation to the rhs if we are not sure if it will even be executed return p, rewrite.SameTree, nil } - lhs, rhs := &projector{}, &projector{} - if p.DT != nil && len(p.DT.Columns) > 0 { - lhs.explicitColumnAliases = true - rhs.explicitColumnAliases = true - } + + explicitColumnAlias := p.DT != nil && len(p.DT.Columns) > 0 + lhs := &projector{explicitColumnAliases: explicitColumnAlias} + rhs := &projector{explicitColumnAliases: explicitColumnAlias} src.JoinColumns = nil for idx, pe := range ap { - var col *sqlparser.IdentifierCI - if p.DT != nil && idx < len(p.DT.Columns) { - col = &p.DT.Columns[idx] - } - err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, col) + err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, p.aliasFor(idx)) if err != nil { return nil, nil, err } } if p.isDerived() { - err := exposeColumnsThroughDerivedTable(ctx, p, src, lhs) + err := exposeColumnsThroughDerivedTable(ctx, p, src, lhs, rhs) if err != nil { return nil, nil, err } @@ -334,6 +333,14 @@ func pushProjectionInApplyJoin( return src, rewrite.NewTree("split projection to either side of join", src), nil } +func (p *Projection) aliasFor(idx int) *sqlparser.IdentifierCI { + if p.DT == nil || idx >= len(p.DT.Columns) { + return nil + } + + return &p.DT.Columns[idx] +} + // splitProjectionAcrossJoin creates JoinPredicates for all projections, // and pushes down columns as needed between the LHS and RHS of a join func splitProjectionAcrossJoin( @@ -408,45 +415,100 @@ func splitUnexploredExpression( // The function iterates through each join predicate, rewriting the expressions in the predicate's // LHS expressions to include the derived table. This allows the expressions to be accessed outside // the derived table. -func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs *projector) error { - derivedTbl, err := ctx.SemTable.TableInfoFor(p.DT.TableID) - if err != nil { - return err - } - derivedTblName, err := derivedTbl.Name() - if err != nil { - return err +func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs *projector, rhs *projector) error { + if p.DT == nil { + return nil } - for _, predicate := range src.JoinPredicates { - for idx, bve := range predicate.LHSExprs { - expr := bve.Expr - tbl, err := ctx.SemTable.TableInfoForExpr(expr) - if err != nil { - return err + + cols := p.Columns.GetColumns() + f := func(expr sqlparser.Expr) (e sqlparser.Expr, err error) { + rewriter := func(cursor *sqlparser.CopyOnWriteCursor) { + this, ok := cursor.Node().(sqlparser.Expr) + if !ok { + return } - tblExpr := tbl.GetExpr() - tblName, err := tblExpr.TableName() - if err != nil { - return err + + // If we didn't find it, and we are dealing with a ColName, we have to add it to the derived table + colExpr, ok := this.(*sqlparser.ColName) + if !ok { + return } - expr = semantics.RewriteDerivedTableExpression(expr, derivedTbl) - out := prefixColNames(tblName, expr) + // First we check if this expression is already being returned + for _, column := range cols { + if ctx.SemTable.EqualsExprWithDeps(column.Expr, colExpr) { + col := sqlparser.NewColName(column.ColumnName()) + cursor.Replace(col) + return + } + } + + colAlias := fmt.Sprintf("%s_vt_%s_%s", p.DT.Alias, sqlparser.String(colExpr), RandString(2)) - alias := sqlparser.UnescapedString(out) - predicate.LHSExprs[idx].Expr = sqlparser.NewColNameWithQualifier(alias, derivedTblName) - identifierCI := sqlparser.NewIdentifierCI(alias) - projExpr := newProjExprWithInner(&sqlparser.AliasedExpr{Expr: out, As: identifierCI}, out) - var colAlias *sqlparser.IdentifierCI - if lhs.explicitColumnAliases { - colAlias = &identifierCI + newCol := sqlparser.NewColName(colAlias) + inner := newProjExprWithInner(aeWrap(newCol), colExpr) + _, thisErr := p.addProjExpr(inner) + if thisErr != nil { + err = thisErr + cursor.StopTreeWalk() + return } - lhs.add(projExpr, colAlias) + cursor.Replace(newCol) + } + + e = sqlparser.CopyOnRewrite(expr, nil, rewriter, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) + if expr == e { + panic(dbg.S()) + } + + return e, nil + } + + //for i, pred := range src.JoinPredicates { + // x, err := pred.Map(f) + // if err != nil { + // return err + // } + // src.JoinPredicates[i] = x + //} + + for i, col := range src.JoinColumns { + var err error + src.JoinColumns[i], err = col.Map(f) + if err != nil { + return err } } return nil } +var src = rand.NewSource(time.Now().UnixNano()) + +const letterBytes = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +const ( + letterIdxBits = len(letterBytes) / 8 // bits to represent a letter index + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + if idx := int(cache & letterIdxMask); idx < len(letterBytes) { + b[i] = letterBytes[idx] + i-- + } + cache >>= letterIdxBits + remain-- + } + + return *(*string)(unsafe.Pointer(&b)) +} + // prefixColNames adds qualifier prefixes to all ColName:s. // We want to be more explicit than the user was to make sure we never produce invalid SQL func prefixColNames(tblName sqlparser.TableName, e sqlparser.Expr) sqlparser.Expr { diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index 6ecfce5bf07..d8e284fc9da 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -342,23 +342,19 @@ func getJoinFor(ctx *plancontext.PlanningContext, cm opCacheMap, lhs, rhs ops.Op return join, nil } -// requiresSwitchingSides will return true if any of the operators with the root from the given operator tree -// is of the type that should not be on the RHS of a join -func requiresSwitchingSides(ctx *plancontext.PlanningContext, op ops.Operator) bool { - required := false - +// needsLimit will return true this op tree requires LIMIT +func needsLimit(op ops.Operator) (required bool) { _ = rewrite.Visit(op, func(current ops.Operator) error { horizon, isHorizon := current.(*Horizon) - if isHorizon && horizon.IsDerived() && !horizon.IsMergeable(ctx) { + if isHorizon && horizon.Query.GetLimit() != nil { required = true return io.EOF } return nil }) - - return required + return } func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPredicates []sqlparser.Expr, inner bool) (ops.Operator, *rewrite.ApplyResult, error) { @@ -367,21 +363,13 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPr return newPlan, rewrite.NewTree("merge routes into single operator", newPlan), nil } - if len(joinPredicates) > 0 && requiresSwitchingSides(ctx, rhs) { - if !inner { - return nil, nil, vterrors.VT12001("LEFT JOIN with derived tables") - } - - if requiresSwitchingSides(ctx, lhs) { - return nil, nil, vterrors.VT12001("JOIN between derived tables") - } - - join := NewApplyJoin(Clone(rhs), Clone(lhs), nil, !inner) - newOp, err := pushJoinPredicates(ctx, joinPredicates, join) - if err != nil { - return nil, nil, err + message := "logical join to applyJoin" + if needsLimit(rhs) { + if needsLimit(lhs) || !inner { + return nil, nil, vterrors.VT12001("can't handle join with limit on the RHS") } - return newOp, rewrite.NewTree("logical join to applyJoin, switching side because derived table", newOp), nil + lhs, rhs = rhs, lhs + message += ", switch sides because limit" } join := NewApplyJoin(Clone(lhs), Clone(rhs), nil, !inner) @@ -389,7 +377,7 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPr if err != nil { return nil, nil, err } - return newOp, rewrite.NewTree("logical join to applyJoin ", newOp), nil + return newOp, rewrite.NewTree(message, newOp), nil } func operatorsToRoutes(a, b ops.Operator) (*Route, *Route) { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index d5bd132cfaa..5830086f15b 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -3531,8 +3531,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `user`.id from `user` where 1 != 1", - "Query": "select `user`.id from `user`", + "FieldQuery": "select id from `user` where 1 != 1", + "Query": "select id from `user`", "Table": "`user`" }, { @@ -3542,8 +3542,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.id = :user_id", + "FieldQuery": "select user_extra.col from (select user_extra.col as col from user_extra where 1 != 1) as x where 1 != 1", + "Query": "select user_extra.col from (select user_extra.col as col from user_extra where user_extra.id = :user_id) as x", "Table": "user_extra" } ] @@ -5507,64 +5507,46 @@ "GroupBy": "0", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0", - "JoinVars": { - "d_id": 1 - }, - "TableName": "`user`_music", + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "0 ASC", "Inputs": [ { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "GroupBy": "0, (1|2)", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "music_user_id": 0 + }, + "TableName": "music_`user`", "Inputs": [ { - "OperatorType": "SimpleProjection", - "Columns": [ - 1, - 0, - 2 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.user_id from music where 1 != 1 group by music.user_id", + "Query": "select music.user_id from music group by music.user_id", + "Table": "music" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select d.a from (select id, count(*) as a from `user` where 1 != 1) as d where 1 != 1 group by d.a", + "Query": "select d.a from (select id, count(*) as a from `user` where id = :music_user_id) as d group by d.a", + "Table": "`user`", + "Values": [ + ":music_user_id" ], - "Inputs": [ - { - "OperatorType": "Aggregate", - "Variant": "Scalar", - "Aggregates": "any_value(0) AS id, sum_count_star(1) AS a, any_value(2)", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id, count(*) as a, weight_string(id) from `user` where 1 != 1", - "OrderBy": "1 ASC, (0|2) ASC", - "Query": "select id, count(*) as a, weight_string(id) from `user` order by count(*) asc, id asc", - "Table": "`user`" - } - ] - } - ] + "Vindex": "user_index" } ] - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from music where 1 != 1 group by .0", - "Query": "select 1 from music where music.user_id = :d_id group by .0", - "Table": "music", - "Values": [ - ":d_id" - ], - "Vindex": "user_index" } ] } diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 155a8042fe9..78f26f018c4 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -1890,8 +1890,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select t.id, t.`user.col` from (select `user`.id, `user`.col1, `user`.col as `user.col` from `user` where 1 != 1) as t where 1 != 1", - "Query": "select t.id, t.`user.col` from (select `user`.id, `user`.col1, `user`.col as `user.col` from `user`) as t", + "FieldQuery": "select t.id, col from (select `user`.id, `user`.col1 from `user` where 1 != 1) as t where 1 != 1", + "Query": "select t.id, col from (select `user`.id, `user`.col1 from `user`) as t", "Table": "`user`" }, { @@ -3534,43 +3534,35 @@ "QueryType": "SELECT", "Original": "select user_extra.col+1 from user left join user_extra on user.col = user_extra.col", "Instructions": { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] + INT64(1) as user_extra.col + 1" - ], + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "user_col": 0 + }, + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "LeftJoin", - "JoinColumnIndexes": "R:0", - "JoinVars": { - "user_col": 0 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", - "Table": "user_extra" - } - ] + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col + 1 from user_extra where 1 != 1", + "Query": "select user_extra.col + 1 from user_extra where user_extra.col = :user_col", + "Table": "user_extra" } ] }, @@ -3593,44 +3585,35 @@ "TableName": "`user`_user_extra_user_extra", "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] as id", - "[COLUMN 1] + INT64(1) as user_extra.col + 1" - ], + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_col": 1 + }, + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "LeftJoin", - "JoinColumnIndexes": "L:0,R:0", - "JoinVars": { - "user_col": 1 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.id, `user`.col from `user` where 1 != 1", - "Query": "select `user`.id, `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", - "Table": "user_extra" - } - ] + "FieldQuery": "select `user`.id, `user`.col from `user` where 1 != 1", + "Query": "select `user`.id, `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col + 1 from user_extra where 1 != 1", + "Query": "select user_extra.col + 1 from user_extra where user_extra.col = :user_col", + "Table": "user_extra" } ] }, @@ -3660,43 +3643,36 @@ "QueryType": "SELECT", "Original": "select user.foo+user_extra.col+1 from user left join user_extra on user.col = user_extra.col", "Instructions": { - "OperatorType": "Projection", - "Expressions": [ - "([COLUMN 0] + [COLUMN 1]) + INT64(1) as `user`.foo + user_extra.col + 1" - ], + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "user_col": 1, + "user_foo": 0 + }, + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "LeftJoin", - "JoinColumnIndexes": "L:0,R:0", - "JoinVars": { - "user_col": 1 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.foo, `user`.col from `user` where 1 != 1", - "Query": "select `user`.foo, `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", - "Table": "user_extra" - } - ] + "FieldQuery": "select `user`.foo, `user`.col from `user` where 1 != 1", + "Query": "select `user`.foo, `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :user_foo + user_extra.col + 1 as ```user``.foo + user_extra.col + 1` from user_extra where 1 != 1", + "Query": "select :user_foo + user_extra.col + 1 as ```user``.foo + user_extra.col + 1` from user_extra where user_extra.col = :user_col", + "Table": "user_extra" } ] }, diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index da7543f706a..e0893e6b907 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,9 +1,60 @@ [ { - "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "comment": "count non-null columns incoming from outer joins should work well", + "query": "select count(col+42) from (select user_extra.col as col from user left join user_extra on user.id = user_extra.id limit 10) as x", "plan": { - + "QueryType": "SELECT", + "Original": "select count(col+42) from (select user_extra.col as col from user left join user_extra on user.id = user_extra.id limit 10) as x", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "count(0) AS count(col)", + "Inputs": [ + { + "OperatorType": "Limit", + "Count": "INT64(10)", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "user_id": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from `user` where 1 != 1", + "Query": "select id from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col from (select user_extra.col as col from user_extra where 1 != 1) as x where 1 != 1", + "Query": "select x.col from (select user_extra.col as col from user_extra where user_extra.id = :user_id) as x", + "Table": "user_extra" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] } } ] \ No newline at end of file diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 80d9cdf6d23..f7001b176f2 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -4261,17 +4261,27 @@ "Instructions": { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "R:0,L:0", + "JoinColumnIndexes": "L:0,R:0", "JoinVars": { - "t_id": 1 + "user_id": 1 }, - "TableName": "user_extra_`user`", + "TableName": "`user`_user_extra", "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.a, `user`.id from `user` where 1 != 1", + "Query": "select `user`.a, `user`.id from `user`", + "Table": "`user`" + }, { "OperatorType": "SimpleProjection", "Columns": [ - 1, - 0 + 1 ], "Inputs": [ { @@ -4289,27 +4299,12 @@ }, "FieldQuery": "select id, count(*) as b, req, weight_string(req), weight_string(id) from user_extra where 1 != 1 group by req, id, weight_string(req), weight_string(id)", "OrderBy": "(2|3) ASC, (0|4) ASC", - "Query": "select id, count(*) as b, req, weight_string(req), weight_string(id) from user_extra group by req, id, weight_string(req), weight_string(id) order by req asc, id asc", + "Query": "select id, count(*) as b, req, weight_string(req), weight_string(id) from user_extra where id = :user_id group by req, id, weight_string(req), weight_string(id) order by req asc, id asc", "Table": "user_extra" } ] } ] - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.a from `user` where 1 != 1", - "Query": "select `user`.a from `user` where `user`.id = :t_id", - "Table": "`user`", - "Values": [ - ":t_id" - ], - "Vindex": "user_index" } ] }, diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 947f9cc0f96..72a237a7924 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -548,8 +548,8 @@ "[COLUMN 0] * [COLUMN 1] as revenue", "[COLUMN 2] as supp_nation", "[COLUMN 3] as l_year", - "[COLUMN 4] as orders.o_custkey", - "[COLUMN 5] as n1.n_name", + "[COLUMN 4] as o_custkey", + "[COLUMN 5] as n_name", "[COLUMN 6] as weight_string(supp_nation)", "[COLUMN 7] as weight_string(l_year)" ], @@ -568,9 +568,9 @@ "Expressions": [ "[COLUMN 0] * [COLUMN 1] as revenue", "[COLUMN 2] as l_year", - "[COLUMN 3] as orders.o_custkey", - "[COLUMN 4] as n1.n_name", - "[COLUMN 5] as lineitem.l_suppkey", + "[COLUMN 3] as o_custkey", + "[COLUMN 4] as n_name", + "[COLUMN 5] as l_suppkey", "[COLUMN 6] as weight_string(l_year)" ], "Inputs": [ @@ -590,9 +590,9 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year)", + "FieldQuery": "select sum(volume) as revenue, l_year, o_custkey, n_name, l_suppkey, l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, o_custkey, n_name, l_suppkey, l_orderkey, weight_string(l_year)", "OrderBy": "(7|8) ASC, (9|10) ASC, (1|6) ASC", - "Query": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year) order by supp_nation asc, cust_nation asc, l_year asc", + "Query": "select sum(volume) as revenue, l_year, o_custkey, n_name, l_suppkey, l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, o_custkey, n_name, l_suppkey, l_orderkey, weight_string(l_year) order by supp_nation asc, cust_nation asc, l_year asc", "Table": "lineitem" }, { @@ -638,8 +638,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where 1 != 1) as shipping where 1 != 1 group by shipping.`supplier.s_nationkey`", - "Query": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where s_suppkey = :l_suppkey) as shipping group by shipping.`supplier.s_nationkey`", + "FieldQuery": "select count(*), s_nationkey from supplier where 1 != 1 group by s_nationkey", + "Query": "select count(*), s_nationkey from supplier where s_suppkey = :l_suppkey group by s_nationkey", "Table": "supplier", "Values": [ ":l_suppkey" @@ -693,8 +693,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where 1 != 1) as shipping where 1 != 1 group by shipping.`customer.c_nationkey`", - "Query": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where c_custkey = :o_custkey) as shipping group by shipping.`customer.c_nationkey`", + "FieldQuery": "select count(*), c_nationkey from customer where 1 != 1 group by c_nationkey", + "Query": "select count(*), c_nationkey from customer where c_custkey = :o_custkey group by c_nationkey", "Table": "customer", "Values": [ ":o_custkey" diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index ea4383db911..67437ffbe3a 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -347,12 +347,12 @@ { "comment": "cant switch sides for outer joins", "query": "select id from user left join (select user_id from user_extra limit 10) ue on user.id = ue.user_id", - "plan": "VT12001: unsupported: LEFT JOIN with derived tables" + "plan": "VT13001: [BUG] can't handle join with limit on the RHS" }, { "comment": "limit on both sides means that we can't evaluate this at all", "query": "select id from (select id from user limit 10) u join (select user_id from user_extra limit 10) ue on u.id = ue.user_id", - "plan": "VT12001: unsupported: JOIN between derived tables" + "plan": "VT13001: [BUG] can't handle join with limit on the RHS" }, { "comment": "multi-shard union", diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 0af935918f9..615174e4f5f 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -418,6 +418,30 @@ func RewriteDerivedTableExpression(expr sqlparser.Expr, vt TableInfo) sqlparser. }, nil).(sqlparser.Expr) } +// RewriteDerivedTableExpression rewrites all the ColName instances in the supplied expression with +// the expressions behind the column definition of the derived table +// SELECT foo FROM (SELECT id+42 as foo FROM user) as t +// We need `foo` to be translated to `id+42` on the inside of the derived table +func ExposeExpressionThroughDerived(expr sqlparser.Expr, vt TableInfo) sqlparser.Expr { + return sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { + node, ok := cursor.Node().(*sqlparser.ColName) + if !ok { + return + } + exp, err := vt.getExprFor(node.Name.String()) + if err == nil { + cursor.Replace(exp) + return + } + + // cloning the expression and removing the qualifier + col := *node + col.Qualifier = sqlparser.TableName{} + cursor.Replace(&col) + + }, nil).(sqlparser.Expr) +} + // CopyExprInfo lookups src in the ExprTypes map and, if a key is found, assign // the corresponding Type value of src to dest. func (st *SemTable) CopyExprInfo(src, dest sqlparser.Expr) {