diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index ed917efda4c..b1123e634e8 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -73,6 +73,7 @@ func TestAggregateTypes(t *testing.T) { mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`) mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`) mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`) + mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`) } func TestGroupBy(t *testing.T) { @@ -172,6 +173,13 @@ func TestAggrOnJoin(t *testing.T) { mcmp.AssertMatches("select a.val1 from aggr_test a join t3 t on a.val2 = t.id7 group by a.val1 having count(*) = 4", `[[VARCHAR("a")]]`) + + mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`, + "[[DECIMAL(1.5000) DECIMAL(1.0000)]]") + + mcmp.AssertMatches(`select a1.val1, avg(a1.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7 group by a1.val1`, + `[[VARCHAR("a") DECIMAL(1.0000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.0000)]]`) + } func TestNotEqualFilterOnScatter(t *testing.T) { @@ -314,22 +322,26 @@ func TestAggOnTopOfLimit(t *testing.T) { for _, workload := range []string{"oltp", "olap"} { t.Run(workload, func(t *testing.T) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) - mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]") - mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]") - mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]") - mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]") - mcmp.AssertMatches(" select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]") - mcmp.AssertMatches(" select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`) - mcmp.AssertMatchesNoOrder(" select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`) + mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]") + mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]") + mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]") + mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]") + mcmp.AssertMatches("select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]") + mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]") + mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`) + mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`) + mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`) // mysql returns FLOAT64(0), vitess returns DECIMAL(0) - mcmp.AssertMatchesNoCompare(" select count(*), sum(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0)]]", "[[INT64(2) FLOAT64(0)]]") - mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]") - mcmp.AssertMatches(" select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]") - mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]") - mcmp.AssertMatches(" select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]") - mcmp.AssertMatches(" select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`) - mcmp.AssertMatchesNoOrder(" select val1, count(val2), sum(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1) DECIMAL(2)] [VARCHAR("a") INT64(2) DECIMAL(7)] [VARCHAR("b") INT64(1) DECIMAL(1)] [VARCHAR("c") INT64(2) DECIMAL(7)]]`) + mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]") + mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]") + mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]") + mcmp.AssertMatches("select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]") + mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]") + mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]") + mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`) + mcmp.AssertMatchesNoOrder("select val1, count(val2), sum(val2), avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", + `[[NULL INT64(1) DECIMAL(2) DECIMAL(2.0000)] [VARCHAR("a") INT64(2) DECIMAL(7) DECIMAL(3.5000)] [VARCHAR("b") INT64(1) DECIMAL(1) DECIMAL(1.0000)] [VARCHAR("c") INT64(2) DECIMAL(7) DECIMAL(3.5000)]]`) }) } } @@ -343,6 +355,8 @@ func TestEmptyTableAggr(t *testing.T) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") + mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") + mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]") mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") }) @@ -355,8 +369,10 @@ func TestEmptyTableAggr(t *testing.T) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") - mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") + mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") + mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]") mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") + mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") }) } @@ -398,6 +414,8 @@ func TestAggregateLeftJoin(t *testing.T) { mcmp.AssertMatches("SELECT count(*) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[INT64(2)]]`) mcmp.AssertMatches("SELECT sum(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`) mcmp.AssertMatches("SELECT sum(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`) + mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`) + mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`) mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`) } @@ -426,6 +444,7 @@ func TestScalarAggregate(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)") mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`) + mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`) } func TestAggregationRandomOnAnAggregatedValue(t *testing.T) { @@ -478,6 +497,7 @@ func TestComplexAggregation(t *testing.T) { mcmp.Exec(`SELECT 1+COUNT(t1_id) FROM t1`) mcmp.Exec(`SELECT COUNT(t1_id)+1 FROM t1`) mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey) FROM t1`) + mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`) mcmp.Exec(`SELECT shardkey, MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`) mcmp.Exec(`SELECT shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`) mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`) diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index dd73a78974d..8a70df79d0c 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -74,6 +74,7 @@ const ( AggregateAnyValue AggregateCountStar AggregateGroupConcat + AggregateAvg _NumOfOpCodes // This line must be last of the opcodes! ) @@ -85,6 +86,7 @@ var ( AggregateCountStar: sqltypes.Int64, AggregateSumDistinct: sqltypes.Decimal, AggregateSum: sqltypes.Decimal, + AggregateAvg: sqltypes.Decimal, AggregateGtid: sqltypes.VarChar, } ) @@ -96,6 +98,7 @@ var SupportedAggregates = map[string]AggregateOpcode{ "sum": AggregateSum, "min": AggregateMin, "max": AggregateMax, + "avg": AggregateAvg, // These functions don't exist in mysql, but are used // to display the plan. "count_distinct": AggregateCountDistinct, @@ -117,6 +120,7 @@ var AggregateName = map[AggregateOpcode]string{ AggregateCountStar: "count_star", AggregateGroupConcat: "group_concat", AggregateAnyValue: "any_value", + AggregateAvg: "avg", } func (code AggregateOpcode) String() string { @@ -148,7 +152,7 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { return sqltypes.Text case AggregateMax, AggregateMin, AggregateAnyValue: return typ - case AggregateSumDistinct, AggregateSum: + case AggregateSumDistinct, AggregateSum, AggregateAvg: if typ == sqltypes.Unknown { return sqltypes.Unknown } diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 85c50427364..0127e81b4c4 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -34,22 +34,33 @@ func tryPushAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator) if aggregator.Pushed { return aggregator, rewrite.SameTree, nil } + + // this rewrite is always valid, and we should do it whenever possible + if route, ok := aggregator.Source.(*Route); ok && (route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping)) { + return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original") + } + + // other rewrites require us to have reached this phase before we can consider them + if !reachedPhase(ctx, delegateAggregation) { + return aggregator, rewrite.SameTree, nil + } + + // if we have not yet been able to push this aggregation down, + // we need to turn AVG into SUM/COUNT to support this over a sharded keyspace + if needAvgBreaking(aggregator.Aggregations) { + return splitAvgAggregations(ctx, aggregator) + } + switch src := aggregator.Source.(type) { case *Route: // if we have a single sharded route, we can push it down output, applyResult, err = pushAggregationThroughRoute(ctx, aggregator, src) case *ApplyJoin: - if reachedPhase(ctx, delegateAggregation) { - output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src) - } + output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src) case *Filter: - if reachedPhase(ctx, delegateAggregation) { - output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src) - } + output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src) case *SubQueryContainer: - if reachedPhase(ctx, delegateAggregation) { - output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src) - } + output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src) default: return aggregator, rewrite.SameTree, nil } @@ -135,15 +146,6 @@ func pushAggregationThroughRoute( aggregator *Aggregator, route *Route, ) (ops.Operator, *rewrite.ApplyResult, error) { - // If the route is single-shard, or we are grouping by sharding keys, we can just push down the aggregation - if route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping) { - return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original") - } - - if !reachedPhase(ctx, delegateAggregation) { - return nil, nil, nil - } - // Create a new aggregator to be placed below the route. aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs()) aggrBelowRoute.Aggregations = nil @@ -806,3 +808,74 @@ func initColReUse(size int) []int { } func extractExpr(expr *sqlparser.AliasedExpr) sqlparser.Expr { return expr.Expr } + +func needAvgBreaking(aggrs []Aggr) bool { + for _, aggr := range aggrs { + if aggr.OpCode == opcode.AggregateAvg { + return true + } + } + return false +} + +// splitAvgAggregations takes an aggregator that has AVG aggregations in it and splits +// these into sum/count expressions that can be spread out to shards +func splitAvgAggregations(ctx *plancontext.PlanningContext, aggr *Aggregator) (ops.Operator, *rewrite.ApplyResult, error) { + proj := newAliasedProjection(aggr) + + var columns []*sqlparser.AliasedExpr + var aggregations []Aggr + + for offset, col := range aggr.Columns { + avg, ok := col.Expr.(*sqlparser.Avg) + if !ok { + proj.addColumnWithoutPushing(ctx, col, false /* addToGroupBy */) + continue + } + + if avg.Distinct { + panic(vterrors.VT12001("AVG(distinct <>)")) + } + + // We have an AVG that we need to split + sumExpr := &sqlparser.Sum{Arg: avg.Arg} + countExpr := &sqlparser.Count{Args: []sqlparser.Expr{avg.Arg}} + calcExpr := &sqlparser.BinaryExpr{ + Operator: sqlparser.DivOp, + Left: sumExpr, + Right: countExpr, + } + + outputColumn := aeWrap(col.Expr) + outputColumn.As = sqlparser.NewIdentifierCI(col.ColumnName()) + _, err := proj.addUnexploredExpr(sqlparser.CloneRefOfAliasedExpr(col), calcExpr) + if err != nil { + return nil, nil, err + } + col.Expr = sumExpr + found := false + for aggrOffset, aggregation := range aggr.Aggregations { + if offset == aggregation.ColOffset { + // We have found the AVG column. We'll change it to SUM, and then we add a COUNT as well + aggr.Aggregations[aggrOffset].OpCode = opcode.AggregateSum + + countExprAlias := aeWrap(countExpr) + countAggr := NewAggr(opcode.AggregateCount, countExpr, countExprAlias, sqlparser.String(countExpr)) + countAggr.ColOffset = len(aggr.Columns) + len(columns) + aggregations = append(aggregations, countAggr) + columns = append(columns, countExprAlias) + found = true + break // no need to search the remaining aggregations + } + } + if !found { + // if we get here, it's because we didn't find the aggregation. Something is wrong + panic(vterrors.VT13001("no aggregation pointing to this column was found")) + } + } + + aggr.Columns = append(aggr.Columns, columns...) + aggr.Aggregations = append(aggr.Aggregations, aggregations...) + + return proj, rewrite.NewTree("split avg aggregation", proj), nil +} diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index ba13a828d0b..36149f89985 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -19,6 +19,7 @@ package operators import ( "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -56,9 +57,9 @@ func (p Phase) String() string { return "optimize Distinct operations" case subquerySettling: return "settle subqueries" + default: + panic(vterrors.VT13001("unhandled default case")) } - - return "unknown" } func (p Phase) shouldRun(s semantics.QuerySignature) bool { @@ -73,8 +74,9 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool { return s.Distinct case subquerySettling: return s.SubQueries + default: + return true } - return true } func (p Phase) act(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Operator, error) { @@ -84,14 +86,14 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Opera case delegateAggregation: return enableDelegateAggregation(ctx, op) case addAggrOrdering: - return addOrderBysForAggregations(ctx, op) + return addOrderingForAllAggregations(ctx, op) case cleanOutPerfDistinct: return removePerformanceDistinctAboveRoute(ctx, op) case subquerySettling: return settleSubqueries(ctx, op), nil + default: + return op, nil } - - return op, nil } // getPhases returns the ordered phases that the planner will undergo. @@ -120,7 +122,8 @@ func enableDelegateAggregation(ctx *plancontext.PlanningContext, op ops.Operator return addColumnsToInput(ctx, op) } -func addOrderBysForAggregations(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { +// addOrderingForAllAggregations is run we have pushed down Aggregators as far down as possible. +func addOrderingForAllAggregations(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { visitor := func(in ops.Operator, _ semantics.TableSet, isRoot bool) (ops.Operator, *rewrite.ApplyResult, error) { aggrOp, ok := in.(*Aggregator) if !ok { @@ -131,30 +134,36 @@ func addOrderBysForAggregations(ctx *plancontext.PlanningContext, root ops.Opera if err != nil { return nil, nil, err } - if !requireOrdering { - return in, rewrite.SameTree, nil - } - orderBys := slice.Map(aggrOp.Grouping, func(from GroupBy) ops.OrderBy { - return from.AsOrderBy() - }) - if aggrOp.DistinctExpr != nil { - orderBys = append(orderBys, ops.OrderBy{ - Inner: &sqlparser.Order{ - Expr: aggrOp.DistinctExpr, - }, - SimplifiedExpr: aggrOp.DistinctExpr, - }) - } - aggrOp.Source = &Ordering{ - Source: aggrOp.Source, - Order: orderBys, + + var res *rewrite.ApplyResult + if requireOrdering { + addOrderingFor(aggrOp) + res = rewrite.NewTree("added ordering before aggregation", in) } - return in, rewrite.NewTree("added ordering before aggregation", in), nil + return in, res, nil } return rewrite.BottomUp(root, TableID, visitor, stopAtRoute) } +func addOrderingFor(aggrOp *Aggregator) { + orderBys := slice.Map(aggrOp.Grouping, func(from GroupBy) ops.OrderBy { + return from.AsOrderBy() + }) + if aggrOp.DistinctExpr != nil { + orderBys = append(orderBys, ops.OrderBy{ + Inner: &sqlparser.Order{ + Expr: aggrOp.DistinctExpr, + }, + SimplifiedExpr: aggrOp.DistinctExpr, + }) + } + aggrOp.Source = &Ordering{ + Source: aggrOp.Source, + Order: orderBys, + } +} + func needsOrdering(ctx *plancontext.PlanningContext, in *Aggregator) (bool, error) { requiredOrder := slice.Map(in.Grouping, func(from GroupBy) sqlparser.Expr { return from.SimplifiedExpr diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index b7040807300..ff188d6c895 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -686,25 +686,13 @@ func (qp *QueryProjection) NeedsDistinct() bool { } func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext, allowComplexExpression bool) (out []Aggr, complex bool, err error) { -orderBy: - for _, orderExpr := range qp.OrderExprs { - orderExpr := orderExpr.SimplifiedExpr - for _, expr := range qp.SelectExprs { - col, ok := expr.Col.(*sqlparser.AliasedExpr) - if !ok { - continue - } - if ctx.SemTable.EqualsExprWithDeps(col.Expr, orderExpr) { - continue orderBy // we found the expression we were looking for! - } - } - qp.SelectExprs = append(qp.SelectExprs, SelectExpr{ - Col: &sqlparser.AliasedExpr{Expr: orderExpr}, - Aggr: containsAggr(orderExpr), - }) - qp.AddedColumn++ + qp.addOrderByToSelect(ctx) + addAggr := func(a Aggr) { + out = append(out, a) + } + makeComplex := func() { + complex = true } - // Here we go over the expressions we are returning. Since we know we are aggregating, // all expressions have to be either grouping expressions or aggregate expressions. // If we find an expression that is neither, we treat is as a special aggregation function AggrRandom @@ -733,34 +721,66 @@ orderBy: return nil, false, vterrors.VT12001("in scatter query: complex aggregate expression") } - sqlparser.CopyOnRewrite(aliasedExpr.Expr, func(node, parent sqlparser.SQLNode) bool { - ex, isExpr := node.(sqlparser.Expr) - if !isExpr { - return true - } - if aggr, isAggr := node.(sqlparser.AggrFunc); isAggr { - ae := aeWrap(aggr) - if aggr == aliasedExpr.Expr { - ae = aliasedExpr - } - aggrFunc := createAggrFromAggrFunc(aggr, ae) - aggrFunc.Index = &idxCopy - out = append(out, aggrFunc) - return false + sqlparser.CopyOnRewrite(aliasedExpr.Expr, qp.extractAggr(ctx, idx, aliasedExpr, addAggr, makeComplex), nil, nil) + } + return +} + +func (qp *QueryProjection) extractAggr( + ctx *plancontext.PlanningContext, + idx int, + aliasedExpr *sqlparser.AliasedExpr, + addAggr func(a Aggr), + makeComplex func(), +) func(node sqlparser.SQLNode, parent sqlparser.SQLNode) bool { + return func(node, parent sqlparser.SQLNode) bool { + ex, isExpr := node.(sqlparser.Expr) + if !isExpr { + return true + } + if aggr, isAggr := node.(sqlparser.AggrFunc); isAggr { + ae := aeWrap(aggr) + if aggr == aliasedExpr.Expr { + ae = aliasedExpr } - if containsAggr(node) { - complex = true - return true + aggrFunc := createAggrFromAggrFunc(aggr, ae) + aggrFunc.Index = &idx + addAggr(aggrFunc) + return false + } + if containsAggr(node) { + makeComplex() + return true + } + if !qp.isExprInGroupByExprs(ctx, ex) { + aggr := NewAggr(opcode.AggregateAnyValue, nil, aeWrap(ex), "") + aggr.Index = &idx + addAggr(aggr) + } + return false + } +} + +func (qp *QueryProjection) addOrderByToSelect(ctx *plancontext.PlanningContext) { +orderBy: + // We need to return all columns that are being used for ordering + for _, orderExpr := range qp.OrderExprs { + orderExpr := orderExpr.SimplifiedExpr + for _, expr := range qp.SelectExprs { + col, ok := expr.Col.(*sqlparser.AliasedExpr) + if !ok { + continue } - if !qp.isExprInGroupByExprs(ctx, ex) { - aggr := NewAggr(opcode.AggregateAnyValue, nil, aeWrap(ex), "") - aggr.Index = &idxCopy - out = append(out, aggr) + if ctx.SemTable.EqualsExprWithDeps(col.Expr, orderExpr) { + continue orderBy // we found the expression we were looking for! } - return false - }, nil, nil) + } + qp.SelectExprs = append(qp.SelectExprs, SelectExpr{ + Col: &sqlparser.AliasedExpr{Expr: orderExpr}, + Aggr: containsAggr(orderExpr), + }) + qp.AddedColumn++ } - return } func createAggrFromAggrFunc(fnc sqlparser.AggrFunc, aliasedExpr *sqlparser.AliasedExpr) Aggr { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index e47b17cd2ae..827c64464f8 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6038,5 +6038,216 @@ "user.user_extra" ] } + }, + { + "comment": "avg function on scatter query", + "query": "select avg(id) from user", + "plan": { + "QueryType": "SELECT", + "Original": "select avg(id) from user", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "sum(id) / count(id) as avg(id)" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS avg(id), sum_count(1) AS count(id)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(id), count(id) from `user` where 1 != 1", + "Query": "select sum(id), count(id) from `user`", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "avg function on scatter query deep inside the output expression", + "query": "select avg(id)+count(foo)+bar from user group by bar", + "plan": { + "QueryType": "SELECT", + "Original": "select avg(id)+count(foo)+bar from user group by bar", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "avg(id) + count(foo) + bar as avg(id) + count(foo) + bar" + ], + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + ":0 as bar", + "sum(id) / count(id) as avg(id)", + ":2 as count(foo)" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum(1) AS avg(id), sum_count(2) AS count(foo), sum_count(3) AS count(id)", + "GroupBy": "(0|4)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select bar, sum(id), count(foo), count(id), weight_string(bar) from `user` where 1 != 1 group by bar, weight_string(bar)", + "OrderBy": "(0|4) ASC", + "Query": "select bar, sum(id), count(foo), count(id), weight_string(bar) from `user` group by bar, weight_string(bar) order by bar asc", + "Table": "`user`" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "avg function on scatter query deep inside the output expression", + "query": "select avg(id)+count(foo)+bar from user group by bar", + "plan": { + "QueryType": "SELECT", + "Original": "select avg(id)+count(foo)+bar from user group by bar", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "avg(id) + count(foo) + bar as avg(id) + count(foo) + bar" + ], + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + ":0 as bar", + "sum(id) / count(id) as avg(id)", + ":2 as count(foo)" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum(1) AS avg(id), sum_count(2) AS count(foo), sum_count(3) AS count(id)", + "GroupBy": "(0|4)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select bar, sum(id), count(foo), count(id), weight_string(bar) from `user` where 1 != 1 group by bar, weight_string(bar)", + "OrderBy": "(0|4) ASC", + "Query": "select bar, sum(id), count(foo), count(id), weight_string(bar) from `user` group by bar, weight_string(bar) order by bar asc", + "Table": "`user`" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "two avg aggregations", + "query": "select avg(foo), avg(bar) from user", + "plan": { + "QueryType": "SELECT", + "Original": "select avg(foo), avg(bar) from user", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "sum(foo) / count(foo) as avg(foo)", + "sum(bar) / count(bar) as avg(bar)" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS avg(foo), sum(1) AS avg(bar), sum_count(2) AS count(foo), sum_count(3) AS count(bar)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(foo), sum(bar), count(foo), count(bar) from `user` where 1 != 1", + "Query": "select sum(foo), sum(bar), count(foo), count(bar) from `user`", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "avg and count on the same argument", + "query": "select avg(foo), count(foo) from user", + "plan": { + "QueryType": "SELECT", + "Original": "select avg(foo), count(foo) from user", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "sum(foo) / count(foo) as avg(foo)", + ":1 as count(foo)" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS avg(foo), sum_count(1) AS count(foo), sum_count(2) AS count(foo)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(foo), count(foo), count(foo) from `user` where 1 != 1", + "Query": "select sum(foo), count(foo), count(foo) from `user`", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index f40ea961334..61f602c4e33 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -2,7 +2,50 @@ { "comment": "TPC-H query 1", "query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus", - "plan": "VT12001: unsupported: in scatter query: aggregation function 'avg(l_quantity) as avg_qty'" + "plan": { + "QueryType": "SELECT", + "Original": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + ":0 as l_returnflag", + ":1 as l_linestatus", + ":2 as sum_qty", + ":3 as sum_base_price", + ":4 as sum_disc_price", + ":5 as sum_charge", + "sum(l_quantity) / count(l_quantity) as avg_qty", + "sum(l_extendedprice) / count(l_extendedprice) as avg_price", + "sum(l_discount) / count(l_discount) as avg_disc", + ":9 as count_order" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum(2) AS sum_qty, sum(3) AS sum_base_price, sum(4) AS sum_disc_price, sum(5) AS sum_charge, sum(6) AS avg_qty, sum(7) AS avg_price, sum(8) AS avg_disc, sum_count_star(9) AS count_order, sum_count(10) AS count(l_quantity), sum_count(11) AS count(l_extendedprice), sum_count(12) AS count(l_discount)", + "GroupBy": "(0|13), (1|14)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "main", + "Sharded": true + }, + "FieldQuery": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where 1 != 1 group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus)", + "OrderBy": "(0|13) ASC, (1|14) ASC", + "Query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus) order by l_returnflag asc, l_linestatus asc", + "Table": "lineitem" + } + ] + } + ] + }, + "TablesUsed": [ + "main.lineitem" + ] + } }, { "comment": "TPC-H query 2", diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 923e7804782..6fb4e6d36ab 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -224,11 +224,6 @@ "query": "create view main.view_a as select * from user.user_extra", "plan": "VT12001: unsupported: Select query does not belong to the same keyspace as the view statement" }, - { - "comment": "avg function on scatter query", - "query": "select avg(id) from user", - "plan": "VT12001: unsupported: in scatter query: aggregation function 'avg(id)'" - }, { "comment": "outer and inner subquery route reference the same \"uu.id\" name\n# but they refer to different things. The first reference is to the outermost query,\n# and the second reference is to the innermost 'from' subquery.\n# This query will never work as the inner derived table is only selecting one of the column", "query": "select id2 from user uu where id in (select id from user where id = uu.id and user.col in (select col from (select id from user_extra where user_id = 5) uu where uu.user_id = uu.id))",