From 68410905fbe95be6a3f4e57379d9dfdc9cd5a0fc Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 21 Dec 2023 20:34:35 +0530 Subject: [PATCH] Subquery inside aggregration function (#14844) Signed-off-by: Harshit Gangal --- .../vtgate/queries/subquery/subquery_test.go | 17 ++ .../operators/aggregation_pushing.go | 4 + .../planbuilder/operators/aggregator.go | 18 ++ .../operators/horizon_expanding.go | 14 +- .../planbuilder/operators/queryprojection.go | 18 +- .../operators/subquery_planning.go | 7 + .../planbuilder/testdata/aggr_cases.json | 283 ++++++++++++++++++ .../testdata/unsupported_cases.json | 5 + 8 files changed, 349 insertions(+), 17 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index ae46a99565d..e849f926d73 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -162,3 +162,20 @@ func TestSubqueryInReference(t *testing.T) { mcmp.AssertMatches(`select (select id1 from t1 where id2 = 30)`, `[[INT64(3)]]`) mcmp.AssertMatches(`select (select id1 from t1 where id2 = 9)`, `[[NULL]]`) } + +// TestSubqueryInAggregation validates that subquery work inside aggregation functions. +func TestSubqueryInAggregation(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t1(id1, id2) values(0,0),(1,1)") + mcmp.Exec("insert into t2(id3, id4) values(1,2),(5,7)") + mcmp.Exec(`SELECT max((select min(id2) from t1)) FROM t2`) + mcmp.Exec(`SELECT max((select group_concat(id1, id2) from t1 where id1 = 1)) FROM t1 where id1 = 1`) + mcmp.Exec(`SELECT max((select min(id2) from t1 where id2 = 1)) FROM dual`) + mcmp.Exec(`SELECT max((select min(id2) from t1)) FROM t2 where id4 = 7`) + + // This fails as the planner adds `weight_string` method which make the query fail on MySQL. + // mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`) +} diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index b56d4dbd869..567936b8a84 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -107,6 +107,10 @@ func pushAggregationThroughSubquery( src.Outer = pushedAggr + for _, aggregation := range pushedAggr.Aggregations { + aggregation.Original.Expr = rewriteColNameToArgument(ctx, aggregation.Original.Expr, aggregation.SubQueryExpression, src.Inner...) + } + if !rootAggr.Original { return src, Rewrote("push Aggregation under subquery - keep original") } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 6c07343498b..02e19d57654 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -255,6 +255,7 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator { if a.offsetPlanned { return nil } + a.checkForInvalidAggregations() defer func() { a.offsetPlanned = true }() @@ -413,4 +414,21 @@ func (a *Aggregator) introducesTableID() semantics.TableSet { return a.DT.introducesTableID() } +func (a *Aggregator) checkForInvalidAggregations() { + for _, aggr := range a.Aggregations { + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + aggrFunc, isAggregate := node.(sqlparser.AggrFunc) + if !isAggregate { + return true, nil + } + args := aggrFunc.GetArgs() + if args != nil && len(args) != 1 { + panic(vterrors.VT03001(sqlparser.String(node))) + } + return true, nil + + }, aggr.Original.Expr) + } +} + var _ Operator = (*Aggregator)(nil) diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index bbe9323509b..86b9ab3ceb6 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -138,8 +138,9 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz aggregations, complexAggr := qp.AggregationExpressions(ctx, true) + src := horizon.src() a := &Aggregator{ - Source: horizon.src(), + Source: src, Original: true, QP: qp, Grouping: qp.GetGrouping(), @@ -147,6 +148,17 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz DT: dt, } + sqc := &SubQueryBuilder{} + outerID := TableID(src) + for idx, aggr := range aggregations { + expr := aggr.Original.Expr + newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr, outerID, false) + if newExpr != nil { + aggregations[idx].SubQueryExpression = subqs + } + } + a.Source = sqc.getRootOperator(src, nil) + if complexAggr { return createProjectionForComplexAggregation(a, qp) } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 163a1213985..6dbad50a50d 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -96,6 +96,8 @@ type ( // the offsets point to columns on the same aggregator ColOffset int WSOffset int + + SubQueryExpression []*SubQuery } AggrRewriter struct { @@ -252,7 +254,6 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) { for _, selExp := range sel.SelectExprs { switch selExp := selExp.(type) { case *sqlparser.AliasedExpr: - checkForInvalidAggregations(selExp) col := SelectExpr{ Col: selExp, } @@ -403,21 +404,6 @@ func (qp *QueryProjection) GetGrouping() []GroupBy { return slices.Clone(qp.groupByExprs) } -func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) { - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - aggrFunc, isAggregate := node.(sqlparser.AggrFunc) - if !isAggregate { - return true, nil - } - args := aggrFunc.GetArgs() - if args != nil && len(args) != 1 { - panic(vterrors.VT03001(sqlparser.String(node))) - } - return true, nil - - }, exp.Expr) -} - func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool { for _, groupByExpr := range qp.groupByExprs { if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, expr) { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 88efdadd266..0980cca9cc8 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -95,6 +95,13 @@ func settleSubqueries(ctx *plancontext.PlanningContext, op Operator) Operator { for _, setExpr := range op.Assignments { mergeSubqueryExpr(ctx, setExpr.Expr) } + case *Aggregator: + for _, aggr := range op.Aggregations { + newExpr, rewritten := rewriteMergedSubqueryExpr(ctx, aggr.SubQueryExpression, aggr.Original.Expr) + if rewritten { + aggr.Original.Expr = newExpr + } + } } return op, NoRewrite } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 4739045b016..0e068283d5c 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6562,5 +6562,288 @@ "user.user_extra" ] } + }, + { + "comment": "select max((select min(col) from user where id = 1))", + "query": "select max((select min(col) from user where id = 1))", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select min(col) from user where id = 1))", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select min(col) from `user` where 1 != 1)) from dual where 1 != 1", + "Query": "select max((select min(col) from `user` where id = 1)) from dual", + "Table": "dual", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "select max((select min(col) from unsharded)) from user where id = 1", + "query": "select max((select min(col) from unsharded)) from user where id = 1", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select min(col) from unsharded)) from user where id = 1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select min(col) from unsharded))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select min(col) from unsharded where 1 != 1", + "Query": "select min(col) from unsharded", + "Table": "unsharded" + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", + "Query": "select max(:__sq1), weight_string(:__sq1) from `user` where id = 1 group by weight_string(:__sq1)", + "Table": "`user`", + "Values": [ + "1" + ], + "Vindex": "user_index" + } + ] + } + ] + }, + "TablesUsed": [ + "main.unsharded", + "user.user" + ] + } + }, + { + "comment": "select max((select min(col) from user where id = 1)) from user where id = 2", + "query": "select max((select min(col) from user where id = 1)) from user where id = 2", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select min(col) from user where id = 1)) from user where id = 2", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select min(col) from `user` where id = 1))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select min(col) from `user` where 1 != 1", + "Query": "select min(col) from `user` where id = 1", + "Table": "`user`", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", + "Query": "select max(:__sq1), weight_string(:__sq1) from `user` where id = 2 group by weight_string(:__sq1)", + "Table": "`user`", + "Values": [ + "2" + ], + "Vindex": "user_index" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1))", + "query": "select max((select group_concat(col1, col2) from user where id = 1))", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select group_concat(col1, col2) from user where id = 1))", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select group_concat(col1, col2) from `user` where 1 != 1)) from dual where 1 != 1", + "Query": "select max((select group_concat(col1, col2) from `user` where id = 1)) from dual", + "Table": "dual", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user where id = 1", + "query": "select max((select group_concat(col1, col2) from user where id = 1)) from user where id = 1", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select group_concat(col1, col2) from user where id = 1)) from user where id = 1", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select group_concat(col1, col2) from `user` where 1 != 1)) from `user` where 1 != 1", + "Query": "select max((select group_concat(col1, col2) from `user` where id = 1)) from `user` where id = 1", + "Table": "`user`", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "query": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select group_concat(col1, col2) from `user` where id = 1))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select group_concat(col1, col2) from `user` where 1 != 1", + "Query": "select group_concat(col1, col2) from `user` where id = 1", + "Table": "`user`", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", + "Query": "select max(:__sq1), weight_string(:__sq1) from `user` group by weight_string(:__sq1)", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "query": "select max((select max(col2) from user u1 where u1.id = u2.id)) from user u2", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select max(col2) from user u1 where u1.id = u2.id)) from user u2", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select max(col2) from `user` as u1 where u1.id = u2.id))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select max(col2) from `user` as u1 where 1 != 1)), weight_string((select max(col2) from `user` as u1 where 1 != 1)) from `user` as u2 where 1 != 1 group by weight_string((select max(col2) from `user` as u1 where 1 != 1))", + "Query": "select max((select max(col2) from `user` as u1 where u1.id = u2.id)), weight_string((select max(col2) from `user` as u1 where u1.id = u2.id)) from `user` as u2 group by weight_string((select max(col2) from `user` as u1 where u1.id = u2.id))", + "Table": "`user`" + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index b1c1c45001c..e93523710dd 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -413,5 +413,10 @@ "comment": "reference table delete with join", "query": "delete r from user u join ref_with_source r on u.col = r.col", "plan": "VT12001: unsupported: DELETE on reference table with join" + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "query": "select group_concat(col1, col2) from user", + "plan": "VT03001: aggregate functions take a single argument 'group_concat(col1, col2)'" } ]