diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index 72b886d1ddb..1bc318c0ff5 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -160,3 +160,19 @@ func TestSubqueryInReference(t *testing.T) { mcmp.AssertMatches(`select (select id1 from t1 where id2 = 30)`, `[[INT64(3)]]`) mcmp.AssertMatches(`select (select id1 from t1 where id2 = 9)`, `[[NULL]]`) } + +// TestSubqueryInAggregation validates that subquery work inside aggregation functions. +func TestSubqueryInAggregation(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t1(id1, id2) values(0,0),(1,1)") + mcmp.Exec("insert into t2(id3, id4) values(1,2),(5,7)") + mcmp.Exec(`SELECT max((select min(id2) from t1)) FROM t2`) + mcmp.Exec(`SELECT max((select group_concat(id1, id2) from t1 where id1 = 1)) FROM t1 where id1 = 1`) + mcmp.Exec(`SELECT max((select min(id2) from t1 where id2 = 1)) FROM dual`) + mcmp.Exec(`SELECT max((select min(id2) from t1)) FROM t2 where id4 = 7`) + + // This fails as the planner adds `weight_string` method which make the query fail on MySQL. + // mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`) +} diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 724ed50a9a8..657d8d129fc 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -107,6 +107,10 @@ func pushAggregationThroughSubquery( src.Outer = pushedAggr + for _, aggregation := range pushedAggr.Aggregations { + aggregation.Original.Expr = rewriteColNameToArgument(ctx, aggregation.Original.Expr, aggregation.SubQueryExpression, src.Inner...) + } + if !rootAggr.Original { return src, rewrite.NewTree("push Aggregation under subquery - keep original", rootAggr), nil } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index bae1eca3007..9e38048d957 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -287,6 +287,10 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) error { if a.offsetPlanned { return nil } + err := a.checkForInvalidAggregations() + if err != nil { + return err + } defer func() { a.offsetPlanned = true }() @@ -480,4 +484,27 @@ func (a *Aggregator) introducesTableID() semantics.TableSet { return a.DT.introducesTableID() } +// checkForInvalidAggregations validates that any aggregation functions evaluated at VTGate +// is supported with correct number for arguments. +func (a *Aggregator) checkForInvalidAggregations() error { + for _, aggr := range a.Aggregations { + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + aggrFunc, isAggregate := node.(sqlparser.AggrFunc) + if !isAggregate { + return true, nil + } + args := aggrFunc.GetArgs() + if args != nil && len(args) != 1 { + return false, vterrors.VT03001(sqlparser.String(node)) + } + return true, nil + + }, aggr.Original.Expr) + if err != nil { + return err + } + } + return nil +} + var _ ops.Operator = (*Aggregator)(nil) diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 2714cb73ff1..84dd7aa3519 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -163,8 +163,9 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz return nil, err } + src := horizon.src() a := &Aggregator{ - Source: horizon.src(), + Source: src, Original: true, QP: qp, Grouping: qp.GetGrouping(), @@ -172,6 +173,20 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz DT: dt, } + sqc := &SubQueryBuilder{} + outerID := TableID(src) + for idx, aggr := range aggregations { + expr := aggr.Original.Expr + newExpr, subqs, err := sqc.pullOutValueSubqueries(ctx, expr, outerID, false) + if err != nil { + return nil, err + } + if newExpr != nil { + aggregations[idx].SubQueryExpression = subqs + } + } + a.Source = sqc.getRootOperator(src) + if complexAggr { return createProjectionForComplexAggregation(a, qp) } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 39bfe899223..2bcba2fc0a2 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -101,6 +101,8 @@ type ( // the offsets point to columns on the same aggregator ColOffset int WSOffset int + + SubQueryExpression []*SubQuery } AggrRewriter struct { @@ -287,10 +289,6 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) error { for _, selExp := range sel.SelectExprs { switch selExp := selExp.(type) { case *sqlparser.AliasedExpr: - err := checkForInvalidAggregations(selExp) - if err != nil { - return err - } col := SelectExpr{ Col: selExp, } @@ -464,21 +462,6 @@ func (qp *QueryProjection) GetGrouping() []GroupBy { return slices.Clone(qp.groupByExprs) } -func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error { - return sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - aggrFunc, isAggregate := node.(sqlparser.AggrFunc) - if !isAggregate { - return true, nil - } - args := aggrFunc.GetArgs() - if args != nil && len(args) != 1 { - return false, vterrors.VT03001(sqlparser.String(node)) - } - return true, nil - - }, exp.Expr) -} - func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool { for _, groupByExpr := range qp.groupByExprs { if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, expr) { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 44bce0e0f2e..2746c2e75e4 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -98,6 +98,13 @@ func settleSubqueries(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Op for _, setExpr := range op.Assignments { mergeSubqueryExpr(ctx, setExpr.Expr) } + case *Aggregator: + for _, aggr := range op.Aggregations { + newExpr, rewritten := rewriteMergedSubqueryExpr(ctx, aggr.SubQueryExpression, aggr.Original.Expr) + if rewritten { + aggr.Original.Expr = newExpr + } + } } return op, rewrite.SameTree, nil } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index d99763bd267..70f50e1f3f7 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6088,5 +6088,288 @@ "user.user" ] } + }, + { + "comment": "select max((select min(col) from user where id = 1))", + "query": "select max((select min(col) from user where id = 1))", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select min(col) from user where id = 1))", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select min(col) from `user` where 1 != 1)) from dual where 1 != 1", + "Query": "select max((select min(col) from `user` where id = 1)) from dual", + "Table": "dual", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "select max((select min(col) from unsharded)) from user where id = 1", + "query": "select max((select min(col) from unsharded)) from user where id = 1", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select min(col) from unsharded)) from user where id = 1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select min(col) from unsharded))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select min(col) from unsharded where 1 != 1", + "Query": "select min(col) from unsharded", + "Table": "unsharded" + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", + "Query": "select max(:__sq1), weight_string(:__sq1) from `user` where id = 1 group by weight_string(:__sq1)", + "Table": "`user`", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + } + ] + } + ] + }, + "TablesUsed": [ + "main.unsharded", + "user.user" + ] + } + }, + { + "comment": "select max((select min(col) from user where id = 1)) from user where id = 2", + "query": "select max((select min(col) from user where id = 1)) from user where id = 2", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select min(col) from user where id = 1)) from user where id = 2", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select min(col) from `user` where id = 1))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select min(col) from `user` where 1 != 1", + "Query": "select min(col) from `user` where id = 1", + "Table": "`user`", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", + "Query": "select max(:__sq1), weight_string(:__sq1) from `user` where id = 2 group by weight_string(:__sq1)", + "Table": "`user`", + "Values": [ + "INT64(2)" + ], + "Vindex": "user_index" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1))", + "query": "select max((select group_concat(col1, col2) from user where id = 1))", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select group_concat(col1, col2) from user where id = 1))", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select group_concat(col1, col2) from `user` where 1 != 1)) from dual where 1 != 1", + "Query": "select max((select group_concat(col1, col2) from `user` where id = 1)) from dual", + "Table": "dual", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user where id = 1", + "query": "select max((select group_concat(col1, col2) from user where id = 1)) from user where id = 1", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select group_concat(col1, col2) from user where id = 1)) from user where id = 1", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select group_concat(col1, col2) from `user` where 1 != 1)) from `user` where 1 != 1", + "Query": "select max((select group_concat(col1, col2) from `user` where id = 1)) from `user` where id = 1", + "Table": "`user`", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "query": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select group_concat(col1, col2) from `user` where id = 1))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select group_concat(col1, col2) from `user` where 1 != 1", + "Query": "select group_concat(col1, col2) from `user` where id = 1", + "Table": "`user`", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", + "Query": "select max(:__sq1), weight_string(:__sq1) from `user` group by weight_string(:__sq1)", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "query": "select max((select max(col2) from user u1 where u1.id = u2.id)) from user u2", + "plan": { + "QueryType": "SELECT", + "Original": "select max((select max(col2) from user u1 where u1.id = u2.id)) from user u2", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0|1) AS max((select max(col2) from `user` as u1 where u1.id = u2.id))", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max((select max(col2) from `user` as u1 where 1 != 1)), weight_string((select max(col2) from `user` as u1 where 1 != 1)) from `user` as u2 where 1 != 1 group by weight_string((select max(col2) from `user` as u1 where 1 != 1))", + "Query": "select max((select max(col2) from `user` as u1 where u1.id = u2.id)), weight_string((select max(col2) from `user` as u1 where u1.id = u2.id)) from `user` as u2 group by weight_string((select max(col2) from `user` as u1 where u1.id = u2.id))", + "Table": "`user`" + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index e3d2f0be11b..a4b6576cbff 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -518,5 +518,10 @@ "comment": "select (select 1 from user u having count(ue.col) > 10) from user_extra ue", "query": "select (select 1 from user u having count(ue.col) > 10) from user_extra ue", "plan": "VT12001: unsupported: correlated subquery is only supported for EXISTS" + }, + { + "comment": "select max((select group_concat(col1, col2) from user where id = 1)) from user", + "query": "select group_concat(col1, col2) from user", + "plan": "VT03001: aggregate functions take a single argument 'group_concat(col1, col2)'" } ]