Skip to content

Commit

Permalink
[Enhancement] (MV For TimeSeries Scene: Part 2) Support date_trunc wi…
Browse files Browse the repository at this point in the history
…th day/month/year rollup (#51451)

Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing authored Oct 10, 2024
1 parent 716434a commit 219b62c
Show file tree
Hide file tree
Showing 11 changed files with 1,010 additions and 350 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class TimeUnitUtils {
public static final String SECOND = "second";
public static final String MINUTE = "minute";
public static final String HOUR = "hour";
public static final String WEEK = "week";
public static final String DAY = "day";
public static final String MONTH = "month";
public static final String QUARTER = "quarter";
Expand All @@ -29,12 +30,25 @@ public class TimeUnitUtils {
// "week" can not exist in timeMap due "month" not sure contains week
public static final ImmutableMap<String, Integer> TIME_MAP =
new ImmutableMap.Builder<String, Integer>()
.put("second", 1)
.put("minute", 2)
.put("hour", 3)
.put("day", 4)
.put("month", 5)
.put("quarter", 6)
.put("year", 7)
.put(SECOND, 1)
.put(MINUTE, 2)
.put(HOUR, 3)
.put(DAY, 4)
.put(MONTH, 5)
.put(QUARTER, 6)
.put(YEAR, 7)
.build();

// all time units which date_trunc supported
public static final ImmutableMap<String, Integer> DATE_TRUNC_SUPPORTED_TIME_MAP =
new ImmutableMap.Builder<String, Integer>()
.put(SECOND, 1)
.put(MINUTE, 2)
.put(HOUR, 3)
.put(DAY, 4)
.put(WEEK, 5)
.put(MONTH, 6)
.put(QUARTER, 7)
.put(YEAR, 8)
.build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rule.transformation.materialization.common.AggregateFunctionRollupUtils;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.EquivalentShuttleContext;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.IRewriteEquivalent;
import com.starrocks.sql.optimizer.rule.tree.pdagg.AggregatePushDownContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -127,7 +128,7 @@ protected OptExpression viewBasedRewrite(RewriteContext rewriteContext, OptExpre
boolean mvHasDistinctAggFunc =
mvAggOp.getAggregations().values().stream().anyMatch(callOp -> callOp.isDistinct()
&& !callOp.getFnName().equalsIgnoreCase(FunctionSet.ARRAY_AGG));
boolean queryHasDistinctAggFunc =
boolean queryHasDistinctAggFunc =
queryAggOp.getAggregations().values().stream().anyMatch(callOp -> callOp.isDistinct());
if (mvHasDistinctAggFunc && queryHasDistinctAggFunc) {
OptimizerTraceUtil.logMVRewriteFailReason(
Expand All @@ -137,6 +138,7 @@ protected OptExpression viewBasedRewrite(RewriteContext rewriteContext, OptExpre
return null;
}
}
rewriteContext.setRollup(isRollup);

// normalize mv's aggs by using query's table ref and query ec
Map<ColumnRefOperator, ScalarOperator> mvProjection =
Expand Down Expand Up @@ -181,7 +183,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
// rewrite group by + aggregate functions
for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : swappedQueryColumnMap.entrySet()) {
ScalarOperator scalarOp = entry.getValue();
ScalarOperator rewritten = rewriteScalarOperator(scalarOp,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, scalarOp,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand All @@ -203,7 +205,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
ScalarOperator scalarOp = entry.getValue();
ScalarOperator mapped = rewriteContext.getQueryColumnRefRewriter().rewrite(scalarOp.clone());
ScalarOperator swapped = columnRewriter.rewriteByQueryEc(mapped);
ScalarOperator rewritten = rewriteScalarOperator(swapped,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand All @@ -216,7 +218,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
for (ColumnRefOperator groupKey : queryAggregationOperator.getGroupingKeys()) {
ScalarOperator mapped = rewriteContext.getQueryColumnRefRewriter().rewrite(groupKey.clone());
ScalarOperator swapped = columnRewriter.rewriteByQueryEc(mapped);
ScalarOperator rewritten = rewriteScalarOperator(swapped,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand Down Expand Up @@ -244,7 +246,8 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
return mvOptExpr;
}

private ScalarOperator rewriteScalarOperator(ScalarOperator scalarOp,
private ScalarOperator rewriteScalarOperator(RewriteContext rewriteContext,
ScalarOperator scalarOp,
EquationRewriter equationRewriter,
Map<ColumnRefOperator, ColumnRefOperator> columnMapping,
ColumnRefSet originalColumnSet,
Expand All @@ -254,12 +257,15 @@ private ScalarOperator rewriteScalarOperator(ScalarOperator scalarOp,
}
equationRewriter.setAggregateFunctionRewriter(aggregateFunctionRewriter);
equationRewriter.setOutputMapping(columnMapping);
ScalarOperator rewritten = equationRewriter.replaceExprWithTarget(scalarOp);

Pair<ScalarOperator, EquivalentShuttleContext> result =
equationRewriter.replaceExprWithEquivalent(rewriteContext, scalarOp);
ScalarOperator rewritten = result.first;
if (rewritten == null || scalarOp == rewritten) {
return null;
}
if (!isAllExprReplaced(rewritten, originalColumnSet)) {
// it means there is some column that can not be rewritten by outputs of mv
// it means there is some column that cannot be rewritten by outputs of mv
return null;
}
return rewritten;
Expand Down Expand Up @@ -370,7 +376,7 @@ private OptExpression rewriteForRollup(
Map<ColumnRefOperator, ScalarOperator> queryColumnRefToScalarMap = Maps.newHashMap();

// rewrite group by keys by using mv
List<ScalarOperator> newQueryGroupKeys = rewriteGroupKeys(queryGroupingKeys, equationRewriter,
List<ScalarOperator> newQueryGroupKeys = rewriteGroupKeys(rewriteContext, queryGroupingKeys, equationRewriter,
rewriteContext.getOutputMapping(), new ColumnRefSet(rewriteContext.getQueryColumnSet()));
if (newQueryGroupKeys == null) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
Expand Down Expand Up @@ -608,14 +614,17 @@ private OptExpression createNewAggregate(
/**
* Rewrite group by keys by using MV.
*/
private List<ScalarOperator> rewriteGroupKeys(List<ScalarOperator> groupKeys,
private List<ScalarOperator> rewriteGroupKeys(RewriteContext rewriteContext,
List<ScalarOperator> groupKeys,
EquationRewriter equationRewriter,
Map<ColumnRefOperator, ColumnRefOperator> mapping,
ColumnRefSet queryColumnSet) {
List<ScalarOperator> newGroupByKeys = Lists.newArrayList();
equationRewriter.setOutputMapping(mapping);
for (ScalarOperator key : groupKeys) {
ScalarOperator newGroupByKey = equationRewriter.replaceExprWithTarget(key);
Pair<ScalarOperator, EquivalentShuttleContext> result = equationRewriter.replaceExprWithEquivalent(rewriteContext,
key, IRewriteEquivalent.RewriteEquivalentType.PREDICATE);
ScalarOperator newGroupByKey = result.first;
if (key.isVariable() && key == newGroupByKey) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite group by key failed: {}", key.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class EquationRewriter {
boolean underAggFunctionRewriteContext;

private final EquivalentShuttle defaultShuttle = new EquivalentShuttle(new EquivalentShuttleContext(null,
false, true));
false, true, IRewriteEquivalent.RewriteEquivalentType.AGGREGATE));

public EquationRewriter() {
this.equationMap = ArrayListMultimap.create();
Expand Down Expand Up @@ -123,13 +123,27 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate, Vo

private ScalarOperator rewriteByEquivalent(ScalarOperator input,
IRewriteEquivalent.RewriteEquivalentType type) {
if (!shuttleContext.isUseEquivalent() || !rewriteEquivalents.containsKey(type)) {
if (!shuttleContext.isUseEquivalent()) {
return null;
}
for (RewriteEquivalent equivalent : rewriteEquivalents.get(type)) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
if (type.isAny()) {
for (List<RewriteEquivalent> equivalents : rewriteEquivalents.values()) {
for (RewriteEquivalent equivalent : equivalents) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
}
}
}
} else {
if (!rewriteEquivalents.containsKey(type)) {
return null;
}
for (RewriteEquivalent equivalent : rewriteEquivalents.get(type)) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
}
}
}
return null;
Expand All @@ -144,7 +158,7 @@ public ScalarOperator visitCall(CallOperator call, Void context) {
}

// rewrite by equivalent
ScalarOperator rewritten = rewriteByEquivalent(call, IRewriteEquivalent.RewriteEquivalentType.AGGREGATE);
ScalarOperator rewritten = rewriteByEquivalent(call, shuttleContext.getRewriteEquivalentType());
if (rewritten != null) {
shuttleContext.setRewrittenByEquivalent(true);
return rewritten;
Expand Down Expand Up @@ -228,12 +242,36 @@ public ScalarOperator replaceExprWithTarget(ScalarOperator expr) {
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithRollup(RewriteContext rewriteContext,
ScalarOperator expr) {
return replaceExprWithEquivalent(rewriteContext, expr, IRewriteEquivalent.RewriteEquivalentType.AGGREGATE);
}

/**
* Replace expr with equivalent shuttle with specific type.
* @param rewriteContext rewrite context
* @param expr input expr to be rewritten
* @param type equivalent type which is used for call operator rewrite to deduce rewriting strategy
* @return rewritten expr and equivalent shuttle context
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithEquivalent(
RewriteContext rewriteContext,
ScalarOperator expr,
IRewriteEquivalent.RewriteEquivalentType type) {
boolean isRollup = rewriteContext.isRollup();
final EquivalentShuttleContext shuttleContext = new EquivalentShuttleContext(rewriteContext,
true, true);
isRollup, true, type);
final EquivalentShuttle shuttle = new EquivalentShuttle(shuttleContext);
return Pair.create(expr.accept(shuttle, null), shuttleContext);
}

/**
* Replace expr with equivalent shuttle, by default, we can rewrite call operator with any type of equivalent
* since call operator can be aggregate or predicate or group by keys.
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithEquivalent(RewriteContext rewriteContext,
ScalarOperator expr) {
return replaceExprWithEquivalent(rewriteContext, expr, IRewriteEquivalent.RewriteEquivalentType.ANY);
}

public boolean containsKey(ScalarOperator scalarOperator) {
return equationMap.containsKey(scalarOperator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ public RangePredicate visit(ScalarOperator scalarOperator, PredicateExtractorCon
}

@Override
public RangePredicate visitBinaryPredicate(
BinaryPredicateOperator predicate, PredicateExtractorContext context) {
public RangePredicate visitBinaryPredicate(BinaryPredicateOperator predicate, PredicateExtractorContext context) {
RangePredicate rangePredicate = rewriteBinaryPredicate(predicate);
if (rangePredicate != null) {
return rangePredicate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public class RewriteContext {
private BiMap<Integer, Integer> queryToMvRelationIdMapping;
private ScalarOperator unionRewriteQueryExtraPredicate;
private AggregatePushDownContext aggregatePushDownContext;
// whether this rewritten query is a rollup query
private boolean isRollup;

public RewriteContext(OptExpression queryExpression,
PredicateSplit queryPredicateSplit,
Expand Down Expand Up @@ -171,4 +173,12 @@ public AggregatePushDownContext getAggregatePushDownContext() {
public void setAggregatePushDownContext(AggregatePushDownContext aggregatePushDownContext) {
this.aggregatePushDownContext = aggregatePushDownContext;
}

public boolean isRollup() {
return isRollup;
}

public void setRollup(boolean rollup) {
isRollup = rollup;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import java.util.Set;

import static com.starrocks.sql.common.TimeUnitUtils.DATE_TRUNC_SUPPORTED_TIME_MAP;

public class DateTruncEquivalent extends IPredicateRewriteEquivalent {
public static final DateTruncEquivalent INSTANCE = new DateTruncEquivalent();

Expand All @@ -34,7 +36,7 @@ public DateTruncEquivalent() {}
* TODO: we can support this later.
* Change date_trunc('month', col) to col = '2023-12-01' will get a wrong result.
* MV : select date_trunc('day', col) as dt from t
* Query : select date_trunc('day, col) from t where date_trunc('month', col) = '2023-11-01'
* Query : select date_trunc('month, col) from t where date_trunc('month', col) = '2023-11-01'
*/
private static Set<BinaryType> SUPPORTED_BINARY_TYPES = ImmutableSet.of(
BinaryType.GE,
Expand Down Expand Up @@ -93,20 +95,49 @@ public ScalarOperator rewrite(RewriteEquivalentContext eqContext,
EquivalentShuttleContext shuttleContext,
ColumnRefOperator replace,
ScalarOperator newInput) {
if (!(newInput instanceof BinaryPredicateOperator)) {
return null;
}
ScalarOperator left = newInput.getChild(0);
ScalarOperator right = newInput.getChild(1);
if (newInput instanceof BinaryPredicateOperator) {
ScalarOperator left = newInput.getChild(0);
ScalarOperator right = newInput.getChild(1);

if (!right.isConstantRef() || !left.equals(eqContext.getEquivalent())) {
return null;
}
if (!isEquivalent(eqContext.getInput(), (ConstantOperator) right)) {
return null;
if (!right.isConstantRef() || !left.equals(eqContext.getEquivalent())) {
return null;
}
if (!isEquivalent(eqContext.getInput(), (ConstantOperator) right)) {
return null;
}
BinaryPredicateOperator predicate = (BinaryPredicateOperator) newInput.clone();
predicate.setChild(0, replace);
return predicate;
} else if (newInput instanceof CallOperator) {
// only in rollup aggregate, `date_trunc('day', dt) as dt` can be rewritten to `date_trunc('month', dt)`
if (!shuttleContext.isRollup()) {
return null;
}
CallOperator newCall = (CallOperator) newInput;
if (!checkDateTrucFunc(newCall)) {
return null;
}
CallOperator oldCall = (CallOperator) eqContext.getInput();
ConstantOperator oldChild0 = (ConstantOperator) oldCall.getChild(0);
// ensure col ref is the same in date_trunc
if (!newCall.getChild(1).equals(oldCall.getChild(1))) {
return null;
}
ConstantOperator newChild0 = (ConstantOperator) newCall.getChild(0);
if (!DATE_TRUNC_SUPPORTED_TIME_MAP.containsKey(oldChild0.getVarchar()) ||
!DATE_TRUNC_SUPPORTED_TIME_MAP.containsKey(newChild0.getVarchar())) {
// only can rewrite date_trunc('day', col) to date_trunc('month', col)
return null;
}
int oldTimeUnit = DATE_TRUNC_SUPPORTED_TIME_MAP.get(oldChild0.getVarchar());
int newTimeUnit = DATE_TRUNC_SUPPORTED_TIME_MAP.get(newChild0.getVarchar());
if (oldTimeUnit > newTimeUnit) {
return null;
}
CallOperator rewritten = (CallOperator) newCall.clone();
rewritten.setChild(1, replace);
return rewritten;
}
BinaryPredicateOperator predicate = (BinaryPredicateOperator) newInput.clone();
predicate.setChild(0, replace);
return predicate;
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ public class EquivalentShuttleContext {
private final boolean isRollup;
private boolean isUseEquivalent;
private boolean isRewrittenByEquivalent;
private IRewriteEquivalent.RewriteEquivalentType rewriteEquivalentType;

public EquivalentShuttleContext(RewriteContext rewriteContext, boolean isRollup, boolean isRewrittenByEquivalent) {
public EquivalentShuttleContext(RewriteContext rewriteContext, boolean isRollup, boolean isRewrittenByEquivalent,
IRewriteEquivalent.RewriteEquivalentType type) {
this.rewriteContext = rewriteContext;
this.isRollup = isRollup;
this.isUseEquivalent = isRewrittenByEquivalent;
this.rewriteEquivalentType = type;
}

public boolean isUseEquivalent() {
Expand All @@ -47,4 +50,8 @@ public void setRewrittenByEquivalent(boolean rewrittenByEquivalent) {
public RewriteContext getRewriteContext() {
return rewriteContext;
}

public IRewriteEquivalent.RewriteEquivalentType getRewriteEquivalentType() {
return rewriteEquivalentType;
}
}
Loading

0 comments on commit 219b62c

Please sign in to comment.