diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 46a98cd0bde..4c67c8c5fb2 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -289,11 +289,11 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega oa.aggregates = append(oa.aggregates, aggrParam) } for _, groupBy := range op.Grouping { - typ, _ := ctx.SemTable.TypeForExpr(groupBy.SimplifiedExpr) + typ, _ := ctx.SemTable.TypeForExpr(groupBy.Inner) oa.groupByKeys = append(oa.groupByKeys, &engine.GroupByParams{ KeyCol: groupBy.ColOffset, WeightStringCol: groupBy.WSOffset, - Expr: groupBy.SimplifiedExpr, + Expr: groupBy.Inner, Type: typ, CollationEnv: ctx.VSchema.CollationEnv(), }) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 1a9ef3c77c1..e5c7d6dcc7f 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -457,7 +457,7 @@ func buildAggregation(op *Aggregator, qb *queryBuilder) { for _, by := range op.Grouping { qb.addGroupBy(by.Inner) - simplified := by.SimplifiedExpr + simplified := by.Inner if by.WSOffset != -1 { qb.addGroupBy(weightStringFor(simplified)) } diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 11ebebc75a1..a0963929eaa 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -193,7 +193,7 @@ func pushAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, // doing the aggregating on the vtgate level instead // Adding to group by can be done only once even though there are multiple distinct aggregation with same expression. if !distinctAggrGroupByAdded { - groupBy := NewGroupBy(distinctExprs[0], distinctExprs[0]) + groupBy := NewGroupBy(distinctExprs[0]) groupBy.ColOffset = aggr.ColOffset aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy) distinctAggrGroupByAdded = true @@ -260,7 +260,7 @@ func pushAggregationThroughFilter( withNextColumn: for _, col := range columnsNeeded { for _, gb := range pushedAggr.Grouping { - if ctx.SemTable.EqualsExpr(col, gb.SimplifiedExpr) { + if ctx.SemTable.EqualsExpr(col, gb.Inner) { continue withNextColumn } } @@ -300,7 +300,7 @@ func collectColNamesNeeded(ctx *plancontext.PlanningContext, f *Filter) (columns func overlappingUniqueVindex(ctx *plancontext.PlanningContext, groupByExprs []GroupBy) bool { for _, groupByExpr := range groupByExprs { - if exprHasUniqueVindex(ctx, groupByExpr.SimplifiedExpr) { + if exprHasUniqueVindex(ctx, groupByExpr.Inner) { return true } } @@ -387,7 +387,7 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr // We need to add any columns coming from the lhs of the join to the group by on that side // If we don't, the LHS will not be able to return the column, and it can't be used to send down to the RHS - addColumnsFromLHSInJoinPredicates(ctx, rootAggr, join, lhs) + addColumnsFromLHSInJoinPredicates(ctx, join, lhs) join.LHS, join.RHS = lhs.pushed, rhs.pushed @@ -419,29 +419,28 @@ func pushAggregationThroughHashJoin(ctx *plancontext.PlanningContext, rootAggr * // The two sides of the hash comparisons are added as grouping expressions for _, cmp := range join.JoinComparisons { - lhs.addGrouping(ctx, NewGroupBy(cmp.LHS, cmp.LHS)) + lhs.addGrouping(ctx, NewGroupBy(cmp.LHS)) columns.addLeft(cmp.LHS) - rhs.addGrouping(ctx, NewGroupBy(cmp.RHS, cmp.RHS)) + rhs.addGrouping(ctx, NewGroupBy(cmp.RHS)) columns.addRight(cmp.RHS) } // The grouping columns need to be pushed down as grouping columns on the respective sides for _, groupBy := range rootAggr.Grouping { - expr := rootAggr.QP.GetSimplifiedExpr(ctx, groupBy.Inner) - deps := ctx.SemTable.RecursiveDeps(expr) + deps := ctx.SemTable.RecursiveDeps(groupBy.Inner) switch { case deps.IsSolvedBy(lhs.tableID): lhs.addGrouping(ctx, groupBy) - columns.addLeft(expr) + columns.addLeft(groupBy.Inner) case deps.IsSolvedBy(rhs.tableID): rhs.addGrouping(ctx, groupBy) - columns.addRight(expr) + columns.addRight(groupBy.Inner) case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)): // TODO: Support this as well return nil, nil default: - panic(vterrors.VT13001(fmt.Sprintf("grouping with bad dependencies %s", groupBy.SimplifiedExpr))) + panic(vterrors.VT13001(fmt.Sprintf("grouping with bad dependencies %s", groupBy.Inner))) } } @@ -473,18 +472,16 @@ func createJoinPusher(rootAggr *Aggregator, operator Operator) *joinPusher { } } -func addColumnsFromLHSInJoinPredicates(ctx *plancontext.PlanningContext, rootAggr *Aggregator, join *ApplyJoin, lhs *joinPusher) { +func addColumnsFromLHSInJoinPredicates(ctx *plancontext.PlanningContext, join *ApplyJoin, lhs *joinPusher) { for _, pred := range join.JoinPredicates.columns { for _, bve := range pred.LHSExprs { - expr := bve.Expr - wexpr := rootAggr.QP.GetSimplifiedExpr(ctx, expr) - idx, found := canReuseColumn(ctx, lhs.pushed.Columns, expr, extractExpr) + idx, found := canReuseColumn(ctx, lhs.pushed.Columns, bve.Expr, extractExpr) if !found { idx = len(lhs.pushed.Columns) - lhs.pushed.Columns = append(lhs.pushed.Columns, aeWrap(expr)) + lhs.pushed.Columns = append(lhs.pushed.Columns, aeWrap(bve.Expr)) } - _, found = canReuseColumn(ctx, lhs.pushed.Grouping, wexpr, func(by GroupBy) sqlparser.Expr { - return by.SimplifiedExpr + _, found = canReuseColumn(ctx, lhs.pushed.Grouping, bve.Expr, func(by GroupBy) sqlparser.Expr { + return by.Inner }) if found { @@ -492,10 +489,9 @@ func addColumnsFromLHSInJoinPredicates(ctx *plancontext.PlanningContext, rootAgg } lhs.pushed.Grouping = append(lhs.pushed.Grouping, GroupBy{ - Inner: expr, - SimplifiedExpr: wexpr, - ColOffset: idx, - WSOffset: -1, + Inner: bve.Expr, + ColOffset: idx, + WSOffset: -1, }) } } @@ -508,24 +504,23 @@ func splitGroupingToLeftAndRight( columns joinColumns, ) { for _, groupBy := range rootAggr.Grouping { - expr := rootAggr.QP.GetSimplifiedExpr(ctx, groupBy.Inner) - deps := ctx.SemTable.RecursiveDeps(expr) + deps := ctx.SemTable.RecursiveDeps(groupBy.Inner) switch { case deps.IsSolvedBy(lhs.tableID): lhs.addGrouping(ctx, groupBy) - columns.addLeft(expr) + columns.addLeft(groupBy.Inner) case deps.IsSolvedBy(rhs.tableID): rhs.addGrouping(ctx, groupBy) - columns.addRight(expr) + columns.addRight(groupBy.Inner) case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)): - jc := breakExpressionInLHSandRHSForApplyJoin(ctx, groupBy.SimplifiedExpr, lhs.tableID) + jc := breakExpressionInLHSandRHSForApplyJoin(ctx, groupBy.Inner, lhs.tableID) for _, lhsExpr := range jc.LHSExprs { e := lhsExpr.Expr - lhs.addGrouping(ctx, NewGroupBy(e, e)) + lhs.addGrouping(ctx, NewGroupBy(e)) } - rhs.addGrouping(ctx, NewGroupBy(jc.RHSExpr, jc.RHSExpr)) + rhs.addGrouping(ctx, NewGroupBy(jc.RHSExpr)) default: - panic(vterrors.VT13001(fmt.Sprintf("grouping with bad dependencies %s", groupBy.SimplifiedExpr))) + panic(vterrors.VT13001(fmt.Sprintf("grouping with bad dependencies %s", groupBy.Inner))) } } } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 3e577cfbb0b..f8148fb3f0e 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -86,12 +86,12 @@ func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser } } -func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { +func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { offset := len(a.Columns) a.Columns = append(a.Columns, expr) if addToGroupBy { - groupBy := NewGroupBy(expr.Expr, expr.Expr) + groupBy := NewGroupBy(expr.Expr) groupBy.ColOffset = offset a.Grouping = append(a.Grouping, groupBy) } else { @@ -154,7 +154,7 @@ func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, reuse bool, gro // This process also sets the weight string column offset, eliminating the need for a later addition in the aggregator operator's planOffset. if wsExpr, isWS := rewritten.(*sqlparser.WeightStringFuncExpr); isWS { idx := slices.IndexFunc(a.Grouping, func(by GroupBy) bool { - return ctx.SemTable.EqualsExprWithDeps(wsExpr.Expr, by.SimplifiedExpr) + return ctx.SemTable.EqualsExprWithDeps(wsExpr.Expr, by.Inner) }) if idx >= 0 { a.Grouping[idx].WSOffset = len(a.Columns) @@ -241,7 +241,7 @@ func (a *Aggregator) ShortDescription() string { var grouping []string for _, gb := range a.Grouping { - grouping = append(grouping, sqlparser.String(gb.SimplifiedExpr)) + grouping = append(grouping, sqlparser.String(gb.Inner)) } return fmt.Sprintf("%s%s group by %s", org, strings.Join(columns, ", "), strings.Join(grouping, ",")) @@ -268,11 +268,11 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator { offset := a.internalAddColumn(ctx, aeWrap(gb.Inner), false) a.Grouping[idx].ColOffset = offset } - if gb.WSOffset != -1 || !ctx.SemTable.NeedsWeightString(gb.SimplifiedExpr) { + if gb.WSOffset != -1 || !ctx.SemTable.NeedsWeightString(gb.Inner) { continue } - offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(gb.SimplifiedExpr)), true) + offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(gb.Inner)), true) a.Grouping[idx].WSOffset = offset } @@ -371,11 +371,11 @@ func (a *Aggregator) pushRemainingGroupingColumnsAndWeightStrings(ctx *planconte a.Grouping[idx].ColOffset = offset } - if gb.WSOffset != -1 || !ctx.SemTable.NeedsWeightString(gb.SimplifiedExpr) { + if gb.WSOffset != -1 || !ctx.SemTable.NeedsWeightString(gb.Inner) { continue } - offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(gb.SimplifiedExpr)), false) + offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(gb.Inner)), false) a.Grouping[idx].WSOffset = offset } for idx, aggr := range a.Aggregations { diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index 655bf2350cc..a7fa63a0c92 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -48,11 +48,7 @@ type ( func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator { columns := d.GetColumns(ctx) for idx, col := range columns { - e, err := d.QP.TryGetSimplifiedExpr(ctx, col.Expr) - if err != nil { - // ambiguous columns are not a problem for DISTINCT - e = col.Expr - } + e := col.Expr var wsCol *int typ, _ := ctx.SemTable.TypeForExpr(e) diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 300e4ef36b9..4d0a15db9b8 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -176,7 +176,7 @@ outer: } addedToCol := false for idx, groupBy := range a.Grouping { - if ctx.SemTable.EqualsExprWithDeps(groupBy.SimplifiedExpr, ae.Expr) { + if ctx.SemTable.EqualsExprWithDeps(groupBy.Inner, ae.Expr) { if !addedToCol { a.Columns = append(a.Columns, ae) addedToCol = true @@ -214,7 +214,7 @@ func createProjectionForComplexAggregation(a *Aggregator, qp *QueryProjection) O } for i, by := range a.Grouping { a.Grouping[i].ColOffset = len(a.Columns) - a.Columns = append(a.Columns, aeWrap(by.SimplifiedExpr)) + a.Columns = append(a.Columns, aeWrap(by.Inner)) } for i, aggregation := range a.Aggregations { a.Aggregations[i].ColOffset = len(a.Columns) diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 8a47507a526..e9c35568400 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -168,7 +168,7 @@ func addOrderingFor(aggrOp *Aggregator) { func needsOrdering(ctx *plancontext.PlanningContext, in *Aggregator) bool { requiredOrder := slice.Map(in.Grouping, func(from GroupBy) sqlparser.Expr { - return from.SimplifiedExpr + return from.Inner }) if in.DistinctExpr != nil { requiredOrder = append(requiredOrder, in.DistinctExpr) @@ -209,7 +209,7 @@ func addLiteralGroupingToRHS(in *ApplyJoin) (Operator, *ApplyResult) { } if len(aggr.Grouping) == 0 { gb := sqlparser.NewIntLiteral(".0") - aggr.Grouping = append(aggr.Grouping, NewGroupBy(gb, gb)) + aggr.Grouping = append(aggr.Grouping, NewGroupBy(gb)) } return nil }) diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 8137ec502a4..c0cdce4c4d6 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -642,7 +642,7 @@ func overlaps(ctx *plancontext.PlanningContext, order []OrderBy, grouping []Grou ordering: for _, orderBy := range order { for _, groupBy := range grouping { - if ctx.SemTable.EqualsExprWithDeps(orderBy.SimplifiedExpr, groupBy.SimplifiedExpr) { + if ctx.SemTable.EqualsExprWithDeps(orderBy.SimplifiedExpr, groupBy.Inner) { continue ordering } } @@ -674,7 +674,7 @@ func pushOrderingUnderAggr(ctx *plancontext.PlanningContext, order *Ordering, ag used := make([]bool, len(aggregator.Grouping)) for _, orderExpr := range order.Order { for grpIdx, by := range aggregator.Grouping { - if !used[grpIdx] && ctx.SemTable.EqualsExprWithDeps(by.SimplifiedExpr, orderExpr.SimplifiedExpr) { + if !used[grpIdx] && ctx.SemTable.EqualsExprWithDeps(by.Inner, orderExpr.SimplifiedExpr) { newGrouping = append(newGrouping, by) used[grpIdx] = true } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 6dbad50a50d..fdb323be005 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -47,27 +47,17 @@ type ( Distinct bool groupByExprs []GroupBy OrderExprs []OrderBy - HasStar bool // AddedColumn keeps a counter for expressions added to solve HAVING expressions the user is not selecting AddedColumn int hasCheckedAlignment bool - - // TODO Remove once all horizon planning is done on the operators - CanPushSorting bool } // GroupBy contains the expression to used in group by and also if grouping is needed at VTGate level then what the weight_string function expression to be sent down for evaluation. GroupBy struct { Inner sqlparser.Expr - // The simplified expressions is the "unaliased expression". - // In the following query, the group by has the inner expression - // `x` and the `SimplifiedExpr` is `table.col + 10`: - // select table.col + 10 as x, count(*) from tbl group by x - SimplifiedExpr sqlparser.Expr - // The index at which the user expects to see this column. Set to nil, if the user does not ask for it InnerIndex *int @@ -125,12 +115,11 @@ func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.T } // NewGroupBy creates a new group by from the given fields. -func NewGroupBy(inner, simplified sqlparser.Expr) GroupBy { +func NewGroupBy(inner sqlparser.Expr) GroupBy { return GroupBy{ - Inner: inner, - SimplifiedExpr: simplified, - ColOffset: -1, - WSOffset: -1, + Inner: inner, + ColOffset: -1, + WSOffset: -1, } } @@ -151,7 +140,7 @@ func (b GroupBy) AsOrderBy() OrderBy { Expr: b.Inner, Direction: sqlparser.AscOrder, }, - SimplifiedExpr: b.SimplifiedExpr, + SimplifiedExpr: b.Inner, } } @@ -264,7 +253,6 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) { qp.SelectExprs = append(qp.SelectExprs, col) case *sqlparser.StarExpr: - qp.HasStar = true col := SelectExpr{ Col: selExp, } @@ -326,21 +314,19 @@ func (qp *QueryProjection) addOrderBy(ctx *plancontext.PlanningContext, orderBy canPushSorting := true es := &expressionSet{} for _, order := range orderBy { - simpleExpr := qp.GetSimplifiedExpr(ctx, order.Expr) - if sqlparser.IsNull(simpleExpr) { + if sqlparser.IsNull(order.Expr) { // ORDER BY null can safely be ignored continue } - if !es.add(ctx, simpleExpr) { + if !es.add(ctx, order.Expr) { continue } qp.OrderExprs = append(qp.OrderExprs, OrderBy{ Inner: sqlparser.CloneRefOfOrder(order), - SimplifiedExpr: simpleExpr, + SimplifiedExpr: order.Expr, }) - canPushSorting = canPushSorting && !containsAggr(simpleExpr) + canPushSorting = canPushSorting && !containsAggr(order.Expr) } - qp.CanPushSorting = canPushSorting } func (qp *QueryProjection) calculateDistinct(ctx *plancontext.PlanningContext) { @@ -365,7 +351,7 @@ func (qp *QueryProjection) calculateDistinct(ctx *plancontext.PlanningContext) { } for _, gb := range qp.groupByExprs { - _, found := canReuseColumn(ctx, qp.SelectExprs, gb.SimplifiedExpr, func(expr SelectExpr) sqlparser.Expr { + _, found := canReuseColumn(ctx, qp.SelectExprs, gb.Inner, func(expr SelectExpr) sqlparser.Expr { getExpr, err := expr.GetExpr() if err != nil { panic(err) @@ -383,16 +369,15 @@ func (qp *QueryProjection) calculateDistinct(ctx *plancontext.PlanningContext) { func (qp *QueryProjection) addGroupBy(ctx *plancontext.PlanningContext, groupBy sqlparser.GroupBy) { es := &expressionSet{} - for _, group := range groupBy { - selectExprIdx := qp.FindSelectExprIndexForExpr(ctx, group) - simpleExpr := qp.GetSimplifiedExpr(ctx, group) - checkForInvalidGroupingExpressions(simpleExpr) + for _, grouping := range groupBy { + selectExprIdx := qp.FindSelectExprIndexForExpr(ctx, grouping) + checkForInvalidGroupingExpressions(grouping) - if !es.add(ctx, simpleExpr) { + if !es.add(ctx, grouping) { continue } - groupBy := NewGroupBy(group, simpleExpr) + groupBy := NewGroupBy(grouping) groupBy.InnerIndex = selectExprIdx qp.groupByExprs = append(qp.groupByExprs, groupBy) @@ -406,79 +391,13 @@ func (qp *QueryProjection) GetGrouping() []GroupBy { func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool { for _, groupByExpr := range qp.groupByExprs { - if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, expr) { + if ctx.SemTable.EqualsExprWithDeps(groupByExpr.Inner, expr) { return true } } return false } -// GetSimplifiedExpr takes an expression used in ORDER BY or GROUP BY, and returns an expression that is simpler to evaluate -func (qp *QueryProjection) GetSimplifiedExpr(ctx *plancontext.PlanningContext, e sqlparser.Expr) sqlparser.Expr { - expr, err := qp.TryGetSimplifiedExpr(ctx, e) - if err != nil { - panic(err) - } - return expr -} - -func (qp *QueryProjection) TryGetSimplifiedExpr(ctx *plancontext.PlanningContext, e sqlparser.Expr) (found sqlparser.Expr, err error) { - if qp == nil { - return e, nil - } - // If the ORDER BY is against a column alias, we need to remember the expression - // behind the alias. The weightstring(.) calls needs to be done against that expression and not the alias. - // Eg - select music.foo as bar, weightstring(music.foo) from music order by bar - - in, isColName := e.(*sqlparser.ColName) - if !(isColName && in.Qualifier.IsEmpty()) { - // we are only interested in unqualified column names. if it's not a column name and not unqualified, we're done - return e, nil - } - - check := func(e sqlparser.Expr) error { - if found != nil && !ctx.SemTable.EqualsExprWithDeps(found, e) { - return &semantics.AmbiguousColumnError{Column: sqlparser.String(in)} - } - found = e - return nil - } - - for _, selectExpr := range qp.SelectExprs { - ae, ok := selectExpr.Col.(*sqlparser.AliasedExpr) - if !ok { - continue - } - aliased := ae.As.NotEmpty() - if aliased { - if in.Name.Equal(ae.As) { - err = check(ae.Expr) - if err != nil { - return nil, err - } - } - } else { - seCol, ok := ae.Expr.(*sqlparser.ColName) - if !ok { - continue - } - if seCol.Name.Equal(in.Name) { - // If the column name matches, we have a match, even if the table name is not listed - err = check(ae.Expr) - if err != nil { - return nil, err - } - } - } - } - - if found == nil { - found = e - } - - return found, nil -} - // toString should only be used for tests func (qp *QueryProjection) toString() string { type output struct { @@ -706,7 +625,7 @@ func (qp *QueryProjection) OldAlignGroupByAndOrderBy(ctx *plancontext.PlanningCo used := make([]bool, len(qp.groupByExprs)) for _, orderExpr := range qp.OrderExprs { for i, groupingExpr := range qp.groupByExprs { - if !used[i] && ctx.SemTable.EqualsExpr(groupingExpr.SimplifiedExpr, orderExpr.SimplifiedExpr) { + if !used[i] && ctx.SemTable.EqualsExpr(groupingExpr.Inner, orderExpr.SimplifiedExpr) { newGrouping = append(newGrouping, groupingExpr) used[i] = true } @@ -746,7 +665,7 @@ func (qp *QueryProjection) AlignGroupByAndOrderBy(ctx *plancontext.PlanningConte outer: for _, orderBy := range qp.OrderExprs { for gidx, groupBy := range qp.groupByExprs { - if ctx.SemTable.EqualsExprWithDeps(groupBy.SimplifiedExpr, orderBy.SimplifiedExpr) { + if ctx.SemTable.EqualsExprWithDeps(groupBy.Inner, orderBy.SimplifiedExpr) { newGrouping = append(newGrouping, groupBy) used[gidx] = true continue outer @@ -796,15 +715,14 @@ func (qp *QueryProjection) useGroupingOverDistinct(ctx *plancontext.PlanningCont // not an alias Expr, cannot continue forward. return false } - sExpr := qp.GetSimplifiedExpr(ctx, ae.Expr) // check if the grouping already exists on that column. found := slices.IndexFunc(qp.groupByExprs, func(gb GroupBy) bool { - return ctx.SemTable.EqualsExprWithDeps(gb.SimplifiedExpr, sExpr) + return ctx.SemTable.EqualsExprWithDeps(gb.Inner, ae.Expr) }) if found != -1 { continue } - groupBy := NewGroupBy(ae.Expr, sExpr) + groupBy := NewGroupBy(ae.Expr) selectExprIdx := idx groupBy.InnerIndex = &selectExprIdx diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection_test.go b/go/vt/vtgate/planbuilder/operators/queryprojection_test.go deleted file mode 100644 index 4495efeab3c..00000000000 --- a/go/vt/vtgate/planbuilder/operators/queryprojection_test.go +++ /dev/null @@ -1,250 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package operators - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/vtgate/semantics" -) - -func TestQP(t *testing.T) { - tcases := []struct { - sql string - - expErr string - expOrder []OrderBy - }{ - { - sql: "select * from user", - }, - { - sql: "select 1, count(1) from user", - }, - { - sql: "select max(id) from user", - }, - { - sql: "select 1, count(1) from user order by 1", - expOrder: []OrderBy{ - {Inner: &sqlparser.Order{Expr: sqlparser.NewIntLiteral("1")}, SimplifiedExpr: sqlparser.NewIntLiteral("1")}, - }, - }, - { - sql: "select id from user order by col, id, 1", - expOrder: []OrderBy{ - {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("col")}, SimplifiedExpr: sqlparser.NewColName("col")}, - {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("id")}, SimplifiedExpr: sqlparser.NewColName("id")}, - }, - }, - { - sql: "SELECT CONCAT(last_name,', ',first_name) AS full_name FROM mytable ORDER BY full_name", // alias in order not supported - expOrder: []OrderBy{ - { - Inner: &sqlparser.Order{Expr: sqlparser.NewColName("full_name")}, - SimplifiedExpr: &sqlparser.FuncExpr{ - Name: sqlparser.NewIdentifierCI("CONCAT"), - Exprs: sqlparser.SelectExprs{ - &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("last_name")}, - &sqlparser.AliasedExpr{Expr: sqlparser.NewStrLiteral(", ")}, - &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("first_name")}, - }, - }, - }, - }, - }, { - sql: "select count(*) b from user group by b", - expErr: "cannot group on 'count(*)'", - }, - } - ctx := &plancontext.PlanningContext{SemTable: semantics.EmptySemTable()} - for _, tcase := range tcases { - t.Run(tcase.sql, func(t *testing.T) { - stmt, err := sqlparser.NewTestParser().Parse(tcase.sql) - require.NoError(t, err) - - sel := stmt.(*sqlparser.Select) - _, err = semantics.Analyze(sel, "", &semantics.FakeSI{}) - require.NoError(t, err) - - qp, err := getQPAndError(ctx, sel) - if tcase.expErr != "" { - require.Error(t, err) - require.Contains(t, err.Error(), tcase.expErr) - } else { - require.NoError(t, err) - assert.Equal(t, len(sel.SelectExprs), len(qp.SelectExprs)) - require.Equal(t, len(tcase.expOrder), len(qp.OrderExprs), "not enough order expressions in QP") - for index, expOrder := range tcase.expOrder { - assert.True(t, sqlparser.Equals.SQLNode(expOrder.Inner, qp.OrderExprs[index].Inner), "want: %+v, got %+v", sqlparser.String(expOrder.Inner), sqlparser.String(qp.OrderExprs[index].Inner)) - assert.True(t, sqlparser.Equals.SQLNode(expOrder.SimplifiedExpr, qp.OrderExprs[index].SimplifiedExpr), "want: %v, got %v", sqlparser.String(expOrder.SimplifiedExpr), sqlparser.String(qp.OrderExprs[index].SimplifiedExpr)) - } - } - }) - } -} - -func getQPAndError(ctx *plancontext.PlanningContext, sel *sqlparser.Select) (qp *QueryProjection, err error) { - defer PanicHandler(&err) - qp = createQPFromSelect(ctx, sel) - return -} - -func TestQPSimplifiedExpr(t *testing.T) { - testCases := []struct { - query, expected string - }{ - { - query: "select intcol, count(*) from user group by 1", - expected: ` -{ - "Select": [ - "intcol", - "aggr: count(*)" - ], - "Grouping": [ - "intcol" - ], - "OrderBy": [], - "Distinct": false -}`, - }, - { - query: "select intcol, textcol from user order by 1, textcol", - expected: ` -{ - "Select": [ - "intcol", - "textcol" - ], - "Grouping": [], - "OrderBy": [ - "intcol asc", - "textcol asc" - ], - "Distinct": false -}`, - }, - { - query: "select intcol, textcol, count(id) from user group by intcol, textcol, extracol order by 2 desc", - expected: ` -{ - "Select": [ - "intcol", - "textcol", - "aggr: count(id)" - ], - "Grouping": [ - "intcol", - "textcol", - "extracol" - ], - "OrderBy": [ - "textcol desc" - ], - "Distinct": false -}`, - }, - { - query: "select distinct col1, col2 from user group by col1, col2", - expected: ` -{ - "Select": [ - "col1", - "col2" - ], - "Grouping": [], - "OrderBy": [], - "Distinct": true -}`, - }, - { - query: "select distinct count(*) from user", - expected: ` -{ - "Select": [ - "aggr: count(*)" - ], - "Grouping": [], - "OrderBy": [], - "Distinct": false -}`, - }, - } - - for _, tc := range testCases { - t.Run(tc.query, func(t *testing.T) { - ast, err := sqlparser.NewTestParser().Parse(tc.query) - require.NoError(t, err) - sel := ast.(*sqlparser.Select) - _, err = semantics.Analyze(sel, "", &semantics.FakeSI{}) - require.NoError(t, err) - ctx := &plancontext.PlanningContext{SemTable: semantics.EmptySemTable()} - qp := createQPFromSelect(ctx, sel) - require.NoError(t, err) - require.Equal(t, tc.expected[1:], qp.toString()) - }) - } -} - -func TestCompareRefInt(t *testing.T) { - one := 1 - two := 2 - tests := []struct { - name string - a *int - b *int - want bool - }{ - { - name: "1<2", - a: &one, - b: &two, - want: true, - }, { - name: "2<1", - a: &two, - b: &one, - want: false, - }, { - name: "2= len(it.node) { + return nil + } + + return it.node[it.idx].Expr +} + +func (it *orderByIterator) replace(e sqlparser.Expr) error { + if it.idx >= len(it.node) { + return vterrors.VT13001("went past the last item") + } + it.node[it.idx].Expr = e + return nil +} + +type exprIterator struct { + node []sqlparser.Expr + idx int +} + +func (it *exprIterator) next() sqlparser.Expr { + it.idx++ + + if it.idx >= len(it.node) { + return nil + } + + return it.node[it.idx] +} + +func (it *exprIterator) replace(e sqlparser.Expr) error { + if it.idx >= len(it.node) { + return vterrors.VT13001("went past the last item") + } + it.node[it.idx] = e + return nil +} + +type iterator interface { + next() sqlparser.Expr + replace(e sqlparser.Expr) error +} + +func (r *earlyRewriter) replaceLiteralsInOrderByGroupBy(e sqlparser.Expr, iter iterator) (bool, error) { + lit := getIntLiteral(e) + if lit == nil { + return false, nil + } + + newExpr, err := r.rewriteOrderByExpr(lit) + if err != nil { + return false, err + } + + if getIntLiteral(newExpr) == nil { + coll, ok := e.(*sqlparser.CollateExpr) + if ok { + coll.Expr = newExpr + newExpr = coll + } + } else { + // the expression is still a literal int. that means that we don't really need to sort by it. + // we'll just replace the number with a string instead, just like mysql would do in this situation + // mysql> explain select 1 as foo from user group by 1; + // + // mysql> show warnings; + // +-------+------+-----------------------------------------------------------------+ + // | Level | Code | Message | + // +-------+------+-----------------------------------------------------------------+ + // | Note | 1003 | /* select#1 */ select 1 AS `foo` from `test`.`user` group by '' | + // +-------+------+-----------------------------------------------------------------+ + newExpr = sqlparser.NewStrLiteral("") + } + + err = iter.replace(newExpr) + return true, err +} + +func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal { + var lit *sqlparser.Literal + switch node := e.(type) { + case *sqlparser.Literal: + lit = node + case *sqlparser.CollateExpr: + expr, ok := node.Expr.(*sqlparser.Literal) + if !ok { + return nil + } + lit = expr + default: + return nil + } + if lit.Type != sqlparser.IntVal { + return nil + } + return lit +} + // handleOrderBy processes the ORDER BY clause. -func handleOrderBy(r *earlyRewriter, cursor *sqlparser.Cursor, node sqlparser.OrderBy) { - r.clause = "order clause" - rewriteHavingAndOrderBy(node, cursor.Parent()) +func (r *earlyRewriter) handleOrderByAndGroupBy(parent sqlparser.SQLNode, iter iterator) error { + stmt, ok := parent.(sqlparser.SelectStatement) + if !ok { + return nil + } + + sel := sqlparser.GetFirstSelect(stmt) + for e := iter.next(); e != nil; e = iter.next() { + lit, err := r.replaceLiteralsInOrderByGroupBy(e, iter) + if err != nil { + return err + } + if lit { + continue + } + expr, err := r.rewriteAliasesInOrderByHavingAndGroupBy(e, sel) + if err != nil { + return err + } + err = iter.replace(expr) + if err != nil { + return err + } + } + + return nil +} + +// rewriteHavingAndOrderBy rewrites columns in the ORDER BY and HAVING clauses to use aliases +// from the SELECT expressions when applicable, following MySQL scoping rules: +// - A column identifier without a table qualifier that matches an alias introduced +// in SELECT points to that expression, not any table column. +// - However, if the aliased expression is an aggregation and the column identifier in +// the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. +func (r *earlyRewriter) rewriteAliasesInOrderByHavingAndGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { + type ExprContainer struct { + expr sqlparser.Expr + ambiguous bool + } + + aliases := map[string]ExprContainer{} + for _, e := range sel.SelectExprs { + ae, ok := e.(*sqlparser.AliasedExpr) + if !ok { + continue + } + + var alias string + + item := ExprContainer{expr: ae.Expr} + if ae.As.NotEmpty() { + alias = ae.As.Lowered() + } else if col, ok := ae.Expr.(*sqlparser.ColName); ok { + alias = col.Name.Lowered() + } + + if old, alreadyExists := aliases[alias]; alreadyExists && !sqlparser.Equals.Expr(old.expr, item.expr) { + item.ambiguous = true + } + + aliases[alias] = item + } + + insideAggr := false + downF := func(node, _ sqlparser.SQLNode) bool { + switch node.(type) { + case *sqlparser.Subquery: + return false + case sqlparser.AggrFunc: + insideAggr = true + } + + return true + } + + output := sqlparser.CopyOnRewrite(node, downF, func(cursor *sqlparser.CopyOnWriteCursor) { + switch col := cursor.Node().(type) { + case sqlparser.AggrFunc: + insideAggr = false + case *sqlparser.ColName: + if !col.Qualifier.IsEmpty() { + // we are only interested in columns not qualified by table names + break + } + + item, found := aliases[col.Name.Lowered()] + if !found { + break + } + + if item.ambiguous { + err = &AmbiguousColumnError{Column: sqlparser.String(col)} + cursor.StopTreeWalk() + return + } + + if insideAggr && sqlparser.ContainsAggregation(item.expr) { + // I'm not sure about this, but my experiments point to this being the behaviour mysql has + // mysql> select min(name) as name from user order by min(name); + // 1 row in set (0.00 sec) + // + // mysql> select id % 2, min(name) as name from user group by id % 2 order by min(name); + // 2 rows in set (0.00 sec) + // + // mysql> select id % 2, 'foobar' as name from user group by id % 2 order by min(name); + // 2 rows in set (0.00 sec) + // + // mysql> select id % 2 from user group by id % 2 order by min(min(name)); + // ERROR 1111 (HY000): Invalid use of group function + // + // mysql> select id % 2, min(name) as k from user group by id % 2 order by min(k); + // ERROR 1111 (HY000): Invalid use of group function + // + // mysql> select id % 2, -id as name from user group by id % 2, -id order by min(name); + // 6 rows in set (0.01 sec) + break + } + + cursor.Replace(sqlparser.CloneExpr(item.expr)) + } + }, nil) + + expr = output.(sqlparser.Expr) + return +} + +func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) { + scope, found := r.scoper.specialExprScopes[node] + if !found { + return node, nil + } + num, err := strconv.Atoi(node.Val) + if err != nil { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val) + } + + stmt, isSel := scope.stmt.(*sqlparser.Select) + if !isSel { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error invalid statement type, expect Select, got: %T", scope.stmt) + } + + if num < 1 || num > len(stmt.SelectExprs) { + return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) + } + + // We loop like this instead of directly accessing the offset, to make sure there are no unexpanded `*` before + for i := 0; i < num; i++ { + if _, ok := stmt.SelectExprs[i].(*sqlparser.AliasedExpr); !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(stmt.SelectExprs[i])) + } + } + + aliasedExpr, ok := stmt.SelectExprs[num-1].(*sqlparser.AliasedExpr) + if !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node)) + } + + if scope.isUnion { + col, isCol := aliasedExpr.Expr.(*sqlparser.ColName) + + if aliasedExpr.As.IsEmpty() && isCol { + return sqlparser.NewColName(col.Name.String()), nil + } + + return sqlparser.NewColName(aliasedExpr.ColumnName()), nil + } + + return realCloneOfColNames(aliasedExpr.Expr, false), nil } // rewriteOrExpr rewrites OR expressions when the right side is FALSE. @@ -244,34 +532,6 @@ func rewriteAndTrue(andExpr sqlparser.AndExpr, collationEnv *collations.Environm return nil } -// handleLiteral processes literals within the context of ORDER BY expressions. -func handleLiteral(r *earlyRewriter, cursor *sqlparser.Cursor, node *sqlparser.Literal) error { - newNode, err := r.rewriteOrderByExpr(node) - if err != nil { - return err - } - if newNode != nil { - cursor.Replace(newNode) - } - return nil -} - -// handleCollateExpr processes COLLATE expressions. -func handleCollateExpr(r *earlyRewriter, node *sqlparser.CollateExpr) error { - lit, ok := node.Expr.(*sqlparser.Literal) - if !ok { - return nil - } - newNode, err := r.rewriteOrderByExpr(lit) - if err != nil { - return err - } - if newNode != nil { - node.Expr = newNode - } - return nil -} - // handleComparisonExpr processes Comparison expressions, specifically for tuples with equal length and EqualOp operator. func handleComparisonExpr(cursor *sqlparser.Cursor, node *sqlparser.ComparisonExpr) error { lft, lftOK := node.Left.(sqlparser.ValTuple) @@ -320,110 +580,6 @@ func (r *earlyRewriter) expandStar(cursor *sqlparser.Cursor, node sqlparser.Sele return nil } -// rewriteHavingAndOrderBy rewrites columns in the ORDER BY and HAVING clauses to use aliases -// from the SELECT expressions when applicable, following MySQL scoping rules: -// - A column identifier without a table qualifier that matches an alias introduced -// in SELECT points to that expression, not any table column. -// - However, if the aliased expression is an aggregation and the column identifier in -// the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. -func rewriteHavingAndOrderBy(node, parent sqlparser.SQLNode) { - sel, isSel := parent.(*sqlparser.Select) - if !isSel { - return - } - - sqlparser.SafeRewrite(node, avoidSubqueries, - func(cursor *sqlparser.Cursor) bool { - col, ok := cursor.Node().(*sqlparser.ColName) - if !ok || !col.Qualifier.IsEmpty() { - // we are only interested in columns not qualified by table names - return true - } - - _, parentIsAggr := cursor.Parent().(sqlparser.AggrFunc) - - // Iterate through SELECT expressions. - for _, e := range sel.SelectExprs { - ae, ok := e.(*sqlparser.AliasedExpr) - if !ok || !ae.As.Equal(col.Name) { - // we are searching for aliased expressions that match the column we have found - continue - } - - expr := ae.Expr - if parentIsAggr { - if _, aliasPointsToAggr := expr.(sqlparser.AggrFunc); aliasPointsToAggr { - return false - } - } - - if isSafeToRewrite(expr) { - cursor.Replace(expr) - } - } - return true - }) -} - -func avoidSubqueries(node, _ sqlparser.SQLNode) bool { - _, isSubQ := node.(*sqlparser.Subquery) - return !isSubQ -} - -func isSafeToRewrite(e sqlparser.Expr) bool { - safeToRewrite := true - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - switch node.(type) { - case *sqlparser.ColName: - safeToRewrite = false - return false, nil - case sqlparser.AggrFunc: - return false, nil - } - return true, nil - }, e) - return safeToRewrite -} - -func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) { - currScope, found := r.scoper.specialExprScopes[node] - if !found { - return nil, nil - } - num, err := strconv.Atoi(node.Val) - if err != nil { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val) - } - stmt, isSel := currScope.stmt.(*sqlparser.Select) - if !isSel { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error invalid statement type, expect Select, got: %T", currScope.stmt) - } - - if num < 1 || num > len(stmt.SelectExprs) { - return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) - } - - for i := 0; i < num; i++ { - expr := stmt.SelectExprs[i] - _, ok := expr.(*sqlparser.AliasedExpr) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(expr)) - } - } - - aliasedExpr, ok := stmt.SelectExprs[num-1].(*sqlparser.AliasedExpr) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node)) - } - - if aliasedExpr.As.NotEmpty() { - return sqlparser.NewColName(aliasedExpr.As.String()), nil - } - - expr := realCloneOfColNames(aliasedExpr.Expr, currScope.isUnion) - return expr, nil -} - // realCloneOfColNames clones all the expressions including ColName. // Since sqlparser.CloneRefOfColName does not clone col names, this method is needed. func realCloneOfColNames(expr sqlparser.Expr, union bool) sqlparser.Expr { diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 476f993f3d7..3b7b30d5f39 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -170,7 +170,7 @@ func TestExpandStar(t *testing.T) { expanded: "main.t1.a, main.t1.b, main.t1.c, main.t5.a", }, { sql: "select * from t1 join t5 using (b) having b = 12", - expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having b = 12", + expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having t1.b = 12", }, { sql: "select 1 from t1 join t5 using (b) having b = 12", expSQL: "select 1 from t1 join t5 on t1.b = t5.b having t1.b = 12", @@ -315,22 +315,31 @@ func TestOrderByGroupByLiteral(t *testing.T) { expErr string }{{ sql: "select 1 as id from t1 order by 1", - expSQL: "select 1 as id from t1 order by id asc", + expSQL: "select 1 as id from t1 order by '' asc", }, { sql: "select t1.col from t1 order by 1", expSQL: "select t1.col from t1 order by t1.col asc", + }, { + sql: "select t1.col from t1 order by 1.0", + expSQL: "select t1.col from t1 order by 1.0 asc", + }, { + sql: "select t1.col from t1 order by 'fubick'", + expSQL: "select t1.col from t1 order by 'fubick' asc", + }, { + sql: "select t1.col as foo from t1 order by 1", + expSQL: "select t1.col as foo from t1 order by t1.col asc", }, { sql: "select t1.col from t1 group by 1", expSQL: "select t1.col from t1 group by t1.col", }, { sql: "select t1.col as xyz from t1 group by 1", - expSQL: "select t1.col as xyz from t1 group by xyz", + expSQL: "select t1.col as xyz from t1 group by t1.col", }, { sql: "select t1.col as xyz, count(*) from t1 group by 1 order by 2", - expSQL: "select t1.col as xyz, count(*) from t1 group by xyz order by count(*) asc", + expSQL: "select t1.col as xyz, count(*) from t1 group by t1.col order by count(*) asc", }, { sql: "select id from t1 group by 2", - expErr: "Unknown column '2' in 'group statement'", + expErr: "Unknown column '2' in 'group clause'", }, { sql: "select id from t1 order by 2", expErr: "Unknown column '2' in 'order clause'", @@ -339,16 +348,22 @@ func TestOrderByGroupByLiteral(t *testing.T) { expErr: "cannot use column offsets in order clause when using `*`", }, { sql: "select *, id from t1 group by 2", - expErr: "cannot use column offsets in group statement when using `*`", + expErr: "cannot use column offsets in group clause when using `*`", }, { sql: "select id from t1 order by 1 collate utf8_general_ci", expSQL: "select id from t1 order by id collate utf8_general_ci asc", + }, { + sql: "select a.id from `user` union select 1 from dual order by 1", + expSQL: "select a.id from `user` union select 1 from dual order by id asc", + }, { + sql: "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 1", + expErr: "Column 'id' in field list is ambiguous", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.NewTestParser().Parse(tcase.sql) require.NoError(t, err) - selectStatement := ast.(*sqlparser.Select) + selectStatement := ast.(sqlparser.SelectStatement) _, err = Analyze(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err) @@ -378,12 +393,28 @@ func TestHavingAndOrderByColumnName(t *testing.T) { }, { sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + }, { + sql: "select id, lower(min(foo)) as foo from t1 order by min(foo)", + expSQL: "select id, lower(min(foo)) as foo from t1 order by min(foo) asc", + }, { + // invalid according to group by rules, but still accepted by mysql + sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", + expSQL: "select id, t1.bar as foo from t1 group by id order by min(t1.bar) asc", + }, { + sql: "select foo + 2 as foo from t1 having foo = 42", + expSQL: "select foo + 2 as foo from t1 having foo + 2 = 42", + }, { + sql: "select id, b as id, count(*) from t1 order by id", + expErr: "Column 'id' in field list is ambiguous", + }, { + sql: "select id, id, count(*) from t1 order by id", + expSQL: "select id, id, count(*) from t1 order by id asc", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.NewTestParser().Parse(tcase.sql) require.NoError(t, err) - selectStatement := ast.(*sqlparser.Select) + selectStatement := ast.(sqlparser.SelectStatement) _, err = Analyze(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err)