Skip to content

Commit

Permalink
[release-18.0] Subquery inside aggregration function (#14844) (#14845)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
Co-authored-by: Harshit Gangal <[email protected]>
  • Loading branch information
vitess-bot[bot] and harshit-gangal authored Dec 21, 2023
1 parent 1d4937c commit f39adab
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 20 deletions.
16 changes: 16 additions & 0 deletions go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,19 @@ func TestSubqueryInReference(t *testing.T) {
mcmp.AssertMatches(`select (select id1 from t1 where id2 = 30)`, `[[INT64(3)]]`)
mcmp.AssertMatches(`select (select id1 from t1 where id2 = 9)`, `[[NULL]]`)
}

// TestSubqueryInAggregation validates that subquery work inside aggregation functions.
func TestSubqueryInAggregation(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into t1(id1, id2) values(0,0),(1,1)")
mcmp.Exec("insert into t2(id3, id4) values(1,2),(5,7)")
mcmp.Exec(`SELECT max((select min(id2) from t1)) FROM t2`)
mcmp.Exec(`SELECT max((select group_concat(id1, id2) from t1 where id1 = 1)) FROM t1 where id1 = 1`)
mcmp.Exec(`SELECT max((select min(id2) from t1 where id2 = 1)) FROM dual`)
mcmp.Exec(`SELECT max((select min(id2) from t1)) FROM t2 where id4 = 7`)

// This fails as the planner adds `weight_string` method which make the query fail on MySQL.
// mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`)
}
4 changes: 4 additions & 0 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func pushAggregationThroughSubquery(

src.Outer = pushedAggr

for _, aggregation := range pushedAggr.Aggregations {
aggregation.Original.Expr = rewriteColNameToArgument(ctx, aggregation.Original.Expr, aggregation.SubQueryExpression, src.Inner...)
}

if !rootAggr.Original {
return src, rewrite.NewTree("push Aggregation under subquery - keep original", rootAggr), nil
}
Expand Down
27 changes: 27 additions & 0 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) error {
if a.offsetPlanned {
return nil
}
err := a.checkForInvalidAggregations()
if err != nil {
return err
}
defer func() {
a.offsetPlanned = true
}()
Expand Down Expand Up @@ -480,4 +484,27 @@ func (a *Aggregator) introducesTableID() semantics.TableSet {
return a.DT.introducesTableID()
}

// checkForInvalidAggregations validates that any aggregation functions evaluated at VTGate
// is supported with correct number for arguments.
func (a *Aggregator) checkForInvalidAggregations() error {
for _, aggr := range a.Aggregations {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
aggrFunc, isAggregate := node.(sqlparser.AggrFunc)
if !isAggregate {
return true, nil
}
args := aggrFunc.GetArgs()
if args != nil && len(args) != 1 {
return false, vterrors.VT03001(sqlparser.String(node))
}
return true, nil

}, aggr.Original.Expr)
if err != nil {
return err
}
}
return nil
}

var _ ops.Operator = (*Aggregator)(nil)
17 changes: 16 additions & 1 deletion go/vt/vtgate/planbuilder/operators/horizon_expanding.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,30 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz
return nil, err
}

src := horizon.src()
a := &Aggregator{
Source: horizon.src(),
Source: src,
Original: true,
QP: qp,
Grouping: qp.GetGrouping(),
Aggregations: aggregations,
DT: dt,
}

sqc := &SubQueryBuilder{}
outerID := TableID(src)
for idx, aggr := range aggregations {
expr := aggr.Original.Expr
newExpr, subqs, err := sqc.pullOutValueSubqueries(ctx, expr, outerID, false)
if err != nil {
return nil, err
}
if newExpr != nil {
aggregations[idx].SubQueryExpression = subqs
}
}
a.Source = sqc.getRootOperator(src)

if complexAggr {
return createProjectionForComplexAggregation(a, qp)
}
Expand Down
21 changes: 2 additions & 19 deletions go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ type (
// the offsets point to columns on the same aggregator
ColOffset int
WSOffset int

SubQueryExpression []*SubQuery
}

AggrRewriter struct {
Expand Down Expand Up @@ -287,10 +289,6 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) error {
for _, selExp := range sel.SelectExprs {
switch selExp := selExp.(type) {
case *sqlparser.AliasedExpr:
err := checkForInvalidAggregations(selExp)
if err != nil {
return err
}
col := SelectExpr{
Col: selExp,
}
Expand Down Expand Up @@ -464,21 +462,6 @@ func (qp *QueryProjection) GetGrouping() []GroupBy {
return slices.Clone(qp.groupByExprs)
}

func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error {
return sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
aggrFunc, isAggregate := node.(sqlparser.AggrFunc)
if !isAggregate {
return true, nil
}
args := aggrFunc.GetArgs()
if args != nil && len(args) != 1 {
return false, vterrors.VT03001(sqlparser.String(node))
}
return true, nil

}, exp.Expr)
}

func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool {
for _, groupByExpr := range qp.groupByExprs {
if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, expr) {
Expand Down
7 changes: 7 additions & 0 deletions go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ func settleSubqueries(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Op
for _, setExpr := range op.Assignments {
mergeSubqueryExpr(ctx, setExpr.Expr)
}
case *Aggregator:
for _, aggr := range op.Aggregations {
newExpr, rewritten := rewriteMergedSubqueryExpr(ctx, aggr.SubQueryExpression, aggr.Original.Expr)
if rewritten {
aggr.Original.Expr = newExpr
}
}
}
return op, rewrite.SameTree, nil
}
Expand Down
Loading

0 comments on commit f39adab

Please sign in to comment.