Skip to content

Commit

Permalink
Add support for AVG on sharded queries (#14419)
Browse files Browse the repository at this point in the history
  • Loading branch information
systay authored Nov 8, 2023
1 parent 1a9119d commit 225fc70
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 107 deletions.
50 changes: 35 additions & 15 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func TestAggregateTypes(t *testing.T) {
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`)
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
}

func TestGroupBy(t *testing.T) {
Expand Down Expand Up @@ -172,6 +173,13 @@ func TestAggrOnJoin(t *testing.T) {

mcmp.AssertMatches("select a.val1 from aggr_test a join t3 t on a.val2 = t.id7 group by a.val1 having count(*) = 4",
`[[VARCHAR("a")]]`)

mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`,
"[[DECIMAL(1.5000) DECIMAL(1.0000)]]")

mcmp.AssertMatches(`select a1.val1, avg(a1.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7 group by a1.val1`,
`[[VARCHAR("a") DECIMAL(1.0000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.0000)]]`)

}

func TestNotEqualFilterOnScatter(t *testing.T) {
Expand Down Expand Up @@ -314,22 +322,26 @@ func TestAggOnTopOfLimit(t *testing.T) {
for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))
mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]")
mcmp.AssertMatches(" select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
mcmp.AssertMatches(" select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
mcmp.AssertMatchesNoOrder(" select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]")
mcmp.AssertMatches("select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]")
mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`)

// mysql returns FLOAT64(0), vitess returns DECIMAL(0)
mcmp.AssertMatchesNoCompare(" select count(*), sum(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0)]]", "[[INT64(2) FLOAT64(0)]]")
mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]")
mcmp.AssertMatches(" select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]")
mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
mcmp.AssertMatches(" select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
mcmp.AssertMatches(" select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
mcmp.AssertMatchesNoOrder(" select val1, count(val2), sum(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1) DECIMAL(2)] [VARCHAR("a") INT64(2) DECIMAL(7)] [VARCHAR("b") INT64(1) DECIMAL(1)] [VARCHAR("c") INT64(2) DECIMAL(7)]]`)
mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]")
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]")
mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]")
mcmp.AssertMatches("select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]")
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
mcmp.AssertMatchesNoOrder("select val1, count(val2), sum(val2), avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1",
`[[NULL INT64(1) DECIMAL(2) DECIMAL(2.0000)] [VARCHAR("a") INT64(2) DECIMAL(7) DECIMAL(3.5000)] [VARCHAR("b") INT64(1) DECIMAL(1) DECIMAL(1.0000)] [VARCHAR("c") INT64(2) DECIMAL(7) DECIMAL(3.5000)]]`)
})
}
}
Expand All @@ -343,6 +355,8 @@ func TestEmptyTableAggr(t *testing.T) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
})
Expand All @@ -355,8 +369,10 @@ func TestEmptyTableAggr(t *testing.T) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
})
}

Expand Down Expand Up @@ -398,6 +414,8 @@ func TestAggregateLeftJoin(t *testing.T) {
mcmp.AssertMatches("SELECT count(*) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[INT64(2)]]`)
mcmp.AssertMatches("SELECT sum(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
mcmp.AssertMatches("SELECT sum(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`)
mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`)
mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`)
}

Expand Down Expand Up @@ -426,6 +444,7 @@ func TestScalarAggregate(t *testing.T) {

mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`)
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
}

func TestAggregationRandomOnAnAggregatedValue(t *testing.T) {
Expand Down Expand Up @@ -478,6 +497,7 @@ func TestComplexAggregation(t *testing.T) {
mcmp.Exec(`SELECT 1+COUNT(t1_id) FROM t1`)
mcmp.Exec(`SELECT COUNT(t1_id)+1 FROM t1`)
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey) FROM t1`)
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`)
mcmp.Exec(`SELECT shardkey, MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)
Expand Down
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const (
AggregateAnyValue
AggregateCountStar
AggregateGroupConcat
AggregateAvg
_NumOfOpCodes // This line must be last of the opcodes!
)

Expand All @@ -85,6 +86,7 @@ var (
AggregateCountStar: sqltypes.Int64,
AggregateSumDistinct: sqltypes.Decimal,
AggregateSum: sqltypes.Decimal,
AggregateAvg: sqltypes.Decimal,
AggregateGtid: sqltypes.VarChar,
}
)
Expand All @@ -96,6 +98,7 @@ var SupportedAggregates = map[string]AggregateOpcode{
"sum": AggregateSum,
"min": AggregateMin,
"max": AggregateMax,
"avg": AggregateAvg,
// These functions don't exist in mysql, but are used
// to display the plan.
"count_distinct": AggregateCountDistinct,
Expand All @@ -117,6 +120,7 @@ var AggregateName = map[AggregateOpcode]string{
AggregateCountStar: "count_star",
AggregateGroupConcat: "group_concat",
AggregateAnyValue: "any_value",
AggregateAvg: "avg",
}

func (code AggregateOpcode) String() string {
Expand Down Expand Up @@ -148,7 +152,7 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
return sqltypes.Text
case AggregateMax, AggregateMin, AggregateAnyValue:
return typ
case AggregateSumDistinct, AggregateSum:
case AggregateSumDistinct, AggregateSum, AggregateAvg:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
Expand Down
109 changes: 91 additions & 18 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,33 @@ func tryPushAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator)
if aggregator.Pushed {
return aggregator, rewrite.SameTree, nil
}

// this rewrite is always valid, and we should do it whenever possible
if route, ok := aggregator.Source.(*Route); ok && (route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping)) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

// other rewrites require us to have reached this phase before we can consider them
if !reachedPhase(ctx, delegateAggregation) {
return aggregator, rewrite.SameTree, nil
}

// if we have not yet been able to push this aggregation down,
// we need to turn AVG into SUM/COUNT to support this over a sharded keyspace
if needAvgBreaking(aggregator.Aggregations) {
return splitAvgAggregations(ctx, aggregator)
}

switch src := aggregator.Source.(type) {
case *Route:
// if we have a single sharded route, we can push it down
output, applyResult, err = pushAggregationThroughRoute(ctx, aggregator, src)
case *ApplyJoin:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
case *Filter:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
case *SubQueryContainer:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
default:
return aggregator, rewrite.SameTree, nil
}
Expand Down Expand Up @@ -135,15 +146,6 @@ func pushAggregationThroughRoute(
aggregator *Aggregator,
route *Route,
) (ops.Operator, *rewrite.ApplyResult, error) {
// If the route is single-shard, or we are grouping by sharding keys, we can just push down the aggregation
if route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

if !reachedPhase(ctx, delegateAggregation) {
return nil, nil, nil
}

// Create a new aggregator to be placed below the route.
aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs())
aggrBelowRoute.Aggregations = nil
Expand Down Expand Up @@ -806,3 +808,74 @@ func initColReUse(size int) []int {
}

func extractExpr(expr *sqlparser.AliasedExpr) sqlparser.Expr { return expr.Expr }

func needAvgBreaking(aggrs []Aggr) bool {
for _, aggr := range aggrs {
if aggr.OpCode == opcode.AggregateAvg {
return true
}
}
return false
}

// splitAvgAggregations takes an aggregator that has AVG aggregations in it and splits
// these into sum/count expressions that can be spread out to shards
func splitAvgAggregations(ctx *plancontext.PlanningContext, aggr *Aggregator) (ops.Operator, *rewrite.ApplyResult, error) {
proj := newAliasedProjection(aggr)

var columns []*sqlparser.AliasedExpr
var aggregations []Aggr

for offset, col := range aggr.Columns {
avg, ok := col.Expr.(*sqlparser.Avg)
if !ok {
proj.addColumnWithoutPushing(ctx, col, false /* addToGroupBy */)
continue
}

if avg.Distinct {
panic(vterrors.VT12001("AVG(distinct <>)"))
}

// We have an AVG that we need to split
sumExpr := &sqlparser.Sum{Arg: avg.Arg}
countExpr := &sqlparser.Count{Args: []sqlparser.Expr{avg.Arg}}
calcExpr := &sqlparser.BinaryExpr{
Operator: sqlparser.DivOp,
Left: sumExpr,
Right: countExpr,
}

outputColumn := aeWrap(col.Expr)
outputColumn.As = sqlparser.NewIdentifierCI(col.ColumnName())
_, err := proj.addUnexploredExpr(sqlparser.CloneRefOfAliasedExpr(col), calcExpr)
if err != nil {
return nil, nil, err
}
col.Expr = sumExpr
found := false
for aggrOffset, aggregation := range aggr.Aggregations {
if offset == aggregation.ColOffset {
// We have found the AVG column. We'll change it to SUM, and then we add a COUNT as well
aggr.Aggregations[aggrOffset].OpCode = opcode.AggregateSum

countExprAlias := aeWrap(countExpr)
countAggr := NewAggr(opcode.AggregateCount, countExpr, countExprAlias, sqlparser.String(countExpr))
countAggr.ColOffset = len(aggr.Columns) + len(columns)
aggregations = append(aggregations, countAggr)
columns = append(columns, countExprAlias)
found = true
break // no need to search the remaining aggregations
}
}
if !found {
// if we get here, it's because we didn't find the aggregation. Something is wrong
panic(vterrors.VT13001("no aggregation pointing to this column was found"))
}
}

aggr.Columns = append(aggr.Columns, columns...)
aggr.Aggregations = append(aggr.Aggregations, aggregations...)

return proj, rewrite.NewTree("split avg aggregation", proj), nil
}
Loading

0 comments on commit 225fc70

Please sign in to comment.