Skip to content

Commit

Permalink
[BugFix] Fix mv non rollup rewrite result contains aggregate functions (
Browse files Browse the repository at this point in the history
#53218)

Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing authored Nov 27, 2024
1 parent 3e89c4e commit 00fafdb
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ protected OptExpression viewBasedRewrite(RewriteContext rewriteContext, OptExpre

// Cannot ROLLUP distinct
if (isRollup) {
boolean mvHasDistinctAggFunc =
mvAggOp.getAggregations().values().stream().anyMatch(callOp -> callOp.isDistinct()
boolean mvHasDistinctAggFunc = mvAggOp.getAggregations().values().stream().anyMatch(callOp -> callOp.isDistinct()
&& !callOp.getFnName().equalsIgnoreCase(FunctionSet.ARRAY_AGG));
boolean queryHasDistinctAggFunc =
queryAggOp.getAggregations().values().stream().anyMatch(callOp -> callOp.isDistinct());
Expand All @@ -143,19 +142,48 @@ protected OptExpression viewBasedRewrite(RewriteContext rewriteContext, OptExpre

// TODO:duplicate if mv has already outputted.
// mvOptExpr = duplicateMvOptExpression(rewriteContext, mvOptExpr, queryExprToMvExprRewriter);

if (isRollup) {
return rewriteForRollup(queryAggOp, queryGroupingKeys, columnRewriter, queryExprToMvExprRewriter,
rewriteContext, mvOptExpr);
} else {
return rewriteProjection(rewriteContext, queryAggOp, queryExprToMvExprRewriter, mvOptExpr);
Pair<OptExpression, Boolean> result =
rewriteProjection(rewriteContext, queryAggOp, queryExprToMvExprRewriter, mvOptExpr);
// even if query and mv's group-by keys are the same, it may still need rollup
// eg:
// example1:
// mv : select dt from t group by dt
// query : select count(dt) from t where dt='2024-11-27';
// rewritten : select count(dt) from mv where dt='2024-11-27'
// example2:
// mv : select dt, avg_union(avg_state(c1)) as s from t group by dt
// query : select dt, avg(c1) from t group by dt
// rewritten : select avg_merge(s) from mv group by dt
if (result.first != null && !result.second) {
return result.first;
} else if (result.second) {
return rewriteForRollup(queryAggOp, queryGroupingKeys, columnRewriter, queryExprToMvExprRewriter,
rewriteContext, mvOptExpr);
} else {
return null;
}
}
}

private boolean isAggregate(ScalarOperator rewritten) {
if (rewritten == null || !(rewritten instanceof CallOperator)) {
return false;
}
CallOperator callOp = (CallOperator) rewritten;
return callOp.isAggregate();
}

protected OptExpression rewriteProjection(RewriteContext rewriteContext,
LogicalAggregationOperator queryAggregationOperator,
EquationRewriter queryExprToMvExprRewriter,
OptExpression mvOptExpr) {
/**
* If rewritten aggregate expr contains aggregate functions, we can still try rollup rewrite again.
*/
protected Pair<OptExpression, Boolean> rewriteProjection(RewriteContext rewriteContext,
LogicalAggregationOperator queryAggregationOperator,
EquationRewriter queryExprToMvExprRewriter,
OptExpression mvOptExpr) {
Map<ColumnRefOperator, ScalarOperator> queryMap = MvUtils.getColumnRefMap(
rewriteContext.getQueryExpression(), rewriteContext.getQueryRefFactory());
ColumnRewriter columnRewriter = new ColumnRewriter(rewriteContext);
Expand All @@ -180,11 +208,13 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, scalarOp,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
// for non-rollup rewrite, the rewritten result should not contain aggregate functions.
boolean isAggregate = isAggregate(rewritten);
if (rewritten == null || isAggregate) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite projection with aggregate group-by/agg expr " +
"failed: {}", scalarOp.toString());
return null;
return Pair.create(null, isAggregate);
}
newQueryProjection.put(entry.getKey(), rewritten);
}
Expand All @@ -202,10 +232,12 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
// for non-rollup rewrite, the rewritten result should not contain aggregate functions.
boolean isAggregate = isAggregate(rewritten);
if (rewritten == null || isAggregate) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite aggregate with having expr failed: {}", scalarOp.toString());
return null;
return Pair.create(null, isAggregate);
}
queryColumnRefToScalarMap.put(entry.getKey(), rewritten);
}
Expand All @@ -215,10 +247,11 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
boolean isAggregate = isAggregate(rewritten);
if (rewritten == null || isAggregate) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Mapping grouping key failed: {}", groupKey.toString());
return null;
return Pair.create(null, isAggregate);
}
queryColumnRefToScalarMap.put(groupKey, rewritten);
}
Expand All @@ -229,15 +262,15 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite aggregate with having failed, cannot compensate aggregate having predicates: {}",
queryAggregationOperator.getPredicate().toString());
return null;
return Pair.create(null, false);
}
Operator op = mvOptExpr.getOp().cast();
// take care original scan predicates and new having exprs
ScalarOperator newPredicate = Utils.compoundAnd(rewrittenPred, op.getPredicate());
mvOptExpr = addExtraPredicate(mvOptExpr, newPredicate);
}

return mvOptExpr;
return Pair.create(mvOptExpr, false);
}

private ScalarOperator rewriteScalarOperator(RewriteContext rewriteContext,
Expand Down Expand Up @@ -269,7 +302,10 @@ private ScalarOperator rewriteScalarOperator(RewriteContext rewriteContext,
// - all matched group by keys bit is less than mvGroupByKeys
// - if query contains one non-mv-existed agg, set it `rollup` and use `replaceExprWithTarget` to
// - check whether to rewrite later.
private boolean isRollupAggregate(List<ScalarOperator> mvGroupingKeys, List<ScalarOperator> queryGroupingKeys,
// NOTE: It's not safe to check rollup by using group by keys only, we may still rollup even if there are the
// same group by keys.
private boolean isRollupAggregate(List<ScalarOperator> mvGroupingKeys,
List<ScalarOperator> queryGroupingKeys,
ScalarOperator queryRangePredicate) {
MaterializedView mv = mvRewriteContext.getMaterializationContext().getMv();
if (mv.getRefreshScheme().isSync() && mv.getDefaultDistributionInfo() instanceof RandomDistributionInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,12 +553,12 @@ public void testRewriteWithEliminateJoinsBasic1() {
" distributed by random" +
" as select sum(t1f) as total, t1a, t1b from test.test_all_type group by t1a, t1b;", () -> {
{
String query = "select t1.t1b, sum(t1b) as total from test.test_all_type t1 " +
String query = "select t1.t1b, sum(t1f) as total from test.test_all_type t1 " +
"join (select 'k1' as k1) t2 on t1.t1a=t2.k1 group by t1.t1b;";
sql(query).match("mv0")
.contains(" 1:Project\n" +
" | <slot 2> : 16: t1b\n" +
" | <slot 13> : sum(16: t1b)\n" +
" | <slot 13> : 15: total\n" +
" | \n" +
" 0:OlapScanNode\n" +
" TABLE: mv0\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5834,4 +5834,12 @@ public void testRangePredicateRewriteCase1() {
" PREAGGREGATION: ON\n" +
" PREDICATES: 20: lo_linenumber = 1, 21: lo_shipmode IN ('A', 'a')");
}

@Test
public void testAggregateToProjection() {
// If agg push down is open, cannot rewrite.
String mv = "select lo_orderkey from lineorder group by lo_orderkey";
String sql = "select count(distinct lo_orderkey) from lineorder where lo_orderkey = 1";
testRewriteOK(mv, sql);
}
}

0 comments on commit 00fafdb

Please sign in to comment.