diff --git a/go/test/vschemawrapper/vschema_wrapper.go b/go/test/vschemawrapper/vschema_wrapper.go index 0f8b47a9804..4d1c424dda8 100644 --- a/go/test/vschemawrapper/vschema_wrapper.go +++ b/go/test/vschemawrapper/vschema_wrapper.go @@ -147,6 +147,10 @@ func (vw *VSchemaWrapper) KeyspaceError(keyspace string) error { return nil } +func (vw *VSchemaWrapper) GetAggregateUDFs() (udfs []string) { + return vw.V.GetAggregateUDFs() +} + func (vw *VSchemaWrapper) GetForeignKeyChecksState() *bool { return vw.ForeignKeyChecksState } diff --git a/go/vt/schemadiff/semantics.go b/go/vt/schemadiff/semantics.go index ee175a37966..ccbf654f566 100644 --- a/go/vt/schemadiff/semantics.go +++ b/go/vt/schemadiff/semantics.go @@ -71,6 +71,10 @@ func (si *declarativeSchemaInformation) KeyspaceError(keyspace string) error { return nil } +func (si *declarativeSchemaInformation) GetAggregateUDFs() []string { + return nil +} + func (si *declarativeSchemaInformation) GetForeignKeyChecksState() *bool { return nil } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 77cbed714b0..8019645c250 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -923,6 +923,16 @@ func (node IdentifierCI) EqualString(str string) bool { return node.Lowered() == strings.ToLower(str) } +// EqualsAnyString returns true if any of these strings match +func (node IdentifierCI) EqualsAnyString(str []string) bool { + for _, s := range str { + if node.EqualString(s) { + return true + } + } + return false +} + // MarshalJSON marshals into JSON. func (node IdentifierCI) MarshalJSON() ([]byte, error) { return json.Marshal(node.val) diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index bc4bd9bbe35..ffce4fc553d 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -57,7 +57,7 @@ var ( VT03029 = errorWithState("VT03029", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "column count does not match value count with the row for vindex '%s'", "The number of columns you want to insert do not match the number of columns of your SELECT query.") VT03030 = errorWithState("VT03030", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "lookup column count does not match value count with the row (columns, count): (%v, %d)", "The number of columns you want to insert do not match the number of columns of your SELECT query.") VT03031 = errorWithoutState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, "EXPLAIN is only supported for single keyspace", "EXPLAIN has to be sent down as a single query to the underlying MySQL, and this is not possible if it uses tables from multiple keyspaces") - VT03032 = errorWithState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.") + VT03032 = errorWithState("VT03032", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.") VT05001 = errorWithState("VT05001", vtrpcpb.Code_NOT_FOUND, DbDropExists, "cannot drop database '%s'; database does not exists", "The given database does not exist; Vitess cannot drop it.") VT05002 = errorWithState("VT05002", vtrpcpb.Code_NOT_FOUND, BadDb, "cannot alter database '%s'; unknown database", "The given database does not exist; Vitess cannot alter it.") diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 2fa0e9446a4..1bdbe61fd65 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -77,22 +77,10 @@ const ( AggregateCountStar AggregateGroupConcat AggregateAvg + AggregateUDF // This is an opcode used to represent UDFs _NumOfOpCodes // This line must be last of the opcodes! ) -var ( - // OpcodeType keeps track of the known output types for different aggregate functions - OpcodeType = map[AggregateOpcode]querypb.Type{ - AggregateCountDistinct: sqltypes.Int64, - AggregateCount: sqltypes.Int64, - AggregateCountStar: sqltypes.Int64, - AggregateSumDistinct: sqltypes.Decimal, - AggregateSum: sqltypes.Decimal, - AggregateAvg: sqltypes.Decimal, - AggregateGtid: sqltypes.VarChar, - } -) - // SupportedAggregates maps the list of supported aggregate // functions to their opcodes. var SupportedAggregates = map[string]AggregateOpcode{ @@ -166,6 +154,8 @@ func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type { return sqltypes.Int64 case AggregateGtid: return sqltypes.VarChar + case AggregateUDF: + return sqltypes.Unknown default: panic(code.String()) // we have a unit test checking we never reach here } diff --git a/go/vt/vtgate/planbuilder/delete.go b/go/vt/vtgate/planbuilder/delete.go index 6d56a41c6df..980af21df61 100644 --- a/go/vt/vtgate/planbuilder/delete.go +++ b/go/vt/vtgate/planbuilder/delete.go @@ -49,7 +49,7 @@ func gen4DeleteStmtPlanner( return nil, err } - err = queryRewrite(ctx.SemTable, reservedVars, deleteStmt) + err = queryRewrite(ctx, deleteStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/insert.go b/go/vt/vtgate/planbuilder/insert.go index b08330f060d..e674850c753 100644 --- a/go/vt/vtgate/planbuilder/insert.go +++ b/go/vt/vtgate/planbuilder/insert.go @@ -33,7 +33,7 @@ func gen4InsertStmtPlanner(version querypb.ExecuteOptions_PlannerVersion, insStm return nil, err } - err = queryRewrite(ctx.SemTable, reservedVars, insStmt) + err = queryRewrite(ctx, insStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 062f5b7303d..6604c7587c3 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -96,7 +96,7 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { switch stmt := qb.stmt.(type) { case *sqlparser.Select: - if containsAggr(expr) { + if ContainsAggr(qb.ctx, expr) { addPred = stmt.AddHaving } else { addPred = stmt.AddWhere diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 256372c172f..5c9ba167171 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -83,7 +83,7 @@ func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser return newFilter(a, expr) } -func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { +func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { offset := len(a.Columns) a.Columns = append(a.Columns, expr) @@ -96,6 +96,12 @@ func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, exp switch e := expr.Expr.(type) { case sqlparser.AggrFunc: aggr = createAggrFromAggrFunc(e, expr) + case *sqlparser.FuncExpr: + if IsAggr(ctx, e) { + aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.As.String()) + } else { + aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String()) + } default: aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String()) } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 612c1e7ec08..521024ab7c9 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -31,7 +31,7 @@ func breakExpressionInLHSandRHSForApplyJoin( ) (col applyJoinColumn) { rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { nodeExpr, ok := cursor.Node().(sqlparser.Expr) - if !ok || !mustFetchFromInput(nodeExpr) { + if !ok || !mustFetchFromInput(ctx, nodeExpr) { return } deps := ctx.SemTable.RecursiveDeps(nodeExpr) diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go index 0ad46bcbc82..da6a63200db 100644 --- a/go/vt/vtgate/planbuilder/operators/hash_join.go +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -292,7 +292,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp } inOffset := op.FindCol(ctx, expr, false) if inOffset == -1 { - if !mustFetchFromInput(expr) { + if !mustFetchFromInput(ctx, expr) { return -1 } @@ -398,7 +398,7 @@ func (hj *HashJoin) addSingleSidedColumn( } inOffset := op.FindCol(ctx, expr, false) if inOffset == -1 { - if !mustFetchFromInput(expr) { + if !mustFetchFromInput(ctx, expr) { return -1 } diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 34f6dc79217..4a4c990b1ed 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -100,7 +100,7 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser. } newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) - if sqlparser.ContainsAggregation(newExpr) { + if ContainsAggr(ctx, newExpr) { return newFilter(h, expr) } h.Source = h.Source.AddPredicate(ctx, newExpr) diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index 638d3d80907..4204b6e0420 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -19,6 +19,8 @@ package operators import ( "fmt" + "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -56,10 +58,12 @@ func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator { } // mustFetchFromInput returns true for expressions that have to be fetched from the input and cannot be evaluated -func mustFetchFromInput(e sqlparser.SQLNode) bool { - switch e.(type) { +func mustFetchFromInput(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool { + switch fun := e.(type) { case *sqlparser.ColName, sqlparser.AggrFunc: return true + case *sqlparser.FuncExpr: + return fun.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs()) default: return false } @@ -93,10 +97,10 @@ func useOffsets(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operat return rewritten.(sqlparser.Expr) } -// addColumnsToInput adds columns needed by an operator to its input. -// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator { - visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { + // addColumnsToInput adds columns needed by an operator to its input. + // This happens only when the filter expression can be retrieved as an offset from the underlying mysql. + addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { filter, ok := in.(*Filter) if !ok { return in, NoRewrite @@ -126,12 +130,33 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator return in, NoRewrite } + // while we are out here walking the operator tree, if we find a UDF in an aggregation, we should fail + failUDFAggregation := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { + aggrOp, ok := in.(*Aggregator) + if !ok { + return in, NoRewrite + } + for _, aggr := range aggrOp.Aggregations { + if aggr.OpCode == opcode.AggregateUDF { + // we don't support UDFs in aggregation if it's still above a route + message := fmt.Sprintf("Aggregate UDF '%s' must be pushed down to MySQL", sqlparser.String(aggr.Original.Expr)) + panic(vterrors.VT12001(message)) + } + } + return in, NoRewrite + } + + visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { + out, res := addColumnsNeededByFilter(in, semantics.EmptyTableSet(), isRoot) + failUDFAggregation(in, semantics.EmptyTableSet(), isRoot) + return out, res + } + return TopDown(root, TableID, visitor, stopAtRoute) } -// addColumnsToInput adds columns needed by an operator to its input. -// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. -func pullDistinctFromUNION(_ *plancontext.PlanningContext, root Operator) Operator { +// isolateDistinctFromUnion will pull out the distinct from a union operator +func isolateDistinctFromUnion(_ *plancontext.PlanningContext, root Operator) Operator { visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { union, ok := in.(*Union) if !ok || !union.distinct { @@ -170,7 +195,7 @@ func getOffsetRewritingVisitor( return false } - if mustFetchFromInput(e) { + if mustFetchFromInput(ctx, e) { notFound(e) return false } diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 2fc3a5a044f..60e937a5b92 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -88,7 +88,7 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool { func (p Phase) act(ctx *plancontext.PlanningContext, op Operator) Operator { switch p { case pullDistinctFromUnion: - return pullDistinctFromUNION(ctx, op) + return isolateDistinctFromUnion(ctx, op) case delegateAggregation: return enableDelegateAggregation(ctx, op) case addAggrOrdering: diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 1b54a94201d..d2eb1c37ccd 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -287,7 +287,7 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator, case *Projection: // we can move ordering under a projection if it's not introducing a column we're sorting by for _, by := range in.Order { - if !mustFetchFromInput(by.SimplifiedExpr) { + if !mustFetchFromInput(ctx, by.SimplifiedExpr) { return in, NoRewrite } } @@ -459,7 +459,7 @@ func pushFilterUnderProjection(ctx *plancontext.PlanningContext, filter *Filter, for _, p := range filter.Predicates { cantPush := false _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - if !mustFetchFromInput(node) { + if !mustFetchFromInput(ctx, node) { return true, nil } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 14bea4f4674..c5db49b37ee 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -29,7 +29,6 @@ import ( "vitess.io/vitess/go/vt/vtgate/engine/opcode" "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/vtgate/semantics" ) type ( @@ -89,12 +88,6 @@ type ( SubQueryExpression []*SubQuery } - - AggrRewriter struct { - qp *QueryProjection - st *semantics.SemTable - failed bool - } ) func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool { @@ -173,80 +166,25 @@ func createQPFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) Distinct: sel.Distinct, } - qp.addSelectExpressions(sel) + qp.addSelectExpressions(ctx, sel) qp.addGroupBy(ctx, sel.GroupBy) qp.addOrderBy(ctx, sel.OrderBy) if !qp.HasAggr && sel.Having != nil { - qp.HasAggr = containsAggr(sel.Having.Expr) + qp.HasAggr = ContainsAggr(ctx, sel.Having.Expr) } qp.calculateDistinct(ctx) return qp } -// RewriteDown stops the walker from entering inside aggregation functions -func (ar *AggrRewriter) RewriteDown() func(sqlparser.SQLNode, sqlparser.SQLNode) bool { - return func(node, _ sqlparser.SQLNode) bool { - if ar.failed { - return true - } - _, ok := node.(sqlparser.AggrFunc) - return !ok - } -} - -// RewriteUp will go through an expression, add aggregations to the QP, and rewrite them to use column offset -func (ar *AggrRewriter) RewriteUp() func(*sqlparser.Cursor) bool { - return func(cursor *sqlparser.Cursor) bool { - if ar.failed { - return false - } - sqlNode := cursor.Node() - fExp, ok := sqlNode.(sqlparser.AggrFunc) - if !ok { - return true - } - for offset, expr := range ar.qp.SelectExprs { - ae, err := expr.GetAliasedExpr() - if err != nil { - ar.failed = true - return false - } - if ar.st.EqualsExprWithDeps(ae.Expr, fExp) { - cursor.Replace(sqlparser.NewOffset(offset, fExp)) - return true - } - } - - col := SelectExpr{ - Aggr: true, - Col: &sqlparser.AliasedExpr{Expr: fExp}, - } - ar.qp.HasAggr = true - cursor.Replace(sqlparser.NewOffset(len(ar.qp.SelectExprs), fExp)) - ar.qp.SelectExprs = append(ar.qp.SelectExprs, col) - ar.qp.AddedColumn++ - - return true - } -} - -// AggrRewriter extracts -func (qp *QueryProjection) AggrRewriter(ctx *plancontext.PlanningContext) *AggrRewriter { - return &AggrRewriter{ - qp: qp, - st: ctx.SemTable, - } -} - -func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) { +func (qp *QueryProjection) addSelectExpressions(ctx *plancontext.PlanningContext, sel *sqlparser.Select) { for _, selExp := range sel.SelectExprs { switch selExp := selExp.(type) { case *sqlparser.AliasedExpr: col := SelectExpr{ Col: selExp, } - if containsAggr(selExp.Expr) { + if ContainsAggr(ctx, selExp.Expr) { col.Aggr = true qp.HasAggr = true } @@ -263,7 +201,18 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) { } } -func containsAggr(e sqlparser.SQLNode) (hasAggr bool) { +func IsAggr(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool { + switch node := e.(type) { + case sqlparser.AggrFunc: + return true + case *sqlparser.FuncExpr: + return node.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs()) + } + + return false +} + +func ContainsAggr(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) (hasAggr bool) { _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch node.(type) { case *sqlparser.Offset: @@ -275,6 +224,11 @@ func containsAggr(e sqlparser.SQLNode) (hasAggr bool) { return false, io.EOF case *sqlparser.Subquery: return false, nil + case *sqlparser.FuncExpr: + if IsAggr(ctx, node) { + hasAggr = true + return false, io.EOF + } } return true, nil @@ -287,7 +241,7 @@ func createQPFromUnion(ctx *plancontext.PlanningContext, union *sqlparser.Union) qp := &QueryProjection{} sel := sqlparser.GetFirstSelect(union) - qp.addSelectExpressions(sel) + qp.addSelectExpressions(ctx, sel) qp.addOrderBy(ctx, union.OrderBy) return qp @@ -325,7 +279,7 @@ func (qp *QueryProjection) addOrderBy(ctx *plancontext.PlanningContext, orderBy Inner: ctx.SemTable.Clone(order).(*sqlparser.Order), SimplifiedExpr: order.Expr, }) - canPushSorting = canPushSorting && !containsAggr(order.Expr) + canPushSorting = canPushSorting && !ContainsAggr(ctx, order.Expr) } } @@ -371,7 +325,7 @@ func (qp *QueryProjection) addGroupBy(ctx *plancontext.PlanningContext, groupBy es := &expressionSet{} for _, grouping := range groupBy { selectExprIdx := qp.FindSelectExprIndexForExpr(ctx, grouping) - checkForInvalidGroupingExpressions(grouping) + checkForInvalidGroupingExpressions(ctx, grouping) if !es.add(ctx, grouping) { continue @@ -480,7 +434,7 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte idxCopy := idx - if !containsAggr(expr.Col) { + if !ContainsAggr(ctx, expr.Col) { getExpr, err := expr.GetExpr() if err != nil { panic(err) @@ -492,8 +446,7 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte } continue } - _, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc) - if !isAggregate && !allowComplexExpression { + if !IsAggr(ctx, aliasedExpr.Expr) && !allowComplexExpression { panic(vterrors.VT12001("in scatter query: complex aggregate expression")) } @@ -524,7 +477,15 @@ func (qp *QueryProjection) extractAggr( addAggr(aggrFunc) return false } - if containsAggr(node) { + if IsAggr(ctx, node) { + // If we are here, we have a function that is an aggregation but not parsed into an AggrFunc. + // This is the case for UDFs - we have to be careful with these because we can't evaluate them in VTGate. + aggr := NewAggr(opcode.AggregateUDF, nil, aeWrap(ex), "") + aggr.Index = &idx + addAggr(aggr) + return false + } + if ContainsAggr(ctx, node) { makeComplex() return true } @@ -553,7 +514,7 @@ orderBy: } qp.SelectExprs = append(qp.SelectExprs, SelectExpr{ Col: &sqlparser.AliasedExpr{Expr: orderExpr}, - Aggr: containsAggr(orderExpr), + Aggr: ContainsAggr(ctx, orderExpr), }) qp.AddedColumn++ } @@ -732,9 +693,9 @@ func (qp *QueryProjection) useGroupingOverDistinct(ctx *plancontext.PlanningCont return true } -func checkForInvalidGroupingExpressions(expr sqlparser.Expr) { +func checkForInvalidGroupingExpressions(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { - if _, isAggregate := node.(sqlparser.AggrFunc); isAggregate { + if IsAggr(ctx, node) { panic(vterrors.VT03005(sqlparser.String(expr))) } _, isSubQ := node.(*sqlparser.Subquery) diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index af25136b16a..c7216e3bdae 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -56,11 +56,11 @@ func isMergeable(ctx *plancontext.PlanningContext, query sqlparser.SelectStateme // if we have grouping, we have already checked that it's safe, and don't need to check for aggregations // but if we don't have groupings, we need to check if there are aggregations that will mess with us - if sqlparser.ContainsAggregation(node.SelectExprs) { + if ContainsAggr(ctx, node.SelectExprs) { return false } - if sqlparser.ContainsAggregation(node.Having) { + if ContainsAggr(ctx, node.Having) { return false } diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 6a2d0fe8b3f..5ac3cf8a7cc 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -586,23 +586,18 @@ func TestOtherPlanningFromFile(t *testing.T) { func loadSchema(t testing.TB, filename string, setCollation bool) *vindexes.VSchema { formal, err := vindexes.LoadFormal(locateFile(filename)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) vschema := vindexes.BuildVSchema(formal, sqlparser.NewTestParser()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for _, ks := range vschema.Keyspaces { - if ks.Error != nil { - t.Fatal(ks.Error) - } + require.NoError(t, ks.Error) // adding view in user keyspace if ks.Keyspace.Name == "user" { - if err = vschema.AddView(ks.Keyspace.Name, "user_details_view", "select user.id, user_extra.col from user join user_extra on user.id = user_extra.user_id", sqlparser.NewTestParser()); err != nil { - t.Fatal(err) - } + err = vschema.AddView(ks.Keyspace.Name, "user_details_view", "select user.id, user_extra.col from user join user_extra on user.id = user_extra.user_id", sqlparser.NewTestParser()) + require.NoError(t, err) + err = vschema.AddUDF(ks.Keyspace.Name, "udf_aggr") + require.NoError(t, err) } // setting a default value to all the text columns in the tables of this keyspace diff --git a/go/vt/vtgate/planbuilder/plancontext/vschema.go b/go/vt/vtgate/planbuilder/plancontext/vschema.go index 1bca80cbc94..8ac4c57bfd7 100644 --- a/go/vt/vtgate/planbuilder/plancontext/vschema.go +++ b/go/vt/vtgate/planbuilder/plancontext/vschema.go @@ -93,6 +93,9 @@ type VSchema interface { // StorePrepareData stores the prepared data in the session. StorePrepareData(name string, v *vtgatepb.PrepareData) + + // GetAggregateUDFs returns the list of aggregate UDFs. + GetAggregateUDFs() []string } // PlannerNameToVersion returns the numerical representation of the planner diff --git a/go/vt/vtgate/planbuilder/planner_test.go b/go/vt/vtgate/planbuilder/planner_test.go index 2601615522f..6ad1bb4116c 100644 --- a/go/vt/vtgate/planbuilder/planner_test.go +++ b/go/vt/vtgate/planbuilder/planner_test.go @@ -19,6 +19,8 @@ package planbuilder import ( "testing" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/sqlparser" @@ -67,9 +69,13 @@ func TestBindingSubquery(t *testing.T) { "foo": {Name: sqlparser.NewIdentifierCS("foo")}, }, }) + ctx := &plancontext.PlanningContext{ + ReservedVars: sqlparser.NewReservedVars("vt", make(sqlparser.BindVars)), + SemTable: semTable, + } require.NoError(t, err) if testcase.rewrite { - err = queryRewrite(semTable, sqlparser.NewReservedVars("vt", make(sqlparser.BindVars)), selStmt) + err = queryRewrite(ctx, selStmt) require.NoError(t, err) } expr := testcase.extractor(selStmt) diff --git a/go/vt/vtgate/planbuilder/rewrite.go b/go/vt/vtgate/planbuilder/rewrite.go index 915b5e753cd..30423229038 100644 --- a/go/vt/vtgate/planbuilder/rewrite.go +++ b/go/vt/vtgate/planbuilder/rewrite.go @@ -18,19 +18,18 @@ package planbuilder import ( "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vtgate/semantics" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) type rewriter struct { - semTable *semantics.SemTable - reservedVars *sqlparser.ReservedVars - err error + err error + ctx *plancontext.PlanningContext } -func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.ReservedVars, statement sqlparser.Statement) error { +func queryRewrite(ctx *plancontext.PlanningContext, statement sqlparser.Statement) error { r := rewriter{ - semTable: semTable, - reservedVars: reservedVars, + ctx: ctx, } sqlparser.Rewrite(statement, r.rewriteDown, nil) return nil @@ -39,15 +38,15 @@ func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.Reserved func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.Select: - rewriteHavingClause(node) + r.rewriteHavingClause(node) case *sqlparser.AliasedTableExpr: if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived { break } // find the tableSet and tableInfo that this table points to // tableInfo should contain the information for the original table that the routed table points to - tableSet := r.semTable.TableSetFor(node) - tableInfo, err := r.semTable.TableInfoFor(tableSet) + tableSet := r.ctx.SemTable.TableSetFor(node) + tableInfo, err := r.ctx.SemTable.TableInfoFor(tableSet) if err != nil { // Fail-safe code, should never happen break @@ -77,7 +76,7 @@ func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { return true } -func rewriteHavingClause(node *sqlparser.Select) { +func (r *rewriter) rewriteHavingClause(node *sqlparser.Select) { if node.Having == nil { return } @@ -89,7 +88,7 @@ func rewriteHavingClause(node *sqlparser.Select) { exprs := sqlparser.SplitAndExpression(nil, node.Having.Expr) node.Having = nil for _, expr := range exprs { - if sqlparser.ContainsAggregation(expr) { + if operators.ContainsAggr(r.ctx, expr) { node.AddHaving(expr) } else { node.AddWhere(expr) diff --git a/go/vt/vtgate/planbuilder/rewrite_test.go b/go/vt/vtgate/planbuilder/rewrite_test.go index 87c8985fd63..7902b69e8f9 100644 --- a/go/vt/vtgate/planbuilder/rewrite_test.go +++ b/go/vt/vtgate/planbuilder/rewrite_test.go @@ -19,6 +19,8 @@ package planbuilder import ( "testing" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -74,7 +76,11 @@ func TestHavingRewrite(t *testing.T) { for _, tcase := range tcases { t.Run(tcase.input, func(t *testing.T) { semTable, reservedVars, sel := prepTest(t, tcase.input) - err := queryRewrite(semTable, reservedVars, sel) + ctx := &plancontext.PlanningContext{ + ReservedVars: reservedVars, + SemTable: semTable, + } + err := queryRewrite(ctx, sel) require.NoError(t, err) assert.Equal(t, tcase.output, sqlparser.String(sel)) }) diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 13671e7efa0..83a6ba650f4 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -225,7 +225,7 @@ func newBuildSelectPlan( } func createSelectOperator(ctx *plancontext.PlanningContext, selStmt sqlparser.SelectStatement, reservedVars *sqlparser.ReservedVars) (operators.Operator, error) { - err := queryRewrite(ctx.SemTable, reservedVars, selStmt) + err := queryRewrite(ctx, selStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index f7e556956e3..311f0f874cc 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6958,5 +6958,80 @@ "comment": "baz in the HAVING clause can't be accessed because of the GROUP BY", "query": "select foo, count(bar) as x from user group by foo having baz > avg(baz) order by x", "plan": "Unknown column 'baz' in 'having clause'" + }, + { + "comment": "Aggregate UDFs can't be handled by vtgate", + "query": "select id from t1 group by id having udf_aggr(foo) > 1 and sum(foo) = 10", + "plan": "VT12001: unsupported: Aggregate UDF 'udf_aggr(foo)' must be pushed down to MySQL" + }, + { + "comment": "Valid to run since we can push down the aggregate function because of the grouping", + "query": "select id from user group by id having udf_aggr(foo) > 1", + "plan": { + "QueryType": "SELECT", + "Original": "select id from user group by id having udf_aggr(foo) > 1", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from `user` where 1 != 1 group by id", + "Query": "select id from `user` group by id having udf_aggr(foo) > 1", + "Table": "`user`" + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "Valid to run since we can push down the aggregate function because it's unsharded", + "query": "select bar, udf_aggr(foo) from unsharded group by bar", + "plan": { + "QueryType": "SELECT", + "Original": "select bar, udf_aggr(foo) from unsharded group by bar", + "Instructions": { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select bar, udf_aggr(foo) from unsharded where 1 != 1 group by bar", + "Query": "select bar, udf_aggr(foo) from unsharded group by bar", + "Table": "unsharded" + }, + "TablesUsed": [ + "main.unsharded" + ] + } + }, + { + "comment": "Valid to run since we can push down the aggregate function because the where clause using the sharding key", + "query": "select bar, udf_aggr(foo) from user where id = 17 group by bar", + "plan": { + "QueryType": "SELECT", + "Original": "select bar, udf_aggr(foo) from user where id = 17 group by bar", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select bar, udf_aggr(foo) from `user` where 1 != 1 group by bar", + "Query": "select bar, udf_aggr(foo) from `user` where id = 17 group by bar", + "Table": "`user`", + "Values": [ + "17" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 224c41d43eb..abd467fbfd0 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -5822,12 +5822,12 @@ { "comment": "update with multi table reference with multi target update on a derived table", "query": "update ignore (select foo, col, bar from user) u, music m set u.foo = 21, u.bar = 'abc' where u.col = m.col", - "plan": "VT03031: the target table (select foo, col, bar from `user`) as u of the UPDATE is not updatable" + "plan": "VT03032: the target table (select foo, col, bar from `user`) as u of the UPDATE is not updatable" }, { "comment": "update with derived table", "query": "update (select id from user) as u set id = 4", - "plan": "VT03031: the target table (select id from `user`) as u of the UPDATE is not updatable" + "plan": "VT03032: the target table (select id from `user`) as u of the UPDATE is not updatable" }, { "comment": "Delete with routed table on music", diff --git a/go/vt/vtgate/planbuilder/update.go b/go/vt/vtgate/planbuilder/update.go index 124eaf87310..313f33b6bf1 100644 --- a/go/vt/vtgate/planbuilder/update.go +++ b/go/vt/vtgate/planbuilder/update.go @@ -41,7 +41,7 @@ func gen4UpdateStmtPlanner( return nil, err } - err = queryRewrite(ctx.SemTable, reservedVars, updStmt) + err = queryRewrite(ctx, updStmt) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/vexplain.go b/go/vt/vtgate/planbuilder/vexplain.go index 7b200fb2e09..db62da75122 100644 --- a/go/vt/vtgate/planbuilder/vexplain.go +++ b/go/vt/vtgate/planbuilder/vexplain.go @@ -128,7 +128,7 @@ func explainPlan(explain *sqlparser.ExplainStmt, reservedVars *sqlparser.Reserve return nil, vterrors.VT03031() } - if err = queryRewrite(ctx.SemTable, reservedVars, explain.Statement); err != nil { + if err = queryRewrite(ctx, explain.Statement); err != nil { return nil, err } diff --git a/go/vt/vtgate/semantics/FakeSI.go b/go/vt/vtgate/semantics/FakeSI.go index 933f4cd40f8..1ca6718f1a8 100644 --- a/go/vt/vtgate/semantics/FakeSI.go +++ b/go/vt/vtgate/semantics/FakeSI.go @@ -36,6 +36,7 @@ type FakeSI struct { VindexTables map[string]vindexes.Vindex KsForeignKeyMode map[string]vschemapb.Keyspace_ForeignKeyMode KsError map[string]error + UDFs []string } // FindTableOrVindex implements the SchemaInformation interface @@ -80,3 +81,7 @@ func (s *FakeSI) KeyspaceError(keyspace string) error { } return nil } + +func (s *FakeSI) GetAggregateUDFs() []string { + return s.UDFs +} diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index bfd5f413f80..b872a1dde04 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -53,7 +53,7 @@ type analyzer struct { // newAnalyzer create the semantic analyzer func newAnalyzer(dbName string, si SchemaInformation, fullAnalysis bool) *analyzer { // TODO dependencies between these components are a little tangled. We should try to clean up - s := newScoper() + s := newScoper(si) a := &analyzer{ scoper: s, earlyTables: newEarlyTableCollector(si, dbName), @@ -78,6 +78,7 @@ func (a *analyzer) lateInit() { aliasMapCache: map[*sqlparser.Select]map[string]exprContainer{}, reAnalyze: a.reAnalyze, tables: a.tables, + aggrUDFs: a.si.GetAggregateUDFs(), } a.fk = &fkManager{ binder: a.binder, diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 2e67509c06f..61abd9c3fa7 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -41,6 +41,7 @@ type earlyRewriter struct { // have happened, and we are introducing or changing the AST. We invoke it so all parts of the query have been // typed, scoped and bound correctly reAnalyze func(n sqlparser.SQLNode) error + aggrUDFs []string } func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { @@ -508,7 +509,10 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars } aliases := r.getAliasMap(sel) - aggrTrack := &aggrTracker{} + aggrTrack := &aggrTracker{ + insideAggr: false, + aggrUDFs: r.aggrUDFs, + } output := sqlparser.CopyOnRewrite(node, aggrTrack.down, func(cursor *sqlparser.CopyOnWriteCursor) { var col *sqlparser.ColName @@ -516,6 +520,11 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars case sqlparser.AggrFunc: aggrTrack.popAggr() return + case *sqlparser.FuncExpr: + if node.Name.EqualsAnyString(r.aggrUDFs) { + aggrTrack.popAggr() + } + return case *sqlparser.ColName: col = node default: @@ -565,14 +574,19 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars type aggrTracker struct { insideAggr bool + aggrUDFs []string } func (at *aggrTracker) down(node, _ sqlparser.SQLNode) bool { - switch node.(type) { + switch node := node.(type) { case *sqlparser.Subquery: return false case sqlparser.AggrFunc: at.insideAggr = true + case *sqlparser.FuncExpr: + if node.Name.EqualsAnyString(at.aggrUDFs) { + at.insideAggr = true + } } return true @@ -738,7 +752,10 @@ func (r *earlyRewriter) rewriteOrderByLiteral(node *sqlparser.Literal) (expr sql } if num < 1 || num > len(stmt.SelectExprs) { - return nil, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) + return nil, false, &ColumnNotFoundClauseError{ + Column: fmt.Sprintf("%d", num), + Clause: r.clause, + } } // We loop like this instead of directly accessing the offset, to make sure there are no unexpanded `*` before diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index c44d6f6307d..ae3c040697d 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -536,6 +536,15 @@ func TestHavingColumnName(t *testing.T) { expSQL: "select id, sum(t1.foo) as foo from t1 having sum(foo) > 1", expDeps: TS0, warning: "Column 'foo' in having clause is ambiguous", + }, { + sql: "select id, sum(t1.foo) as foo from t1 having custom_udf(foo) > 1", + expSQL: "select id, sum(t1.foo) as foo from t1 having custom_udf(foo) > 1", + expDeps: TS0, + warning: "Column 'foo' in having clause is ambiguous", + }, { + sql: "select id, custom_udf(t1.foo) as foo from t1 having foo > 1", + expSQL: "select id, custom_udf(t1.foo) as foo from t1 having custom_udf(t1.foo) > 1", + expDeps: TS0, }, { sql: "select id, sum(t1.foo) as XYZ from t1 having sum(XYZ) > 1", expErr: "Invalid use of group function", @@ -640,6 +649,7 @@ func getSchemaWithKnownColumns() *FakeSI { ColumnListAuthoritative: true, }, }, + UDFs: []string{"custom_udf"}, } return schemaInfo } diff --git a/go/vt/vtgate/semantics/info_schema.go b/go/vt/vtgate/semantics/info_schema.go index d7470e2fd0a..11e577f3fa7 100644 --- a/go/vt/vtgate/semantics/info_schema.go +++ b/go/vt/vtgate/semantics/info_schema.go @@ -1661,3 +1661,7 @@ func (i *infoSchemaWithColumns) GetForeignKeyChecksState() *bool { func (i *infoSchemaWithColumns) KeyspaceError(keyspace string) error { return i.inner.KeyspaceError(keyspace) } + +func (i *infoSchemaWithColumns) GetAggregateUDFs() []string { + return i.inner.GetAggregateUDFs() +} diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 3a6fbe4c35c..5901df51af7 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -37,6 +37,7 @@ type ( // These scopes are only used for rewriting ORDER BY 1 and GROUP BY 1 specialExprScopes map[*sqlparser.Literal]*scope statementIDs map[sqlparser.Statement]TableSet + si SchemaInformation } scope struct { @@ -53,12 +54,13 @@ type ( } ) -func newScoper() *scoper { +func newScoper(si SchemaInformation) *scoper { return &scoper{ rScope: map[*sqlparser.Select]*scope{}, wScope: map[*sqlparser.Select]*scope{}, specialExprScopes: map[*sqlparser.Literal]*scope{}, statementIDs: map[sqlparser.Statement]TableSet{}, + si: si, } } @@ -84,6 +86,13 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { break } s.currentScope().inHavingAggr = true + case *sqlparser.FuncExpr: + if !s.currentScope().inHaving { + break + } + if node.Name.EqualsAnyString(s.si.GetAggregateUDFs()) { + s.currentScope().inHavingAggr = true + } case *sqlparser.Where: if node.Type == sqlparser.HavingClause { err := s.createSpecialScopePostProjection(cursor.Parent()) diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 1ea4bc2a889..6c6e495b33d 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -161,6 +161,7 @@ type ( ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) GetForeignKeyChecksState() *bool KeyspaceError(keyspace string) error + GetAggregateUDFs() []string } shortCut = int diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index 15c6296f108..9372012f77d 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -1075,6 +1075,10 @@ func (vc *vcursorImpl) KeyspaceError(keyspace string) error { return ks.Error } +func (vc *vcursorImpl) GetAggregateUDFs() []string { + return vc.vschema.GetAggregateUDFs() +} + // ParseDestinationTarget parses destination target string and sets default keyspace if possible. func parseDestinationTarget(targetString string, vschema *vindexes.VSchema) (string, topodatapb.TabletType, key.Destination, error) { destKeyspace, destTabletType, dest, err := topoprotopb.ParseDestination(targetString, defaultTabletType) diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index 9e21505690c..8dc889fc848 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -245,6 +245,9 @@ type KeyspaceSchema struct { Views map[string]sqlparser.SelectStatement Error error MultiTenantSpec *vschemapb.MultiTenantSpec + + // These are the UDFs that exist in the schema and are aggregations + AggregateUDFs []string } type ksJSON struct { @@ -422,6 +425,18 @@ func (vschema *VSchema) AddView(ksname, viewName, query string, parser *sqlparse return nil } +// AddUDF adds a UDF to an existing keyspace in the VSchema. +// It's only used from tests. +func (vschema *VSchema) AddUDF(ksname, udfName string) error { + ks, ok := vschema.Keyspaces[ksname] + if !ok { + return fmt.Errorf("keyspace %s not found in vschema", ksname) + } + + ks.AggregateUDFs = append(ks.AggregateUDFs, udfName) + return nil +} + func buildGlobalTables(source *vschemapb.SrvVSchema, vschema *VSchema) { for ksname, ks := range source.Keyspaces { ksvschema := vschema.Keyspaces[ksname] @@ -1272,6 +1287,20 @@ func (vschema *VSchema) ResetCreated() { vschema.created = time.Time{} } +func (vschema *VSchema) GetAggregateUDFs() (udfs []string) { + seen := make(map[string]bool) + for _, ks := range vschema.Keyspaces { + for _, udf := range ks.AggregateUDFs { + if seen[udf] { + continue + } + seen[udf] = true + udfs = append(udfs, udf) + } + } + return +} + // ByCost provides the interface needed for ColumnVindexes to // be sorted by cost order. type ByCost []*ColumnVindex diff --git a/go/vt/wrangler/vexec_plan.go b/go/vt/wrangler/vexec_plan.go index 1878c25441c..76b2d0fe732 100644 --- a/go/vt/wrangler/vexec_plan.go +++ b/go/vt/wrangler/vexec_plan.go @@ -248,13 +248,7 @@ func (vx *vexec) buildUpdatePlan(ctx context.Context, planner vexecPlanner, upd if updatableColumnNames := plannerParams.updatableColumnNames; len(updatableColumnNames) > 0 { // if updatableColumnNames is non empty, then we must only accept changes to columns listed there for _, expr := range upd.Exprs { - isUpdatable := false - for _, updatableColName := range updatableColumnNames { - if expr.Name.Name.EqualString(updatableColName) { - isUpdatable = true - } - } - if !isUpdatable { + if !expr.Name.Name.EqualsAnyString(updatableColumnNames) { return nil, fmt.Errorf("%+v cannot be changed: %v", expr.Name.Name, sqlparser.String(expr)) } }