Skip to content

Commit

Permalink
planbuilder: split avg aggregations into sum/count
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Nov 7, 2023
1 parent 0c849ca commit a0506a1
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 33 deletions.
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const (
AggregateAnyValue
AggregateCountStar
AggregateGroupConcat
AggregateAvg
_NumOfOpCodes // This line must be last of the opcodes!
)

Expand All @@ -85,6 +86,7 @@ var (
AggregateCountStar: sqltypes.Int64,
AggregateSumDistinct: sqltypes.Decimal,
AggregateSum: sqltypes.Decimal,
AggregateAvg: sqltypes.Decimal,
AggregateGtid: sqltypes.VarChar,
}
)
Expand All @@ -96,6 +98,7 @@ var SupportedAggregates = map[string]AggregateOpcode{
"sum": AggregateSum,
"min": AggregateMin,
"max": AggregateMax,
"avg": AggregateAvg,
// These functions don't exist in mysql, but are used
// to display the plan.
"count_distinct": AggregateCountDistinct,
Expand All @@ -117,6 +120,7 @@ var AggregateName = map[AggregateOpcode]string{
AggregateCountStar: "count_star",
AggregateGroupConcat: "group_concat",
AggregateAnyValue: "any_value",
AggregateAvg: "avg",
}

func (code AggregateOpcode) String() string {
Expand Down Expand Up @@ -148,7 +152,7 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
return sqltypes.Text
case AggregateMax, AggregateMin, AggregateAnyValue:
return typ
case AggregateSumDistinct, AggregateSum:
case AggregateSumDistinct, AggregateSum, AggregateAvg:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
Expand Down
110 changes: 92 additions & 18 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,39 @@ func tryPushAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator)
if aggregator.Pushed {
return aggregator, rewrite.SameTree, nil
}

// this rewrite is always valid, and we should do it whenever possible
if route, ok := aggregator.Source.(*Route); ok && (route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping)) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

// other rewrites require us to have reached this phase before we can consider them
if !reachedPhase(ctx, delegateAggregation) {
return aggregator, rewrite.SameTree, nil
}

// if we have not yet been able to push this aggregation down,
// we need to turn AVG into SUM/COUNT to support this over a sharded keyspace
if needAvgBreaking(aggregator.Aggregations) {
output, err = splitAvgAggregations(ctx, aggregator)
if err != nil {
return nil, nil, err
}

applyResult = rewrite.NewTree("split avg aggregation", output)
return
}

switch src := aggregator.Source.(type) {
case *Route:
// if we have a single sharded route, we can push it down
output, applyResult, err = pushAggregationThroughRoute(ctx, aggregator, src)
case *ApplyJoin:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
case *Filter:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
case *SubQueryContainer:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
default:
return aggregator, rewrite.SameTree, nil
}
Expand Down Expand Up @@ -135,15 +152,6 @@ func pushAggregationThroughRoute(
aggregator *Aggregator,
route *Route,
) (ops.Operator, *rewrite.ApplyResult, error) {
// If the route is single-shard, or we are grouping by sharding keys, we can just push down the aggregation
if route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

if !reachedPhase(ctx, delegateAggregation) {
return nil, nil, nil
}

// Create a new aggregator to be placed below the route.
aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs())
aggrBelowRoute.Aggregations = nil
Expand Down Expand Up @@ -806,3 +814,69 @@ func initColReUse(size int) []int {
}

func extractExpr(expr *sqlparser.AliasedExpr) sqlparser.Expr { return expr.Expr }

func needAvgBreaking(aggrs []Aggr) bool {
for _, aggr := range aggrs {
if aggr.OpCode == opcode.AggregateAvg {
return true
}
}
return false
}

// splitAvgAggregations takes an aggregator that has AVG aggregations in it and splits
// these into sum/count expressions that can be spread out to shards
func splitAvgAggregations(ctx *plancontext.PlanningContext, aggr *Aggregator) (ops.Operator, error) {
proj := newAliasedProjection(aggr)

var columns []*sqlparser.AliasedExpr
var aggregations []Aggr

for offset, col := range aggr.Columns {
avg, ok := col.Expr.(*sqlparser.Avg)
if !ok {
proj.addColumnWithoutPushing(ctx, col, false /* addToGroupBy */)
continue
}

if avg.Distinct {
panic(vterrors.VT12001("AVG(distinct <>)"))
}

// We have an AVG that we need to split
sumExpr := &sqlparser.Sum{Arg: avg.Arg}
countExpr := &sqlparser.Count{Args: []sqlparser.Expr{avg.Arg}}
calcExpr := &sqlparser.BinaryExpr{
Operator: sqlparser.DivOp,
Left: sumExpr,
Right: countExpr,
}

outputColumn := aeWrap(col.Expr)
outputColumn.As = sqlparser.NewIdentifierCI(col.ColumnName())
_, err := proj.addUnexploredExpr(sqlparser.CloneRefOfAliasedExpr(col), calcExpr)
if err != nil {
return nil, err
}
col.Expr = sumExpr

for aggrOffset, aggregation := range aggr.Aggregations {
if offset == aggregation.ColOffset {
// We have found the AVG column. We'll change it to SUM, and then we add a COUNT as well
aggr.Aggregations[aggrOffset].OpCode = opcode.AggregateSum

countExprAlias := aeWrap(countExpr)
countAggr := NewAggr(opcode.AggregateCount, countExpr, countExprAlias, sqlparser.String(countExpr))
countAggr.ColOffset = len(aggr.Columns) + len(columns)
aggregations = append(aggregations, countAggr)
columns = append(columns, countExprAlias)
break // no need to search the remaining aggregations
}
}
}

aggr.Columns = append(aggr.Columns, columns...)
aggr.Aggregations = append(aggr.Aggregations, aggregations...)

return proj, nil
}
16 changes: 8 additions & 8 deletions go/vt/vtgate/planbuilder/operators/phases.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package operators
import (
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand Down Expand Up @@ -56,9 +57,9 @@ func (p Phase) String() string {
return "optimize Distinct operations"
case subquerySettling:
return "settle subqueries"
default:
panic(vterrors.VT13001("unhandled default case"))
}

return "unknown"
}

func (p Phase) shouldRun(s semantics.QuerySignature) bool {
Expand All @@ -73,8 +74,9 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool {
return s.Distinct
case subquerySettling:
return s.SubQueries
default:
return true
}
return true
}

func (p Phase) act(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Operator, error) {
Expand All @@ -89,9 +91,9 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Opera
return removePerformanceDistinctAboveRoute(ctx, op)
case subquerySettling:
return settleSubqueries(ctx, op), nil
default:
return op, nil
}

return op, nil
}

// getPhases returns the ordered phases that the planner will undergo.
Expand Down Expand Up @@ -128,18 +130,16 @@ func addOrderingForAllAggregations(ctx *plancontext.PlanningContext, root ops.Op
return in, rewrite.SameTree, nil
}

var res *rewrite.ApplyResult

requireOrdering, err := needsOrdering(ctx, aggrOp)
if err != nil {
return nil, nil, err
}

var res *rewrite.ApplyResult
if requireOrdering {
addOrderingFor(aggrOp)
res = rewrite.NewTree("added ordering before aggregation", in)
}

return in, res, nil
}

Expand Down
86 changes: 86 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -6038,5 +6038,91 @@
"user.user_extra"
]
}
},
{
"comment": "avg function on scatter query",
"query": "select avg(id) from user",
"plan": {
"QueryType": "SELECT",
"Original": "select avg(id) from user",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"sum(id) / count(id) as avg(id)"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "sum(0) AS avg(id), sum_count(1) AS count(id)",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select sum(id), count(id) from `user` where 1 != 1",
"Query": "select sum(id), count(id) from `user`",
"Table": "`user`"
}
]
}
]
},
"TablesUsed": [
"user.user"
]
}
},
{
"comment": "avg function on scatter query deep inside the output expression",
"query": "select avg(id)+count(foo)+bar from user group by bar",
"plan": {
"QueryType": "SELECT",
"Original": "select avg(id)+count(foo)+bar from user group by bar",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"avg(id) + count(foo) + bar as avg(id) + count(foo) + bar"
],
"Inputs": [
{
"OperatorType": "Projection",
"Expressions": [
":0 as bar",
"sum(id) / count(id) as avg(id)",
":2 as count(foo)"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Ordered",
"Aggregates": "sum(1) AS avg(id), sum_count(2) AS count(foo), sum_count(3) AS count(id)",
"GroupBy": "(0|4)",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select bar, sum(id), count(foo), count(id), weight_string(bar) from `user` where 1 != 1 group by bar, weight_string(bar)",
"OrderBy": "(0|4) ASC",
"Query": "select bar, sum(id), count(foo), count(id), weight_string(bar) from `user` group by bar, weight_string(bar) order by bar asc",
"Table": "`user`"
}
]
}
]
}
]
},
"TablesUsed": [
"user.user"
]
}
}
]
45 changes: 44 additions & 1 deletion go/vt/vtgate/planbuilder/testdata/tpch_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,50 @@
{
"comment": "TPC-H query 1",
"query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus",
"plan": "VT12001: unsupported: in scatter query: aggregation function 'avg(l_quantity) as avg_qty'"
"plan": {
"QueryType": "SELECT",
"Original": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
":0 as l_returnflag",
":1 as l_linestatus",
":2 as sum_qty",
":3 as sum_base_price",
":4 as sum_disc_price",
":5 as sum_charge",
"sum(l_quantity) / count(l_quantity) as avg_qty",
"sum(l_extendedprice) / count(l_extendedprice) as avg_price",
"sum(l_discount) / count(l_discount) as avg_disc",
":9 as count_order"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Ordered",
"Aggregates": "sum(2) AS sum_qty, sum(3) AS sum_base_price, sum(4) AS sum_disc_price, sum(5) AS sum_charge, sum(6) AS avg_qty, sum(7) AS avg_price, sum(8) AS avg_disc, sum_count_star(9) AS count_order, sum_count(10) AS count(l_quantity), sum_count(11) AS count(l_extendedprice), sum_count(12) AS count(l_discount)",
"GroupBy": "(0|13), (1|14)",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "main",
"Sharded": true
},
"FieldQuery": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where 1 != 1 group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus)",
"OrderBy": "(0|13) ASC, (1|14) ASC",
"Query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus) order by l_returnflag asc, l_linestatus asc",
"Table": "lineitem"
}
]
}
]
},
"TablesUsed": [
"main.lineitem"
]
}
},
{
"comment": "TPC-H query 2",
Expand Down
5 changes: 0 additions & 5 deletions go/vt/vtgate/planbuilder/testdata/unsupported_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,6 @@
"query": "create view main.view_a as select * from user.user_extra",
"plan": "VT12001: unsupported: Select query does not belong to the same keyspace as the view statement"
},
{
"comment": "avg function on scatter query",
"query": "select avg(id) from user",
"plan": "VT12001: unsupported: in scatter query: aggregation function 'avg(id)'"
},
{
"comment": "outer and inner subquery route reference the same \"uu.id\" name\n# but they refer to different things. The first reference is to the outermost query,\n# and the second reference is to the innermost 'from' subquery.\n# This query will never work as the inner derived table is only selecting one of the column",
"query": "select id2 from user uu where id in (select id from user where id = uu.id and user.col in (select col from (select id from user_extra where user_id = 5) uu where uu.user_id = uu.id))",
Expand Down

0 comments on commit a0506a1

Please sign in to comment.