From 773f6a0ddcdc72a0bf5f8dc7109bdcfc9df54a22 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 15 Apr 2024 15:51:42 +0200 Subject: [PATCH] feat: allow the planner to understand aggregate UDFs Signed-off-by: Andres Taylor --- go/mysql/sqlerror/sql_error.go | 1 + go/vt/sqlparser/ast_funcs.go | 10 ++++ go/vt/vterrors/code.go | 3 +- go/vt/vterrors/state.go | 1 + go/vt/vtgate/engine/opcode/constants.go | 14 +----- go/vt/vtgate/planbuilder/delete.go | 2 +- go/vt/vtgate/planbuilder/insert.go | 2 +- .../planbuilder/operators/SQL_builder.go | 2 +- .../planbuilder/operators/aggregator.go | 8 +++- .../planbuilder/operators/expressions.go | 2 +- .../vtgate/planbuilder/operators/hash_join.go | 4 +- go/vt/vtgate/planbuilder/operators/horizon.go | 2 +- .../planbuilder/operators/offset_planning.go | 38 +++++++++++---- .../planbuilder/operators/query_planning.go | 4 +- .../planbuilder/operators/queryprojection.go | 47 ++++++++++++++----- .../operators/subquery_planning.go | 4 +- go/vt/vtgate/planbuilder/planner_test.go | 8 +++- go/vt/vtgate/planbuilder/rewrite.go | 23 +++++---- go/vt/vtgate/planbuilder/rewrite_test.go | 8 +++- go/vt/vtgate/planbuilder/select.go | 2 +- .../planbuilder/testdata/aggr_cases.json | 5 ++ go/vt/vtgate/planbuilder/update.go | 2 +- go/vt/vtgate/planbuilder/vexplain.go | 2 +- go/vt/vtgate/semantics/analyzer.go | 2 +- go/vt/vtgate/semantics/early_rewriter.go | 18 +++---- go/vt/vtgate/semantics/scoper.go | 11 ++++- go/vt/wrangler/vexec_plan.go | 8 +--- 27 files changed, 149 insertions(+), 84 deletions(-) diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go index bebd9e41ca7..69cb397a015 100644 --- a/go/mysql/sqlerror/sql_error.go +++ b/go/mysql/sqlerror/sql_error.go @@ -244,6 +244,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{ vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState}, vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation}, vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState}, + vterrors.AggregateMustPushDown: {num: ERNotSupportedYet, state: SSUnknownSQLState}, } func getStateToMySQLState(state vterrors.State) mysqlCode { diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 77cbed714b0..8019645c250 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -923,6 +923,16 @@ func (node IdentifierCI) EqualString(str string) bool { return node.Lowered() == strings.ToLower(str) } +// EqualsAnyString returns true if any of these strings match +func (node IdentifierCI) EqualsAnyString(str []string) bool { + for _, s := range str { + if node.EqualString(s) { + return true + } + } + return false +} + // MarshalJSON marshals into JSON. func (node IdentifierCI) MarshalJSON() ([]byte, error) { return json.Marshal(node.val) diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index bc4bd9bbe35..a0d2b301aec 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -57,7 +57,8 @@ var ( VT03029 = errorWithState("VT03029", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "column count does not match value count with the row for vindex '%s'", "The number of columns you want to insert do not match the number of columns of your SELECT query.") VT03030 = errorWithState("VT03030", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "lookup column count does not match value count with the row (columns, count): (%v, %d)", "The number of columns you want to insert do not match the number of columns of your SELECT query.") VT03031 = errorWithoutState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, "EXPLAIN is only supported for single keyspace", "EXPLAIN has to be sent down as a single query to the underlying MySQL, and this is not possible if it uses tables from multiple keyspaces") - VT03032 = errorWithState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.") + VT03032 = errorWithState("VT03032", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.") + VT03033 = errorWithState("VT03033", vtrpcpb.Code_INVALID_ARGUMENT, AggregateMustPushDown, "aggregate user-defined function %s must be pushed down to mysql", "The aggregate user-defined function must be pushed down to mysql and can't be evaluated on the vtgate. The query contains aggregation that can't be fully pushed down to MySQL.") VT05001 = errorWithState("VT05001", vtrpcpb.Code_NOT_FOUND, DbDropExists, "cannot drop database '%s'; database does not exists", "The given database does not exist; Vitess cannot drop it.") VT05002 = errorWithState("VT05002", vtrpcpb.Code_NOT_FOUND, BadDb, "cannot alter database '%s'; unknown database", "The given database does not exist; Vitess cannot alter it.") diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 2b0ada0bc6d..78724cda892 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -49,6 +49,7 @@ const ( WrongArguments BadNullError InvalidGroupFuncUse + AggregateMustPushDown // failed precondition NoDB diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 2fa0e9446a4..189964536fb 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -77,22 +77,10 @@ const ( AggregateCountStar AggregateGroupConcat AggregateAvg + AggregateUDF // This is an opcode used to represent UDFs _NumOfOpCodes // This line must be last of the opcodes! ) -var ( - // OpcodeType keeps track of the known output types for different aggregate functions - OpcodeType = map[AggregateOpcode]querypb.Type{ - AggregateCountDistinct: sqltypes.Int64, - AggregateCount: sqltypes.Int64, - AggregateCountStar: sqltypes.Int64, - AggregateSumDistinct: sqltypes.Decimal, - AggregateSum: sqltypes.Decimal, - AggregateAvg: sqltypes.Decimal, - AggregateGtid: sqltypes.VarChar, - } -) - // SupportedAggregates maps the list of supported aggregate // functions to their opcodes. var SupportedAggregates = map[string]AggregateOpcode{ diff --git a/go/vt/vtgate/planbuilder/delete.go b/go/vt/vtgate/planbuilder/delete.go index 6d56a41c6df..980af21df61 100644 --- a/go/vt/vtgate/planbuilder/delete.go +++ b/go/vt/vtgate/planbuilder/delete.go @@ -49,7 +49,7 @@ func gen4DeleteStmtPlanner( return nil, err } - err = queryRewrite(ctx.SemTable, reservedVars, deleteStmt) + err = queryRewrite(ctx, deleteStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/insert.go b/go/vt/vtgate/planbuilder/insert.go index b08330f060d..e674850c753 100644 --- a/go/vt/vtgate/planbuilder/insert.go +++ b/go/vt/vtgate/planbuilder/insert.go @@ -33,7 +33,7 @@ func gen4InsertStmtPlanner(version querypb.ExecuteOptions_PlannerVersion, insStm return nil, err } - err = queryRewrite(ctx.SemTable, reservedVars, insStmt) + err = queryRewrite(ctx, insStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 062f5b7303d..6604c7587c3 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -96,7 +96,7 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { switch stmt := qb.stmt.(type) { case *sqlparser.Select: - if containsAggr(expr) { + if ContainsAggr(qb.ctx, expr) { addPred = stmt.AddHaving } else { addPred = stmt.AddWhere diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 256372c172f..5c9ba167171 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -83,7 +83,7 @@ func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser return newFilter(a, expr) } -func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { +func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { offset := len(a.Columns) a.Columns = append(a.Columns, expr) @@ -96,6 +96,12 @@ func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, exp switch e := expr.Expr.(type) { case sqlparser.AggrFunc: aggr = createAggrFromAggrFunc(e, expr) + case *sqlparser.FuncExpr: + if IsAggr(ctx, e) { + aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.As.String()) + } else { + aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String()) + } default: aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String()) } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 612c1e7ec08..521024ab7c9 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -31,7 +31,7 @@ func breakExpressionInLHSandRHSForApplyJoin( ) (col applyJoinColumn) { rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { nodeExpr, ok := cursor.Node().(sqlparser.Expr) - if !ok || !mustFetchFromInput(nodeExpr) { + if !ok || !mustFetchFromInput(ctx, nodeExpr) { return } deps := ctx.SemTable.RecursiveDeps(nodeExpr) diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go index 0ad46bcbc82..da6a63200db 100644 --- a/go/vt/vtgate/planbuilder/operators/hash_join.go +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -292,7 +292,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp } inOffset := op.FindCol(ctx, expr, false) if inOffset == -1 { - if !mustFetchFromInput(expr) { + if !mustFetchFromInput(ctx, expr) { return -1 } @@ -398,7 +398,7 @@ func (hj *HashJoin) addSingleSidedColumn( } inOffset := op.FindCol(ctx, expr, false) if inOffset == -1 { - if !mustFetchFromInput(expr) { + if !mustFetchFromInput(ctx, expr) { return -1 } diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 34f6dc79217..4a4c990b1ed 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -100,7 +100,7 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser. } newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) - if sqlparser.ContainsAggregation(newExpr) { + if ContainsAggr(ctx, newExpr) { return newFilter(h, expr) } h.Source = h.Source.AddPredicate(ctx, newExpr) diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index 638d3d80907..714700642c0 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -19,6 +19,8 @@ package operators import ( "fmt" + "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -56,10 +58,12 @@ func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator { } // mustFetchFromInput returns true for expressions that have to be fetched from the input and cannot be evaluated -func mustFetchFromInput(e sqlparser.SQLNode) bool { - switch e.(type) { +func mustFetchFromInput(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool { + switch fun := e.(type) { case *sqlparser.ColName, sqlparser.AggrFunc: return true + case *sqlparser.FuncExpr: + return fun.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs()) default: return false } @@ -93,10 +97,10 @@ func useOffsets(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operat return rewritten.(sqlparser.Expr) } -// addColumnsToInput adds columns needed by an operator to its input. -// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator { - visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { + // addColumnsToInput adds columns needed by an operator to its input. + // This happens only when the filter expression can be retrieved as an offset from the underlying mysql. + addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { filter, ok := in.(*Filter) if !ok { return in, NoRewrite @@ -125,12 +129,30 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator return in, NoRewrite } + failUDFAggregation := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { + aggrOp, ok := in.(*Aggregator) + if !ok { + return in, NoRewrite + } + for _, aggr := range aggrOp.Aggregations { + if aggr.OpCode == opcode.AggregateUDF { + // we don't support UDFs in aggregation if it's still above a route + panic(vterrors.VT03033(sqlparser.String(aggr.Original.Expr))) + } + } + return in, NoRewrite + } + + visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { + out, res := addColumnsNeededByFilter(in, semantics.EmptyTableSet(), isRoot) + failUDFAggregation(in, semantics.EmptyTableSet(), isRoot) + return out, res + } return TopDown(root, TableID, visitor, stopAtRoute) } -// addColumnsToInput adds columns needed by an operator to its input. -// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. +// pullDistinctFromUNION will pull out the distinct from a union operator func pullDistinctFromUNION(_ *plancontext.PlanningContext, root Operator) Operator { visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { union, ok := in.(*Union) @@ -170,7 +192,7 @@ func getOffsetRewritingVisitor( return false } - if mustFetchFromInput(e) { + if mustFetchFromInput(ctx, e) { notFound(e) return false } diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 1b54a94201d..d2eb1c37ccd 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -287,7 +287,7 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator, case *Projection: // we can move ordering under a projection if it's not introducing a column we're sorting by for _, by := range in.Order { - if !mustFetchFromInput(by.SimplifiedExpr) { + if !mustFetchFromInput(ctx, by.SimplifiedExpr) { return in, NoRewrite } } @@ -459,7 +459,7 @@ func pushFilterUnderProjection(ctx *plancontext.PlanningContext, filter *Filter, for _, p := range filter.Predicates { cantPush := false _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - if !mustFetchFromInput(node) { + if !mustFetchFromInput(ctx, node) { return true, nil } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 14bea4f4674..dd4fec80135 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -173,11 +173,11 @@ func createQPFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) Distinct: sel.Distinct, } - qp.addSelectExpressions(sel) + qp.addSelectExpressions(ctx, sel) qp.addGroupBy(ctx, sel.GroupBy) qp.addOrderBy(ctx, sel.OrderBy) if !qp.HasAggr && sel.Having != nil { - qp.HasAggr = containsAggr(sel.Having.Expr) + qp.HasAggr = ContainsAggr(ctx, sel.Having.Expr) } qp.calculateDistinct(ctx) @@ -239,14 +239,14 @@ func (qp *QueryProjection) AggrRewriter(ctx *plancontext.PlanningContext) *AggrR } } -func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) { +func (qp *QueryProjection) addSelectExpressions(ctx *plancontext.PlanningContext, sel *sqlparser.Select) { for _, selExp := range sel.SelectExprs { switch selExp := selExp.(type) { case *sqlparser.AliasedExpr: col := SelectExpr{ Col: selExp, } - if containsAggr(selExp.Expr) { + if ContainsAggr(ctx, selExp.Expr) { col.Aggr = true qp.HasAggr = true } @@ -263,7 +263,18 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) { } } -func containsAggr(e sqlparser.SQLNode) (hasAggr bool) { +func IsAggr(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool { + switch node := e.(type) { + case sqlparser.AggrFunc: + return true + case *sqlparser.FuncExpr: + return node.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs()) + } + + return false +} + +func ContainsAggr(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) (hasAggr bool) { _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch node.(type) { case *sqlparser.Offset: @@ -275,6 +286,11 @@ func containsAggr(e sqlparser.SQLNode) (hasAggr bool) { return false, io.EOF case *sqlparser.Subquery: return false, nil + case *sqlparser.FuncExpr: + if IsAggr(ctx, node) { + hasAggr = true + return false, io.EOF + } } return true, nil @@ -287,7 +303,7 @@ func createQPFromUnion(ctx *plancontext.PlanningContext, union *sqlparser.Union) qp := &QueryProjection{} sel := sqlparser.GetFirstSelect(union) - qp.addSelectExpressions(sel) + qp.addSelectExpressions(ctx, sel) qp.addOrderBy(ctx, union.OrderBy) return qp @@ -325,7 +341,7 @@ func (qp *QueryProjection) addOrderBy(ctx *plancontext.PlanningContext, orderBy Inner: ctx.SemTable.Clone(order).(*sqlparser.Order), SimplifiedExpr: order.Expr, }) - canPushSorting = canPushSorting && !containsAggr(order.Expr) + canPushSorting = canPushSorting && !ContainsAggr(ctx, order.Expr) } } @@ -480,7 +496,7 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte idxCopy := idx - if !containsAggr(expr.Col) { + if !ContainsAggr(ctx, expr.Col) { getExpr, err := expr.GetExpr() if err != nil { panic(err) @@ -492,8 +508,7 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte } continue } - _, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc) - if !isAggregate && !allowComplexExpression { + if !IsAggr(ctx, aliasedExpr.Expr) && !allowComplexExpression { panic(vterrors.VT12001("in scatter query: complex aggregate expression")) } @@ -524,7 +539,15 @@ func (qp *QueryProjection) extractAggr( addAggr(aggrFunc) return false } - if containsAggr(node) { + if IsAggr(ctx, node) { + // If we are here, we have a function that is an aggregation but not parsed into an AggrFunc. + // This is the case for UDFs - we have to be careful with these because we can't evaluate them in VTGate. + aggr := NewAggr(opcode.AggregateUDF, nil, aeWrap(ex), "") + aggr.Index = &idx + addAggr(aggr) + return false + } + if ContainsAggr(ctx, node) { makeComplex() return true } @@ -553,7 +576,7 @@ orderBy: } qp.SelectExprs = append(qp.SelectExprs, SelectExpr{ Col: &sqlparser.AliasedExpr{Expr: orderExpr}, - Aggr: containsAggr(orderExpr), + Aggr: ContainsAggr(ctx, orderExpr), }) qp.AddedColumn++ } diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index af25136b16a..c7216e3bdae 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -56,11 +56,11 @@ func isMergeable(ctx *plancontext.PlanningContext, query sqlparser.SelectStateme // if we have grouping, we have already checked that it's safe, and don't need to check for aggregations // but if we don't have groupings, we need to check if there are aggregations that will mess with us - if sqlparser.ContainsAggregation(node.SelectExprs) { + if ContainsAggr(ctx, node.SelectExprs) { return false } - if sqlparser.ContainsAggregation(node.Having) { + if ContainsAggr(ctx, node.Having) { return false } diff --git a/go/vt/vtgate/planbuilder/planner_test.go b/go/vt/vtgate/planbuilder/planner_test.go index 2601615522f..6ad1bb4116c 100644 --- a/go/vt/vtgate/planbuilder/planner_test.go +++ b/go/vt/vtgate/planbuilder/planner_test.go @@ -19,6 +19,8 @@ package planbuilder import ( "testing" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/sqlparser" @@ -67,9 +69,13 @@ func TestBindingSubquery(t *testing.T) { "foo": {Name: sqlparser.NewIdentifierCS("foo")}, }, }) + ctx := &plancontext.PlanningContext{ + ReservedVars: sqlparser.NewReservedVars("vt", make(sqlparser.BindVars)), + SemTable: semTable, + } require.NoError(t, err) if testcase.rewrite { - err = queryRewrite(semTable, sqlparser.NewReservedVars("vt", make(sqlparser.BindVars)), selStmt) + err = queryRewrite(ctx, selStmt) require.NoError(t, err) } expr := testcase.extractor(selStmt) diff --git a/go/vt/vtgate/planbuilder/rewrite.go b/go/vt/vtgate/planbuilder/rewrite.go index 915b5e753cd..30423229038 100644 --- a/go/vt/vtgate/planbuilder/rewrite.go +++ b/go/vt/vtgate/planbuilder/rewrite.go @@ -18,19 +18,18 @@ package planbuilder import ( "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vtgate/semantics" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) type rewriter struct { - semTable *semantics.SemTable - reservedVars *sqlparser.ReservedVars - err error + err error + ctx *plancontext.PlanningContext } -func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.ReservedVars, statement sqlparser.Statement) error { +func queryRewrite(ctx *plancontext.PlanningContext, statement sqlparser.Statement) error { r := rewriter{ - semTable: semTable, - reservedVars: reservedVars, + ctx: ctx, } sqlparser.Rewrite(statement, r.rewriteDown, nil) return nil @@ -39,15 +38,15 @@ func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.Reserved func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.Select: - rewriteHavingClause(node) + r.rewriteHavingClause(node) case *sqlparser.AliasedTableExpr: if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived { break } // find the tableSet and tableInfo that this table points to // tableInfo should contain the information for the original table that the routed table points to - tableSet := r.semTable.TableSetFor(node) - tableInfo, err := r.semTable.TableInfoFor(tableSet) + tableSet := r.ctx.SemTable.TableSetFor(node) + tableInfo, err := r.ctx.SemTable.TableInfoFor(tableSet) if err != nil { // Fail-safe code, should never happen break @@ -77,7 +76,7 @@ func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { return true } -func rewriteHavingClause(node *sqlparser.Select) { +func (r *rewriter) rewriteHavingClause(node *sqlparser.Select) { if node.Having == nil { return } @@ -89,7 +88,7 @@ func rewriteHavingClause(node *sqlparser.Select) { exprs := sqlparser.SplitAndExpression(nil, node.Having.Expr) node.Having = nil for _, expr := range exprs { - if sqlparser.ContainsAggregation(expr) { + if operators.ContainsAggr(r.ctx, expr) { node.AddHaving(expr) } else { node.AddWhere(expr) diff --git a/go/vt/vtgate/planbuilder/rewrite_test.go b/go/vt/vtgate/planbuilder/rewrite_test.go index 87c8985fd63..7902b69e8f9 100644 --- a/go/vt/vtgate/planbuilder/rewrite_test.go +++ b/go/vt/vtgate/planbuilder/rewrite_test.go @@ -19,6 +19,8 @@ package planbuilder import ( "testing" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -74,7 +76,11 @@ func TestHavingRewrite(t *testing.T) { for _, tcase := range tcases { t.Run(tcase.input, func(t *testing.T) { semTable, reservedVars, sel := prepTest(t, tcase.input) - err := queryRewrite(semTable, reservedVars, sel) + ctx := &plancontext.PlanningContext{ + ReservedVars: reservedVars, + SemTable: semTable, + } + err := queryRewrite(ctx, sel) require.NoError(t, err) assert.Equal(t, tcase.output, sqlparser.String(sel)) }) diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 13671e7efa0..83a6ba650f4 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -225,7 +225,7 @@ func newBuildSelectPlan( } func createSelectOperator(ctx *plancontext.PlanningContext, selStmt sqlparser.SelectStatement, reservedVars *sqlparser.ReservedVars) (operators.Operator, error) { - err := queryRewrite(ctx.SemTable, reservedVars, selStmt) + err := queryRewrite(ctx, selStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index f7e556956e3..912cee2184d 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6958,5 +6958,10 @@ "comment": "baz in the HAVING clause can't be accessed because of the GROUP BY", "query": "select foo, count(bar) as x from user group by foo having baz > avg(baz) order by x", "plan": "Unknown column 'baz' in 'having clause'" + }, + { + "comment": "Aggregate UDFs can't be handled by vtgate", + "query": "select id from t1 group by id having udf_aggr(foo) > 1 and sum(foo) = 10", + "plan": "VT03033: aggregate user-defined function udf_aggr(foo) must be pushed down to mysql" } ] diff --git a/go/vt/vtgate/planbuilder/update.go b/go/vt/vtgate/planbuilder/update.go index 124eaf87310..313f33b6bf1 100644 --- a/go/vt/vtgate/planbuilder/update.go +++ b/go/vt/vtgate/planbuilder/update.go @@ -41,7 +41,7 @@ func gen4UpdateStmtPlanner( return nil, err } - err = queryRewrite(ctx.SemTable, reservedVars, updStmt) + err = queryRewrite(ctx, updStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/vexplain.go b/go/vt/vtgate/planbuilder/vexplain.go index 7b200fb2e09..db62da75122 100644 --- a/go/vt/vtgate/planbuilder/vexplain.go +++ b/go/vt/vtgate/planbuilder/vexplain.go @@ -128,7 +128,7 @@ func explainPlan(explain *sqlparser.ExplainStmt, reservedVars *sqlparser.Reserve return nil, vterrors.VT03031() } - if err = queryRewrite(ctx.SemTable, reservedVars, explain.Statement); err != nil { + if err = queryRewrite(ctx, explain.Statement); err != nil { return nil, err } diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index a5a38f09641..a0788a595e7 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -53,7 +53,7 @@ type analyzer struct { // newAnalyzer create the semantic analyzer func newAnalyzer(dbName string, si SchemaInformation, fullAnalysis bool) *analyzer { // TODO dependencies between these components are a little tangled. We should try to clean up - s := newScoper() + s := newScoper(si.GetAggregateUDFs()) a := &analyzer{ scoper: s, earlyTables: newEarlyTableCollector(si, dbName), diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 6c796289057..61abd9c3fa7 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -501,15 +501,6 @@ func (r *earlyRewriter) rewriteAliasesInGroupBy(node sqlparser.Expr, sel *sqlpar return } -func (at *aggrTracker) isAggregateUDF(name sqlparser.IdentifierCI) bool { - for _, aggrUDF := range at.aggrUDFs { - if name.EqualString(aggrUDF) { - return true - } - } - return false -} - func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { currentScope := r.scoper.currentScope() if currentScope.isUnion { @@ -530,7 +521,7 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars aggrTrack.popAggr() return case *sqlparser.FuncExpr: - if aggrTrack.isAggregateUDF(node.Name) { + if node.Name.EqualsAnyString(r.aggrUDFs) { aggrTrack.popAggr() } return @@ -593,7 +584,7 @@ func (at *aggrTracker) down(node, _ sqlparser.SQLNode) bool { case sqlparser.AggrFunc: at.insideAggr = true case *sqlparser.FuncExpr: - if at.isAggregateUDF(node.Name) { + if node.Name.EqualsAnyString(at.aggrUDFs) { at.insideAggr = true } } @@ -761,7 +752,10 @@ func (r *earlyRewriter) rewriteOrderByLiteral(node *sqlparser.Literal) (expr sql } if num < 1 || num > len(stmt.SelectExprs) { - return nil, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) + return nil, false, &ColumnNotFoundClauseError{ + Column: fmt.Sprintf("%d", num), + Clause: r.clause, + } } // We loop like this instead of directly accessing the offset, to make sure there are no unexpanded `*` before diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 3a6fbe4c35c..b8a41253f64 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -37,6 +37,7 @@ type ( // These scopes are only used for rewriting ORDER BY 1 and GROUP BY 1 specialExprScopes map[*sqlparser.Literal]*scope statementIDs map[sqlparser.Statement]TableSet + aggrUDFs []string } scope struct { @@ -53,12 +54,13 @@ type ( } ) -func newScoper() *scoper { +func newScoper(udfs []string) *scoper { return &scoper{ rScope: map[*sqlparser.Select]*scope{}, wScope: map[*sqlparser.Select]*scope{}, specialExprScopes: map[*sqlparser.Literal]*scope{}, statementIDs: map[sqlparser.Statement]TableSet{}, + aggrUDFs: udfs, } } @@ -84,6 +86,13 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { break } s.currentScope().inHavingAggr = true + case *sqlparser.FuncExpr: + if !s.currentScope().inHaving { + break + } + if node.Name.EqualsAnyString(s.aggrUDFs) { + s.currentScope().inHavingAggr = true + } case *sqlparser.Where: if node.Type == sqlparser.HavingClause { err := s.createSpecialScopePostProjection(cursor.Parent()) diff --git a/go/vt/wrangler/vexec_plan.go b/go/vt/wrangler/vexec_plan.go index 1878c25441c..76b2d0fe732 100644 --- a/go/vt/wrangler/vexec_plan.go +++ b/go/vt/wrangler/vexec_plan.go @@ -248,13 +248,7 @@ func (vx *vexec) buildUpdatePlan(ctx context.Context, planner vexecPlanner, upd if updatableColumnNames := plannerParams.updatableColumnNames; len(updatableColumnNames) > 0 { // if updatableColumnNames is non empty, then we must only accept changes to columns listed there for _, expr := range upd.Exprs { - isUpdatable := false - for _, updatableColName := range updatableColumnNames { - if expr.Name.Name.EqualString(updatableColName) { - isUpdatable = true - } - } - if !isUpdatable { + if !expr.Name.Name.EqualsAnyString(updatableColumnNames) { return nil, fmt.Errorf("%+v cannot be changed: %v", expr.Name.Name, sqlparser.String(expr)) } }