diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index aa6495eec9e..8f6f06132b5 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -108,26 +108,6 @@ func findTablesContained(ctx *plancontext.PlanningContext, node sqlparser.SQLNod return } -func checkForCorrelatedSubqueries( - ctx *plancontext.PlanningContext, - stmt sqlparser.SelectStatement, - subqID semantics.TableSet, -) (correlated bool) { - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { - colname, isColname := node.(*sqlparser.ColName) - if !isColname { - return true, nil - } - deps := ctx.SemTable.RecursiveDeps(colname) - if deps.IsSolvedBy(subqID) { - return true, nil - } - correlated = true - return false, nil - }, stmt) - return correlated -} - // joinPredicateCollector is used to inspect the predicates inside the subquery, looking for any // comparisons between the inner and the outer side. // They can be used for merging the two parts of the query together diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 4d0a15db9b8..68880bef90b 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -98,10 +98,7 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel } if len(qp.OrderExprs) > 0 { - op = &Ordering{ - Source: op, - Order: qp.OrderExprs, - } + op = expandOrderBy(ctx, op, qp) extracted = append(extracted, "Ordering") } @@ -116,6 +113,40 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel return op, Rewrote(fmt.Sprintf("expand SELECT horizon into (%s)", strings.Join(extracted, ", "))) } +func expandOrderBy(ctx *plancontext.PlanningContext, op Operator, qp *QueryProjection) Operator { + proj := newAliasedProjection(op) + var newOrder []OrderBy + sqc := &SubQueryBuilder{} + for _, expr := range qp.OrderExprs { + newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr.SimplifiedExpr, TableID(op), false) + if newExpr == nil { + // no subqueries found, let's move on + newOrder = append(newOrder, expr) + continue + } + proj.addSubqueryExpr(aeWrap(newExpr), newExpr, subqs...) + newOrder = append(newOrder, OrderBy{ + Inner: &sqlparser.Order{ + Expr: newExpr, + Direction: expr.Inner.Direction, + }, + SimplifiedExpr: newExpr, + }) + + } + + if len(proj.Columns.GetColumns()) > 0 { + // if we had to project columns for the ordering, + // we need the projection as source + op = proj + } + + return &Ordering{ + Source: op, + Order: newOrder, + } +} + func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horizon) Operator { qp := horizon.getQP(ctx) @@ -242,7 +273,7 @@ func createProjectionWithoutAggr(ctx *plancontext.PlanningContext, qp *QueryProj sqc := &SubQueryBuilder{} outerID := TableID(src) for _, ae := range aes { - org := sqlparser.CloneRefOfAliasedExpr(ae) + org := ctx.SemTable.Clone(ae).(*sqlparser.AliasedExpr) expr := ae.Expr newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr, outerID, false) if newExpr == nil { diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index fdb323be005..14bea4f4674 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -322,7 +322,7 @@ func (qp *QueryProjection) addOrderBy(ctx *plancontext.PlanningContext, orderBy continue } qp.OrderExprs = append(qp.OrderExprs, OrderBy{ - Inner: sqlparser.CloneRefOfOrder(order), + Inner: ctx.SemTable.Clone(order).(*sqlparser.Order), SimplifiedExpr: order.Expr, }) canPushSorting = canPushSorting && !containsAggr(order.Expr) diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 2bcf1e97f74..0765a878a3e 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -203,7 +203,7 @@ func (sq *SubQuery) settle(ctx *plancontext.PlanningContext, outer Operator) Ope if !sq.TopLevel { panic(subqueryNotAtTopErr) } - if sq.correlated { + if sq.correlated && sq.FilterType != opcode.PulloutExists { panic(correlatedSubqueryErr) } if sq.IsProjection { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_builder.go b/go/vt/vtgate/planbuilder/operators/subquery_builder.go index f69de0dedc4..4caf3530075 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_builder.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_builder.go @@ -169,7 +169,7 @@ func createSubquery( sqc := &SubQueryBuilder{totalID: totalID, subqID: subqID, outerID: outerID} predicates, joinCols := sqc.inspectStatement(ctx, subq.Select) - correlated := checkForCorrelatedSubqueries(ctx, subq.Select, subqID) + correlated := !ctx.SemTable.RecursiveDeps(subq).IsEmpty() opInner := translateQueryToOp(ctx, subq.Select) diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 1484915ea6e..4a1c8fa1559 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6942,7 +6942,7 @@ "Sharded": true }, "FieldQuery": "select id, from_unixtime(min(col)) as col, min(col) from `user` where 1 != 1 group by id", - "OrderBy": "2 ASC", + "OrderBy": "2 ASC COLLATE utf8mb4_0900_ai_ci", "Query": "select id, from_unixtime(min(col)) as col, min(col) from `user` group by id order by min(col) asc", "ResultColumns": 2, "Table": "`user`" diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index db8230d8f7b..039093cd0c7 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2023,7 +2023,73 @@ { "comment": "select (select col from user limit 1) as a from user join user_extra order by a", "query": "select (select col from user limit 1) as a from user join user_extra order by a", - "plan": "VT12001: unsupported: correlated subquery is only supported for EXISTS" + "plan": { + "QueryType": "SELECT", + "Original": "select (select col from user limit 1) as a from user join user_extra order by a", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Limit", + "Count": "1", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col from `user` where 1 != 1", + "Query": "select col from `user` limit :__upper_limit", + "Table": "`user`" + } + ] + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` where 1 != 1", + "OrderBy": "(0|1) ASC", + "Query": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` order by __sq1 asc", + "Table": "`user`" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } }, { "comment": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 4a2b70fe81a..5b4ff4f69de 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -902,3 +902,13 @@ func (st *SemTable) ASTEquals() *sqlparser.Comparator { } return st.comparator } + +func (st *SemTable) Clone(n sqlparser.SQLNode) sqlparser.SQLNode { + return sqlparser.CopyOnRewrite(n, nil, func(cursor *sqlparser.CopyOnWriteCursor) { + expr, isExpr := cursor.Node().(sqlparser.Expr) + if !isExpr { + return + } + cursor.Replace(sqlparser.CloneExpr(expr)) + }, st.CopySemanticInfo) +}