Skip to content

Commit

Permalink
feat: allow the planner to understand aggregate UDFs
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Apr 15, 2024
1 parent e6a4ee7 commit 773f6a0
Show file tree
Hide file tree
Showing 27 changed files with 149 additions and 84 deletions.
1 change: 1 addition & 0 deletions go/mysql/sqlerror/sql_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{
vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState},
vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation},
vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState},
vterrors.AggregateMustPushDown: {num: ERNotSupportedYet, state: SSUnknownSQLState},
}

func getStateToMySQLState(state vterrors.State) mysqlCode {
Expand Down
10 changes: 10 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,16 @@ func (node IdentifierCI) EqualString(str string) bool {
return node.Lowered() == strings.ToLower(str)
}

// EqualsAnyString returns true if any of these strings match
func (node IdentifierCI) EqualsAnyString(str []string) bool {
for _, s := range str {
if node.EqualString(s) {
return true
}
}
return false
}

// MarshalJSON marshals into JSON.
func (node IdentifierCI) MarshalJSON() ([]byte, error) {
return json.Marshal(node.val)
Expand Down
3 changes: 2 additions & 1 deletion go/vt/vterrors/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ var (
VT03029 = errorWithState("VT03029", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "column count does not match value count with the row for vindex '%s'", "The number of columns you want to insert do not match the number of columns of your SELECT query.")
VT03030 = errorWithState("VT03030", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "lookup column count does not match value count with the row (columns, count): (%v, %d)", "The number of columns you want to insert do not match the number of columns of your SELECT query.")
VT03031 = errorWithoutState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, "EXPLAIN is only supported for single keyspace", "EXPLAIN has to be sent down as a single query to the underlying MySQL, and this is not possible if it uses tables from multiple keyspaces")
VT03032 = errorWithState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.")
VT03032 = errorWithState("VT03032", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.")
VT03033 = errorWithState("VT03033", vtrpcpb.Code_INVALID_ARGUMENT, AggregateMustPushDown, "aggregate user-defined function %s must be pushed down to mysql", "The aggregate user-defined function must be pushed down to mysql and can't be evaluated on the vtgate. The query contains aggregation that can't be fully pushed down to MySQL.")

VT05001 = errorWithState("VT05001", vtrpcpb.Code_NOT_FOUND, DbDropExists, "cannot drop database '%s'; database does not exists", "The given database does not exist; Vitess cannot drop it.")
VT05002 = errorWithState("VT05002", vtrpcpb.Code_NOT_FOUND, BadDb, "cannot alter database '%s'; unknown database", "The given database does not exist; Vitess cannot alter it.")
Expand Down
1 change: 1 addition & 0 deletions go/vt/vterrors/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const (
WrongArguments
BadNullError
InvalidGroupFuncUse
AggregateMustPushDown

// failed precondition
NoDB
Expand Down
14 changes: 1 addition & 13 deletions go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,10 @@ const (
AggregateCountStar
AggregateGroupConcat
AggregateAvg
AggregateUDF // This is an opcode used to represent UDFs
_NumOfOpCodes // This line must be last of the opcodes!
)

var (
// OpcodeType keeps track of the known output types for different aggregate functions
OpcodeType = map[AggregateOpcode]querypb.Type{
AggregateCountDistinct: sqltypes.Int64,
AggregateCount: sqltypes.Int64,
AggregateCountStar: sqltypes.Int64,
AggregateSumDistinct: sqltypes.Decimal,
AggregateSum: sqltypes.Decimal,
AggregateAvg: sqltypes.Decimal,
AggregateGtid: sqltypes.VarChar,
}
)

// SupportedAggregates maps the list of supported aggregate
// functions to their opcodes.
var SupportedAggregates = map[string]AggregateOpcode{
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func gen4DeleteStmtPlanner(
return nil, err
}

err = queryRewrite(ctx.SemTable, reservedVars, deleteStmt)
err = queryRewrite(ctx, deleteStmt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func gen4InsertStmtPlanner(version querypb.ExecuteOptions_PlannerVersion, insStm
return nil, err
}

err = queryRewrite(ctx.SemTable, reservedVars, insStmt)
err = queryRewrite(ctx, insStmt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {

switch stmt := qb.stmt.(type) {
case *sqlparser.Select:
if containsAggr(expr) {
if ContainsAggr(qb.ctx, expr) {
addPred = stmt.AddHaving
} else {
addPred = stmt.AddWhere
Expand Down
8 changes: 7 additions & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser
return newFilter(a, expr)
}

func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int {
func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int {
offset := len(a.Columns)
a.Columns = append(a.Columns, expr)

Expand All @@ -96,6 +96,12 @@ func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, exp
switch e := expr.Expr.(type) {
case sqlparser.AggrFunc:
aggr = createAggrFromAggrFunc(e, expr)
case *sqlparser.FuncExpr:
if IsAggr(ctx, e) {
aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.As.String())
} else {
aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String())
}
default:
aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String())
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func breakExpressionInLHSandRHSForApplyJoin(
) (col applyJoinColumn) {
rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
nodeExpr, ok := cursor.Node().(sqlparser.Expr)
if !ok || !mustFetchFromInput(nodeExpr) {
if !ok || !mustFetchFromInput(ctx, nodeExpr) {
return
}
deps := ctx.SemTable.RecursiveDeps(nodeExpr)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp
}
inOffset := op.FindCol(ctx, expr, false)
if inOffset == -1 {
if !mustFetchFromInput(expr) {
if !mustFetchFromInput(ctx, expr) {
return -1
}

Expand Down Expand Up @@ -398,7 +398,7 @@ func (hj *HashJoin) addSingleSidedColumn(
}
inOffset := op.FindCol(ctx, expr, false)
if inOffset == -1 {
if !mustFetchFromInput(expr) {
if !mustFetchFromInput(ctx, expr) {
return -1
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/horizon.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.
}

newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo)
if sqlparser.ContainsAggregation(newExpr) {
if ContainsAggr(ctx, newExpr) {
return newFilter(h, expr)
}
h.Source = h.Source.AddPredicate(ctx, newExpr)
Expand Down
38 changes: 30 additions & 8 deletions go/vt/vtgate/planbuilder/operators/offset_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package operators
import (
"fmt"

"vitess.io/vitess/go/vt/vtgate/engine/opcode"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand Down Expand Up @@ -56,10 +58,12 @@ func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator {
}

// mustFetchFromInput returns true for expressions that have to be fetched from the input and cannot be evaluated
func mustFetchFromInput(e sqlparser.SQLNode) bool {
switch e.(type) {
func mustFetchFromInput(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool {
switch fun := e.(type) {
case *sqlparser.ColName, sqlparser.AggrFunc:
return true
case *sqlparser.FuncExpr:
return fun.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs())
default:
return false
}
Expand Down Expand Up @@ -93,10 +97,10 @@ func useOffsets(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operat
return rewritten.(sqlparser.Expr)
}

// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator {
visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
filter, ok := in.(*Filter)
if !ok {
return in, NoRewrite
Expand Down Expand Up @@ -125,12 +129,30 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator

return in, NoRewrite
}
failUDFAggregation := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
aggrOp, ok := in.(*Aggregator)
if !ok {
return in, NoRewrite
}
for _, aggr := range aggrOp.Aggregations {
if aggr.OpCode == opcode.AggregateUDF {
// we don't support UDFs in aggregation if it's still above a route
panic(vterrors.VT03033(sqlparser.String(aggr.Original.Expr)))
}
}
return in, NoRewrite
}

visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
out, res := addColumnsNeededByFilter(in, semantics.EmptyTableSet(), isRoot)
failUDFAggregation(in, semantics.EmptyTableSet(), isRoot)
return out, res
}

return TopDown(root, TableID, visitor, stopAtRoute)
}

// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
// pullDistinctFromUNION will pull out the distinct from a union operator
func pullDistinctFromUNION(_ *plancontext.PlanningContext, root Operator) Operator {
visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
union, ok := in.(*Union)
Expand Down Expand Up @@ -170,7 +192,7 @@ func getOffsetRewritingVisitor(
return false
}

if mustFetchFromInput(e) {
if mustFetchFromInput(ctx, e) {
notFound(e)
return false
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator,
case *Projection:
// we can move ordering under a projection if it's not introducing a column we're sorting by
for _, by := range in.Order {
if !mustFetchFromInput(by.SimplifiedExpr) {
if !mustFetchFromInput(ctx, by.SimplifiedExpr) {
return in, NoRewrite
}
}
Expand Down Expand Up @@ -459,7 +459,7 @@ func pushFilterUnderProjection(ctx *plancontext.PlanningContext, filter *Filter,
for _, p := range filter.Predicates {
cantPush := false
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
if !mustFetchFromInput(node) {
if !mustFetchFromInput(ctx, node) {
return true, nil
}

Expand Down
47 changes: 35 additions & 12 deletions go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ func createQPFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select)
Distinct: sel.Distinct,
}

qp.addSelectExpressions(sel)
qp.addSelectExpressions(ctx, sel)
qp.addGroupBy(ctx, sel.GroupBy)
qp.addOrderBy(ctx, sel.OrderBy)
if !qp.HasAggr && sel.Having != nil {
qp.HasAggr = containsAggr(sel.Having.Expr)
qp.HasAggr = ContainsAggr(ctx, sel.Having.Expr)
}
qp.calculateDistinct(ctx)

Expand Down Expand Up @@ -239,14 +239,14 @@ func (qp *QueryProjection) AggrRewriter(ctx *plancontext.PlanningContext) *AggrR
}
}

func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) {
func (qp *QueryProjection) addSelectExpressions(ctx *plancontext.PlanningContext, sel *sqlparser.Select) {
for _, selExp := range sel.SelectExprs {
switch selExp := selExp.(type) {
case *sqlparser.AliasedExpr:
col := SelectExpr{
Col: selExp,
}
if containsAggr(selExp.Expr) {
if ContainsAggr(ctx, selExp.Expr) {
col.Aggr = true
qp.HasAggr = true
}
Expand All @@ -263,7 +263,18 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) {
}
}

func containsAggr(e sqlparser.SQLNode) (hasAggr bool) {
func IsAggr(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool {
switch node := e.(type) {
case sqlparser.AggrFunc:
return true
case *sqlparser.FuncExpr:
return node.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs())
}

return false
}

func ContainsAggr(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) (hasAggr bool) {
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node.(type) {
case *sqlparser.Offset:
Expand All @@ -275,6 +286,11 @@ func containsAggr(e sqlparser.SQLNode) (hasAggr bool) {
return false, io.EOF
case *sqlparser.Subquery:
return false, nil
case *sqlparser.FuncExpr:
if IsAggr(ctx, node) {
hasAggr = true
return false, io.EOF
}
}

return true, nil
Expand All @@ -287,7 +303,7 @@ func createQPFromUnion(ctx *plancontext.PlanningContext, union *sqlparser.Union)
qp := &QueryProjection{}

sel := sqlparser.GetFirstSelect(union)
qp.addSelectExpressions(sel)
qp.addSelectExpressions(ctx, sel)
qp.addOrderBy(ctx, union.OrderBy)

return qp
Expand Down Expand Up @@ -325,7 +341,7 @@ func (qp *QueryProjection) addOrderBy(ctx *plancontext.PlanningContext, orderBy
Inner: ctx.SemTable.Clone(order).(*sqlparser.Order),
SimplifiedExpr: order.Expr,
})
canPushSorting = canPushSorting && !containsAggr(order.Expr)
canPushSorting = canPushSorting && !ContainsAggr(ctx, order.Expr)
}
}

Expand Down Expand Up @@ -480,7 +496,7 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte

idxCopy := idx

if !containsAggr(expr.Col) {
if !ContainsAggr(ctx, expr.Col) {
getExpr, err := expr.GetExpr()
if err != nil {
panic(err)
Expand All @@ -492,8 +508,7 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte
}
continue
}
_, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc)
if !isAggregate && !allowComplexExpression {
if !IsAggr(ctx, aliasedExpr.Expr) && !allowComplexExpression {
panic(vterrors.VT12001("in scatter query: complex aggregate expression"))
}

Expand Down Expand Up @@ -524,7 +539,15 @@ func (qp *QueryProjection) extractAggr(
addAggr(aggrFunc)
return false
}
if containsAggr(node) {
if IsAggr(ctx, node) {
// If we are here, we have a function that is an aggregation but not parsed into an AggrFunc.
// This is the case for UDFs - we have to be careful with these because we can't evaluate them in VTGate.
aggr := NewAggr(opcode.AggregateUDF, nil, aeWrap(ex), "")
aggr.Index = &idx
addAggr(aggr)
return false
}
if ContainsAggr(ctx, node) {
makeComplex()
return true
}
Expand Down Expand Up @@ -553,7 +576,7 @@ orderBy:
}
qp.SelectExprs = append(qp.SelectExprs, SelectExpr{
Col: &sqlparser.AliasedExpr{Expr: orderExpr},
Aggr: containsAggr(orderExpr),
Aggr: ContainsAggr(ctx, orderExpr),
})
qp.AddedColumn++
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ func isMergeable(ctx *plancontext.PlanningContext, query sqlparser.SelectStateme

// if we have grouping, we have already checked that it's safe, and don't need to check for aggregations
// but if we don't have groupings, we need to check if there are aggregations that will mess with us
if sqlparser.ContainsAggregation(node.SelectExprs) {
if ContainsAggr(ctx, node.SelectExprs) {
return false
}

if sqlparser.ContainsAggregation(node.Having) {
if ContainsAggr(ctx, node.Having) {
return false
}

Expand Down
Loading

0 comments on commit 773f6a0

Please sign in to comment.