diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/JoinOperator.java b/fe/fe-core/src/main/java/com/starrocks/analysis/JoinOperator.java index 2cf8d2b43227e..53a5c0abb0757 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/JoinOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/JoinOperator.java @@ -40,20 +40,19 @@ import java.util.Set; public enum JoinOperator { - INNER_JOIN("INNER JOIN", TJoinOp.INNER_JOIN), - LEFT_OUTER_JOIN("LEFT OUTER JOIN", TJoinOp.LEFT_OUTER_JOIN), - - LEFT_SEMI_JOIN("LEFT SEMI JOIN", TJoinOp.LEFT_SEMI_JOIN), - LEFT_ANTI_JOIN("LEFT ANTI JOIN", TJoinOp.LEFT_ANTI_JOIN), - RIGHT_SEMI_JOIN("RIGHT SEMI JOIN", TJoinOp.RIGHT_SEMI_JOIN), - RIGHT_ANTI_JOIN("RIGHT ANTI JOIN", TJoinOp.RIGHT_ANTI_JOIN), - RIGHT_OUTER_JOIN("RIGHT OUTER JOIN", TJoinOp.RIGHT_OUTER_JOIN), - FULL_OUTER_JOIN("FULL OUTER JOIN", TJoinOp.FULL_OUTER_JOIN), - CROSS_JOIN("CROSS JOIN", TJoinOp.CROSS_JOIN), - // Variant of the LEFT ANTI JOIN that is used for the equal of + INNER_JOIN("INNER JOIN", "⋈", TJoinOp.INNER_JOIN), + LEFT_OUTER_JOIN("LEFT OUTER JOIN", "⟕", TJoinOp.LEFT_OUTER_JOIN), + + LEFT_SEMI_JOIN("LEFT SEMI JOIN", "⋉", TJoinOp.LEFT_SEMI_JOIN), + LEFT_ANTI_JOIN("LEFT ANTI JOIN", "◁", TJoinOp.LEFT_ANTI_JOIN), + RIGHT_SEMI_JOIN("RIGHT SEMI JOIN", "⋊", TJoinOp.RIGHT_SEMI_JOIN), + RIGHT_ANTI_JOIN("RIGHT ANTI JOIN", "▷", TJoinOp.RIGHT_ANTI_JOIN), + RIGHT_OUTER_JOIN("RIGHT OUTER JOIN", "⟖", TJoinOp.RIGHT_OUTER_JOIN), + FULL_OUTER_JOIN("FULL OUTER JOIN", "⟗", TJoinOp.FULL_OUTER_JOIN), + CROSS_JOIN("CROSS JOIN", "×", TJoinOp.CROSS_JOIN), // Variant of the LEFT ANTI JOIN that is used for the equal of // NOT IN subqueries. It can have a single equality join conjunct // that returns TRUE when the rhs is NULL. - NULL_AWARE_LEFT_ANTI_JOIN("NULL AWARE LEFT ANTI JOIN", + NULL_AWARE_LEFT_ANTI_JOIN("NULL AWARE LEFT ANTI JOIN", "▷*", TJoinOp.NULL_AWARE_LEFT_ANTI_JOIN); public static final String HINT_BUCKET = "BUCKET"; @@ -65,10 +64,12 @@ public enum JoinOperator { public static final String HINT_UNREORDER = "UNREORDER"; private final String description; + private final String algebra; private final TJoinOp thriftJoinOp; - private JoinOperator(String description, TJoinOp thriftJoinOp) { + JoinOperator(String description, String algebra, TJoinOp thriftJoinOp) { this.description = description; + this.algebra = algebra; this.thriftJoinOp = thriftJoinOp; } @@ -77,6 +78,10 @@ public String toString() { return description; } + public String toAlgebra() { + return algebra; + } + public TJoinOp toThrift() { return thriftJoinOp; } diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index 7e447f85fbd3c..da4cface02e01 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -376,6 +376,8 @@ public class SessionVariable implements Serializable, Writable, Cloneable { public static final String CBO_ENABLE_PARALLEL_PREPARE_METADATA = "enable_parallel_prepare_metadata"; + public static final String CBO_EXTRACT_COMMON_PLAN = "cbo_extract_common_plan"; + public static final String SKEW_JOIN_RAND_RANGE = "skew_join_rand_range"; public static final String ENABLE_STATS_TO_OPTIMIZE_SKEW_JOIN = "enable_stats_to_optimize_skew_join"; public static final String SKEW_JOIN_OPTIMIZE_USE_MCV_COUNT = "skew_join_use_mcv_count"; @@ -1279,6 +1281,9 @@ public static MaterializedViewRewriteMode parse(String str) { @VariableMgr.VarAttr(name = ENABLE_GIN_FILTER) private boolean enableGinFilter = true; + @VariableMgr.VarAttr(name = CBO_EXTRACT_COMMON_PLAN) + private boolean cboExtractCommonPlan = true; + @VariableMgr.VarAttr(name = CBO_MAX_REORDER_NODE_USE_EXHAUSTIVE) private int cboMaxReorderNodeUseExhaustive = 4; @@ -3515,6 +3520,14 @@ public void setEnableMultiColumnsOnGlobbalRuntimeFilter(boolean value) { this.enableMultiColumnsOnGlobalRuntimeFilter = value; } + public boolean isCboExtractCommonPlan() { + return cboExtractCommonPlan; + } + + public void setCboExtractCommonPlan(boolean cboExtractCommonPlan) { + this.cboExtractCommonPlan = cboExtractCommonPlan; + } + public boolean isEnableQueryDebugTrace() { return enableQueryDebugTrace; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/common/DebugOperatorTracer.java b/fe/fe-core/src/main/java/com/starrocks/sql/common/DebugOperatorTracer.java index d7a189e3aa5e7..fcad9b4eb0be1 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/common/DebugOperatorTracer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/common/DebugOperatorTracer.java @@ -273,7 +273,7 @@ public String visitLogicalTableFunction(LogicalTableFunctionOperator node, Void @Override public String visitLogicalLimit(LogicalLimitOperator node, Void context) { - return "LogicalLimitOperator" + " {limit=" + node.getLimit() + + return "LogicalLimitOperator {" + node.getPhase().name() + " limit=" + node.getLimit() + ", offset=" + node.getOffset() + "}"; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java index c677763d6f6bc..86b0c0f8fcafa 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java @@ -48,6 +48,7 @@ public MvPlanContext optimize(MaterializedView mv, optimizerConfig.disableRuleSet(RuleSetType.INTERSECT_REWRITE); optimizerConfig.disableRule(RuleType.TF_REWRITE_GROUP_BY_COUNT_DISTINCT); optimizerConfig.disableRule(RuleType.TF_PRUNE_EMPTY_SCAN); + optimizerConfig.disableRule(RuleType.TF_REUSE_FUSION_RULE); optimizerConfig.disableRule(RuleType.TF_MV_TEXT_MATCH_REWRITE_RULE); optimizerConfig.disableRule(RuleType.TF_MV_TRANSPARENT_REWRITE_RULE); // For sync mv, no rewrite query by original sync mv rule to avoid useless rewrite. diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptExpression.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptExpression.java index 63cd8548cf239..c876390dc2fb4 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptExpression.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptExpression.java @@ -264,11 +264,9 @@ private String debugString(String headlinePrefix, String detailPrefix, int limit StringBuilder sb = new StringBuilder(); sb.append(headlinePrefix).append(op.accept(new DebugOperatorTracer(), null)); limitLine -= 1; - sb.append('\n'); if (limitLine <= 0 || inputs.isEmpty()) { return sb.toString(); } - String childHeadlinePrefix = detailPrefix + "-> "; String childDetailPrefix = detailPrefix + " "; for (OptExpression input : inputs) { diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java index bcef9feb0c158..2a0c1a479aa52 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java @@ -39,6 +39,7 @@ import com.starrocks.sql.optimizer.rewrite.JoinPredicatePushdown; import com.starrocks.sql.optimizer.rule.Rule; import com.starrocks.sql.optimizer.rule.RuleSetType; +import com.starrocks.sql.optimizer.rule.RuleType; import com.starrocks.sql.optimizer.rule.implementation.OlapScanImplementationRule; import com.starrocks.sql.optimizer.rule.join.ReorderJoinRule; import com.starrocks.sql.optimizer.rule.mv.MaterializedViewRule; @@ -107,6 +108,7 @@ import com.starrocks.sql.optimizer.rule.tree.SimplifyCaseWhenPredicateRule; import com.starrocks.sql.optimizer.rule.tree.SubfieldExprNoCopyRule; import com.starrocks.sql.optimizer.rule.tree.lowcardinality.LowCardinalityRewriteRule; +import com.starrocks.sql.optimizer.rule.tree.pieces.ReuseFusionPlanRule; import com.starrocks.sql.optimizer.rule.tree.prunesubfield.PruneSubfieldRule; import com.starrocks.sql.optimizer.rule.tree.prunesubfield.PushDownSubfieldRule; import com.starrocks.sql.optimizer.task.OptimizeGroupTask; @@ -530,6 +532,7 @@ private OptExpression logicalRuleRewrite( ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PRUNE_COLUMNS); ruleRewriteIterative(tree, rootTaskContext, RuleSetType.PRUNE_UKFK_JOIN); deriveLogicalProperty(tree); + tree = extractCommonCTE(tree, rootTaskContext, requiredColumns); ruleRewriteOnlyOnce(tree, rootTaskContext, new PushDownJoinOnExpressionToChildProject()); @@ -643,6 +646,8 @@ private OptExpression logicalRuleRewrite( ruleRewriteIterative(tree, rootTaskContext, new MergeProjectWithChildRule()); ruleRewriteOnlyOnce(tree, rootTaskContext, new PushDownTopNBelowOuterJoinRule()); + // intersect rewrite depend on statistics + Utils.calculateStatistics(tree, rootTaskContext.getOptimizerContext()); ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.INTERSECT_REWRITE); ruleRewriteIterative(tree, rootTaskContext, new RemoveAggregationFromAggTable()); @@ -673,6 +678,22 @@ private OptExpression logicalRuleRewrite( return tree.getInputs().get(0); } + private OptExpression extractCommonCTE(OptExpression tree, TaskContext rootTaskContext, + ColumnRefSet requiredColumns) { + if (!context.getSessionVariable().isCboExtractCommonPlan() || + optimizerConfig.isRuleDisable(RuleType.TF_REUSE_FUSION_RULE)) { + return tree; + } + ReuseFusionPlanRule fusion = new ReuseFusionPlanRule(); + tree = fusion.rewrite(tree, rootTaskContext); + if (fusion.hasRewrite()) { + deriveLogicalProperty(tree); + rootTaskContext.setRequiredColumns(requiredColumns.clone()); + ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PRUNE_COLUMNS); + } + return tree; + } + private void rewriteGroupingSets(OptExpression tree, TaskContext rootTaskContext, SessionVariable sessionVariable) { if (sessionVariable.isEnableRewriteGroupingsetsToUnionAll()) { ruleRewriteIterative(tree, rootTaskContext, new RewriteGroupingSetsByCTERule()); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/OperatorType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/OperatorType.java index aafa1711e180e..3281bad36d993 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/OperatorType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/OperatorType.java @@ -57,6 +57,7 @@ public enum OperatorType { LOGICAL_CTE_ANCHOR, LOGICAL_CTE_PRODUCE, LOGICAL_CTE_CONSUME, + LOGICAL_SPJG_PIECES, /** * Physical operator diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ReplaceColumnRefRewriter.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ReplaceColumnRefRewriter.java index c0fe1dc93f67d..5a3a1479d7085 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ReplaceColumnRefRewriter.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ReplaceColumnRefRewriter.java @@ -66,15 +66,18 @@ public ScalarOperator visitVariableReference(ColumnRefOperator column, Void cont // The rewritten predicate will be rewritten continually, // Rewiring predicate shouldn't change the origin project columnRefMap - ScalarOperator mapperOperator = operatorMap.get(column).clone(); - if (isRecursively) { - while (mapperOperator.getChildren().isEmpty() && operatorMap.containsKey(mapperOperator)) { + ScalarOperator mapperOperator = operatorMap.get(column); + if (!isRecursively) { + return mapperOperator.clone(); + } else { + while (mapperOperator instanceof ColumnRefOperator && operatorMap.containsKey(mapperOperator)) { ScalarOperator mapped = operatorMap.get(mapperOperator); if (mapped.equals(mapperOperator)) { break; } - mapperOperator = mapped.clone(); + mapperOperator = mapped; } + mapperOperator = mapperOperator.clone(); for (int i = 0; i < mapperOperator.getChildren().size(); ++i) { mapperOperator.setChild(i, mapperOperator.getChild(i).accept(this, null)); } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java index 8bbd11e9cc665..2889142e28f63 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java @@ -256,6 +256,7 @@ public class RuleSet { MergeLimitDirectRule.EXCEPT, MergeLimitDirectRule.VALUES, MergeLimitDirectRule.FILTER, + MergeLimitDirectRule.CTE_CONSUMER, MergeLimitDirectRule.TABLE_FUNCTION, MergeLimitDirectRule.TABLE_FUNCTION_TABLE_SCAN )); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java index 5b6eead20c664..1b5f0da8b0cdd 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java @@ -206,6 +206,8 @@ public enum RuleType { TF_VECTOR_REWRITE_RULE, + TF_REUSE_FUSION_RULE, + // The following are implementation rules: IMP_OLAP_LSCAN_TO_PSCAN, IMP_HIVE_LSCAN_TO_PSCAN, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MergeLimitDirectRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MergeLimitDirectRule.java index 29afbe272d36b..0c39a64ab1d41 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MergeLimitDirectRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MergeLimitDirectRule.java @@ -54,6 +54,8 @@ public class MergeLimitDirectRule extends TransformationRule { new MergeLimitDirectRule(OperatorType.LOGICAL_TABLE_FUNCTION); public static final MergeLimitDirectRule TABLE_FUNCTION_TABLE_SCAN = new MergeLimitDirectRule(OperatorType.LOGICAL_TABLE_FUNCTION_TABLE_SCAN); + public static final MergeLimitDirectRule CTE_CONSUMER = + new MergeLimitDirectRule(OperatorType.LOGICAL_CTE_CONSUME); private MergeLimitDirectRule(OperatorType logicalOperatorType) { super(RuleType.TF_MERGE_LIMIT_DIRECT, Pattern.create(OperatorType.LOGICAL_LIMIT) diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/LogicalPiecesOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/LogicalPiecesOperator.java new file mode 100644 index 0000000000000..c418b3b27faa3 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/LogicalPiecesOperator.java @@ -0,0 +1,40 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.tree.pieces; + +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.operator.Operator; +import com.starrocks.sql.optimizer.operator.OperatorType; + +class LogicalPiecesOperator extends Operator { + public LogicalPiecesOperator(OperatorType type, OptExpression plan, QueryPiecesPlan piece) { + super(type); + this.plan = plan; + this.piece = piece; + this.predicate = plan.getOp().getPredicate(); + } + + private final OptExpression plan; + + private final QueryPiecesPlan piece; + + public OptExpression getPlan() { + return plan; + } + + public QueryPiecesPlan getPiece() { + return piece; + } +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/PiecesPlanTransformer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/PiecesPlanTransformer.java new file mode 100644 index 0000000000000..19c29bcced18a --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/PiecesPlanTransformer.java @@ -0,0 +1,209 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.tree.pieces; + +import com.google.api.client.util.Lists; +import com.google.common.collect.Maps; +import com.starrocks.catalog.Column; +import com.starrocks.common.Pair; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptExpressionVisitor; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.base.ColumnRefSet; +import com.starrocks.sql.optimizer.operator.Operator; +import com.starrocks.sql.optimizer.operator.OperatorType; +import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; + +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +class PiecesPlanTransformer { + private final ColumnRefFactory factory; + private final List planPieces = Lists.newArrayList(); + private int planId = 0; + + public PiecesPlanTransformer(ColumnRefFactory factory) { + this.factory = factory; + } + + public List getPlanPieces() { + return planPieces; + } + + private static boolean checkTrees(OptExpression root, Predicate lambda) { + if (!lambda.test(root.getOp())) { + return false; + } + List inputs = root.getInputs(); + for (OptExpression input : inputs) { + if (!checkTrees(input, lambda)) { + return false; + } + } + return true; + } + + public boolean isSPJGPieces(OptExpression tree) { + if (!tree.getOp().getOpType().equals(OperatorType.LOGICAL_AGGR)) { + return false; + } + return checkTrees(tree.inputAt(0), op -> op.getOpType().equals(OperatorType.LOGICAL_PROJECT) + || op.getOpType().equals(OperatorType.LOGICAL_JOIN) + || op instanceof LogicalScanOperator); + } + + // trans to special pieces plan + public OptExpression transformSPJGPieces(OptExpression root) { + for (int i = 0; i < root.arity(); i++) { + if (root.inputAt(i).getOp().getOpType() == OperatorType.LOGICAL_CTE_CONSUME) { + // don't check CTE consume plan, it's duplicate work + break; + } + if (isSPJGPieces(root.inputAt(i))) { + QueryPiecesPlan piece = planToPiece(root.inputAt(i), planId++); + LogicalPiecesOperator op = + new LogicalPiecesOperator(OperatorType.LOGICAL_SPJG_PIECES, root.inputAt(i), piece); + + root.getInputs().set(i, OptExpression.create(op)); + planPieces.add(piece); + continue; + } + transformSPJGPieces(root.inputAt(i)); + } + return root; + } + + private QueryPiecesPlan planToPiece(OptExpression plan, int id) { + QueryPiecesPlan piecesPlan = new QueryPiecesPlan(id, new ScalarOperatorConverter()); + piecesPlan.planId = id; + piecesPlan.root = new Transformer().visit(plan, piecesPlan); + return piecesPlan; + } + + public OptExpression transformPlan(OptExpression root) { + if (root.getOp().getOpType() == OperatorType.LOGICAL_SPJG_PIECES) { + LogicalPiecesOperator op = root.getOp().cast(); + return op.getPlan(); + } + for (int i = 0; i < root.arity(); i++) { + root.setChild(i, transformPlan(root.inputAt(i))); + } + return root; + } + + // replace columnRef with newId + private class Transformer extends OptExpressionVisitor { + @Override + public QueryPieces visit(OptExpression optExpression, QueryPiecesPlan context) { + + if (optExpression.getOp().getOpType() == OperatorType.LOGICAL_PROJECT) { + OptExpression child = optExpression.inputAt(0); + QueryPieces childPieces = child.getOp().accept(this, child, context); + childPieces.op = child.getOp(); + + QueryPieces pieces = visitProjection(optExpression.getOp().cast(), context); + pieces.inputs.add(childPieces); + pieces.algebra = childPieces.algebra; + return pieces; + } else { + QueryPieces childPieces = optExpression.getOp().accept(this, optExpression, context); + childPieces.op = optExpression.getOp(); + + ColumnRefSet refs = optExpression.getOutputColumns(); + Map project = Maps.newHashMap(); + refs.getColumnRefOperators(factory).forEach(ref -> project.put(ref, ref)); + + LogicalProjectOperator temp = new LogicalProjectOperator(project); + QueryPieces pieces = visitProjection(temp, context); + pieces.inputs.add(childPieces); + pieces.algebra = childPieces.algebra; + return pieces; + } + } + + private QueryPieces visitProjection(LogicalProjectOperator project, QueryPiecesPlan context) { + QueryPieces pieces = new QueryPieces(); + + project.getColumnRefMap().entrySet().stream() + .map(e -> Pair.create(e.getKey(), context.columnRefConverter.convert(e.getValue()))) + .sorted(Comparator.comparing(p -> p.second.toString())) + .forEach(p -> { + ColumnRefOperator newRef = context.columnRefConverter.convertRef(p.first); + context.columnRefConverter.addExpr(newRef, p.second); + }); + + pieces.op = project; + return pieces; + } + + @Override + public QueryPieces visitLogicalAggregate(OptExpression optExpression, QueryPiecesPlan context) { + QueryPieces pieces = new QueryPieces(); + pieces.inputs.add(visit(optExpression.inputAt(0), context)); + + LogicalAggregationOperator aggregate = optExpression.getOp().cast(); + List groupBy = aggregate.getGroupingKeys() + .stream().map(g -> context.columnRefConverter.convertExpr(g)) + .sorted(Comparator.comparing(ScalarOperator::toString)) + .collect(Collectors.toList()); + + pieces.algebra = "G(" + groupBy + " => " + pieces.inputs.get(0).algebra + ")"; + return pieces; + } + + @Override + public QueryPieces visitLogicalJoin(OptExpression optExpression, QueryPiecesPlan context) { + LogicalJoinOperator join = optExpression.getOp().cast(); + + QueryPieces pieces = new QueryPieces(); + pieces.inputs.add(visit(optExpression.inputAt(0), context)); + pieces.inputs.add(visit(optExpression.inputAt(1), context)); + + pieces.algebra = pieces.inputs.stream().map(p -> p.algebra) + .collect(Collectors.joining(" " + join.getJoinType().toAlgebra() + " ")); + pieces.algebra = "(" + pieces.algebra + " on " + + context.columnRefConverter.convertExpr(join.getPredicate()) + " & " + + context.columnRefConverter.convertExpr(join.getOnPredicate()) + ")"; + return pieces; + } + + @Override + public QueryPieces visitLogicalTableScan(OptExpression optExpression, QueryPiecesPlan context) { + QueryPieces pieces = new QueryPieces(); + LogicalScanOperator scan = optExpression.getOp().cast(); + + Map columnMetaToIdMap = Maps.newHashMap(); + scan.getTable().getColumns().stream().sorted(Comparator.comparing(Column::getName)).forEach(c -> + columnMetaToIdMap.put(c, context.columnRefConverter.getNextID()) + ); + + scan.getColumnMetaToColRefMap().forEach((c, ref) -> { + context.columnRefConverter.convertRef(ref, columnMetaToIdMap.get(c)); + }); + + pieces.algebra = scan.getTable().getUUID() + ":" + scan.getTable().getName(); + return pieces; + } + } + +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/QueryPieces.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/QueryPieces.java new file mode 100644 index 0000000000000..f25727205121b --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/QueryPieces.java @@ -0,0 +1,82 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.tree.pieces; + +import com.google.common.collect.Lists; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.operator.Operator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +public class QueryPieces { + String algebra; + + // raw operator + Operator op; + + // filter used columns + List filterUsedRefs; + + List inputs = Lists.newArrayList(); + + public static Optional of(Operator op, List filterUsedRefs, QueryPieces... input) { + QueryPieces p = new QueryPieces(); + p.op = op; + p.inputs.addAll(Arrays.asList(input)); + p.filterUsedRefs = filterUsedRefs; + return Optional.of(p); + } + + @Override + public String toString() { + return "QueryPieces{" + + "algebra='" + algebra + '\'' + + ", op=" + op + + '}'; + } +} + +class QueryPiecesPlan { + ScalarOperatorConverter columnRefConverter; + int planId; + QueryPieces root; + + public QueryPiecesPlan(int planId, ScalarOperatorConverter converter) { + this.planId = planId; + this.columnRefConverter = converter; + } + + public String planIdentifier() { + return root.algebra; + } + + public OptExpression toOptExpression() { + return toOptExpressionImpl(root); + } + + public OptExpression toOptExpressionImpl(QueryPieces pieces) { + List inputs = Lists.newArrayList(); + for (QueryPieces input : pieces.inputs) { + inputs.add(toOptExpressionImpl(input)); + } + return OptExpression.create(pieces.op, inputs); + } +} + + + diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/ReuseFusionPlanRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/ReuseFusionPlanRule.java new file mode 100644 index 0000000000000..1ff067117417e --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/ReuseFusionPlanRule.java @@ -0,0 +1,433 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.tree.pieces; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.starrocks.analysis.Expr; +import com.starrocks.analysis.JoinOperator; +import com.starrocks.catalog.Column; +import com.starrocks.catalog.Function; +import com.starrocks.catalog.FunctionSet; +import com.starrocks.catalog.Type; +import com.starrocks.sql.optimizer.CTEContext; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.Utils; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.base.ColumnRefSet; +import com.starrocks.sql.optimizer.operator.Operator; +import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory; +import com.starrocks.sql.optimizer.operator.OperatorType; +import com.starrocks.sql.optimizer.operator.OperatorVisitor; +import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalCTEAnchorOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalCTEConsumeOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalCTEProduceOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator; +import com.starrocks.sql.optimizer.operator.scalar.CallOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rewrite.ScalarRangePredicateExtractor; +import com.starrocks.sql.optimizer.rule.tree.TreeRewriteRule; +import com.starrocks.sql.optimizer.task.TaskContext; + +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +public class ReuseFusionPlanRule implements TreeRewriteRule { + // original piece id -> fusion piece + private final Map originPlan2FusionPiecsMap = Maps.newHashMap(); + + private final List fusionPieces = Lists.newArrayList(); + + private ColumnRefFactory factory; + + public boolean hasRewrite() { + return !fusionPieces.isEmpty(); + } + + @Override + public OptExpression rewrite(OptExpression root, TaskContext taskContext) { + factory = taskContext.getOptimizerContext().getColumnRefFactory(); + PiecesPlanTransformer transformer = new PiecesPlanTransformer(factory); + root = transformer.transformSPJGPieces(root); + + recommendFusionCTE(transformer.getPlanPieces()); + Preconditions.checkState(root.getOp().getOpType() == OperatorType.LOGICAL); + + if (!fusionPieces.isEmpty()) { + OptExpression anchor = generateCTEPlan(taskContext.getOptimizerContext().getCteContext(), root.inputAt(0)); + root.setChild(0, anchor); + } + return transformer.transformPlan(root); + } + + private OptExpression generateCTEPlan(CTEContext cteContext, OptExpression root) { + OptExpression start = root.inputAt(0); + for (QueryPiecesPlan plan : fusionPieces) { + int cteId = cteContext.getNextCteId(); + plan.planId = cteId; + OptExpression produce = OptExpression.create(new LogicalCTEProduceOperator(cteId), plan.toOptExpression()); + root = OptExpression.create(new LogicalCTEAnchorOperator(cteId), produce, root); + cteContext.addForceCTE(cteId); + } + rewritePlan(start); + return root; + } + + private OptExpression rewritePlan(OptExpression root) { + if (root.getOp().getOpType() != OperatorType.LOGICAL_SPJG_PIECES) { + for (int i = 0; i < root.arity(); i++) { + root.setChild(i, rewritePlan(root.inputAt(i))); + } + return root; + } + + LogicalPiecesOperator pieces = (LogicalPiecesOperator) root.getOp(); + if (!originPlan2FusionPiecsMap.containsKey(pieces.getPiece().planId)) { + return pieces.getPlan(); + } + + ColumnRefSet outputs = pieces.getPlan().getOutputColumns(); + Map cteRefsMapping = Maps.newHashMap(); + + QueryPiecesPlan fusionPlan = originPlan2FusionPiecsMap.get(pieces.getPiece().planId); + for (ColumnRefOperator originRef : outputs.getColumnRefOperators(factory)) { + cteRefsMapping.put(originRef, fusionPlan.columnRefConverter.convertRef(originRef)); + } + + LogicalCTEConsumeOperator consumer = LogicalCTEConsumeOperator.builder().setCteId(fusionPlan.planId) + .setCteOutputColumnRefMap(cteRefsMapping) + .setPredicate(pieces.getPredicate()) + .build(); + return OptExpression.create(consumer); + } + + private void recommendFusionCTE(List piecesPlans) { + Collectors.groupingBy(QueryPiecesPlan::planIdentifier); + + Map> groups = + piecesPlans.stream().collect(Collectors.groupingBy(QueryPiecesPlan::planIdentifier)); + + for (Map.Entry> entry : groups.entrySet()) { + List pieces = entry.getValue(); + + if (pieces.size() < 2) { + continue; + } + + PiecesFusion fusion = new PiecesFusion(factory, pieces); + Optional p = fusion.fusion(); + p.ifPresent(queryPiecesPlan -> { + fusionPieces.add(queryPiecesPlan); + pieces.forEach(pi -> originPlan2FusionPiecsMap.put(pi.planId, queryPiecesPlan)); + }); + } + } + + // top-down generator new plan + private static class PiecesFusion extends OperatorVisitor, List> { + + private final ScalarOperatorConverter converter; + + private final ColumnRefFactory factory; + + private final List piecePlans; + + private final List allJoinTypes; + + // align with piecePlans + private final List> fusionScanFilters; + + public PiecesFusion(ColumnRefFactory factory, List piecePlans) { + this.factory = factory; + this.piecePlans = piecePlans; + + this.converter = new ScalarOperatorConverter(); + this.converter.disableCreateRef(); + + this.allJoinTypes = Lists.newArrayList(); + this.fusionScanFilters = Lists.newArrayList(); + piecePlans.forEach(p -> this.fusionScanFilters.add(Lists.newArrayList())); + } + + public Optional fusion() { + List qp = piecePlans.stream().map(p -> p.root).collect(Collectors.toList()); + if (qp.stream().map(p -> p.op.getOpType()).distinct().count() > 1) { + return Optional.empty(); + } + Optional plan = qp.get(0).op.accept(this, qp); + return plan.map(p -> { + QueryPiecesPlan newPlan = new QueryPiecesPlan(-1, converter); + newPlan.root = p; + return newPlan; + }); + } + + private Optional visitChild(List pieces, int childIndex) { + List childSize = pieces.stream().map(p -> p.inputs.size()).distinct().collect(Collectors.toList()); + if (childSize.size() > 1 && childSize.get(0) <= childIndex) { + return Optional.empty(); + } + + List children = pieces.stream().map(p -> p.inputs.get(childIndex)) + .collect(Collectors.toList()); + List ops = children.stream().map(p -> p.op.getOpType()).distinct() + .collect(Collectors.toList()); + if (ops.size() > 1) { + return Optional.empty(); + } + return children.get(0).op.accept(this, children); + } + + @Override + public Optional visitLogicalProject(LogicalProjectOperator node, List context) { + Optional child = visitChild(context, 0); + if (child.isEmpty()) { + return child; + } + + Map project = Maps.newHashMap(); + for (QueryPieces pieces : context) { + LogicalProjectOperator p = pieces.op.cast(); + p.getColumnRefMap().forEach((k, v) -> { + if (converter.contains(k)) { + project.put(converter.convertRef(k), converter.convert(v)); + } else { + ColumnRefOperator newRef = factory.create(k.getName(), k.getType(), k.isNullable()); + project.put(newRef, converter.convert(v)); + converter.add(k, newRef); + } + }); + } + + for (ColumnRefOperator ref : child.get().filterUsedRefs) { + project.put(ref, ref); + } + return QueryPieces.of(new LogicalProjectOperator(project), child.get().filterUsedRefs, child.get()); + } + + @Override + public Optional visitLogicalAggregation(LogicalAggregationOperator node, + List context) { + Optional child = visitChild(context, 0); + if (child.isEmpty()) { + return Optional.empty(); + } + + List pieceFilters = fusionScanFilters.stream() + .map(Utils::compoundAnd) + .collect(Collectors.toList()); + + long filterDistinct = pieceFilters.stream().distinct().count(); + if (filterDistinct != 1 && !allJoinTypes.isEmpty() && + allJoinTypes.stream().anyMatch(p -> p != JoinOperator.INNER_JOIN)) { + return Optional.empty(); + } + + List groupBys = node.getGroupingKeys().stream().map(converter::convertRef) + .collect(Collectors.toList()); + List partitions = node.getPartitionByColumns().stream().map(converter::convertRef) + .collect(Collectors.toList()); + Map aggToRefs = Maps.newHashMap(); + Map aggFilterProject = Maps.newHashMap(); + + Preconditions.checkState(context.size() == pieceFilters.size()); + for (int i = 0; i < context.size(); i++) { + LogicalAggregationOperator aggregate = context.get(i).op.cast(); + + // check group by + if (aggregate.getGroupingKeys().stream().map(converter::convertRef) + .anyMatch(c -> !groupBys.contains(c))) { + return Optional.empty(); + } + + // check aggregate + if (filterDistinct > 1 && + aggregate.getAggregations().values().stream().anyMatch(v -> v.getChildren().size() > 1)) { + return Optional.empty(); + } + + ScalarOperator filter = filterDistinct > 1 ? pieceFilters.get(i) : null; + aggregate.getAggregations().forEach((ref, call) -> { + CallOperator newCall = addFilterAggCall((CallOperator) converter.convert(call), filter, + aggFilterProject); + + if (aggToRefs.containsKey(newCall)) { + converter.add(ref, aggToRefs.get(newCall)); + } else { + ColumnRefOperator newRef = factory.create(ref.getName(), ref.getType(), ref.isNullable()); + aggToRefs.put(newCall, newRef); + converter.add(ref, newRef); + } + }); + } + + Optional childPieces = child; + if (!aggFilterProject.isEmpty()) { + Map filterProject = aggFilterProject.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + groupBys.forEach(k -> filterProject.put(k, k)); + childPieces = QueryPieces.of(new LogicalProjectOperator(filterProject), + Collections.emptyList(), childPieces.get()); + } + + Map aggs = aggToRefs.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + Operator op = LogicalAggregationOperator.builder().withOperator(node) + .setAggregations(aggs) + .setGroupingKeys(groupBys) + .setPartitionByColumns(partitions) + .setPredicate(converter.convert(node.getPredicate())) + .build(); + Preconditions.checkState(childPieces.isPresent()); + return QueryPieces.of(op, Collections.emptyList(), childPieces.get()); + } + + private CallOperator addFilterAggCall(CallOperator call, ScalarOperator filter, + Map filterProject) { + Preconditions.checkState(call.getChildren().size() <= 1); + if (filter == null) { + return call; + } + + ScalarOperator child; + Function aggFunc; + if (call.getChildren().isEmpty()) { + // count(*) + Preconditions.checkState(FunctionSet.COUNT.equalsIgnoreCase(call.getFnName())); + child = ConstantOperator.createInt(1); + aggFunc = Expr.getBuiltinFunction(call.getFunction().getFunctionName().getFunction(), + new Type[] {Type.INT}, Function.CompareMode.IS_IDENTICAL); + } else { + child = call.getChild(0); + aggFunc = call.getFunction(); + } + Function f = Expr.getBuiltinFunction(FunctionSet.IF, + new Type[] {Type.BOOLEAN, child.getType(), child.getType()}, Function.CompareMode.IS_IDENTICAL); + CallOperator ifNull = new CallOperator("if", child.getType(), + List.of(filter, child, ConstantOperator.createNull(child.getType())), f); + + if (!filterProject.containsKey(ifNull)) { + filterProject.put(ifNull, factory.create(ifNull, child.getType(), true)); + } + + return new CallOperator(call.getFnName(), call.getType(), List.of(filterProject.get(ifNull)), + aggFunc, call.isDistinct()); + } + + @Override + public Optional visitLogicalJoin(LogicalJoinOperator node, List context) { + if (context.stream().map(p -> ((LogicalJoinOperator) p.op).getJoinType()).distinct().count() > 1) { + return Optional.empty(); + } + + allJoinTypes.add(node.getJoinType()); + Optional left = visitChild(context, 0); + if (left.isEmpty()) { + return Optional.empty(); + } + Optional right = visitChild(context, 1); + if (right.isEmpty()) { + return Optional.empty(); + } + + List filterRefs = Lists.newArrayList(left.get().filterUsedRefs); + filterRefs.addAll(right.get().filterUsedRefs); + + LogicalJoinOperator.Builder builder = LogicalJoinOperator.builder() + .withOperator(node) + .setOnPredicate(converter.convert(node.getOnPredicate())); + return QueryPieces.of(builder.build(), filterRefs, left.get(), right.get()); + } + + @Override + public Optional visitLogicalTableScan(LogicalScanOperator node, List context) { + if (context.stream().map(p -> ((LogicalScanOperator) p.op).getTable()).distinct().count() > 1) { + return Optional.empty(); + } + if (context.stream().map(p -> p.op.getLimit()).distinct().count() > 1) { + return Optional.empty(); + } + if (context.stream().map(p -> p.op.getLimit()).distinct().count() > 1) { + return Optional.empty(); + } + + Map columnMetaToColRefMap = Maps.newHashMap(); + Map colRefToColumnMetaMap = Maps.newHashMap(); + + int scanId = factory.getNextRelationId(); + node.getTable().getColumns().stream().sorted(Comparator.comparing(Column::getName)).forEach(c -> { + ColumnRefOperator newRef = factory.create(c.getName(), c.getType(), c.isAllowNull()); + columnMetaToColRefMap.put(c, newRef); + colRefToColumnMetaMap.put(newRef, c); + factory.updateColumnToRelationIds(newRef.getId(), scanId); + factory.updateColumnRefToColumns(newRef, c, node.getTable()); + }); + + LogicalScanOperator.Builder builder = OperatorBuilderFactory.build(node); + builder.withOperator(node); + builder.setColRefToColumnMetaMap(colRefToColumnMetaMap); + builder.setColumnMetaToColRefMap(columnMetaToColRefMap); + builder.setTable(node.getTable()); + + List fusionPredicates = Lists.newArrayList(); + Set fusionPredicateSet = Sets.newHashSet(); + for (QueryPieces piece : context) { + LogicalScanOperator scan = piece.op.cast(); + scan.getColumnMetaToColRefMap().forEach((c, ref) -> { + ColumnRefOperator newRef = columnMetaToColRefMap.get(c); + converter.add(ref, newRef); + }); + ScalarOperator newPredicate = converter.convert(scan.getPredicate()); + fusionPredicates.add(newPredicate); + fusionPredicateSet.add(newPredicate); + } + + List filterUsedRefs = Lists.newArrayList(); + if (fusionPredicateSet.size() > 1) { + // same predicate doesn't need do again + Preconditions.checkState(fusionPredicates.size() == this.fusionScanFilters.size()); + for (int i = 0; i < fusionPredicates.size(); i++) { + ScalarOperator filter = fusionPredicates.get(i); + if (filter != null) { + this.fusionScanFilters.get(i).add(filter); + filter.getColumnRefs(filterUsedRefs); + } + } + } + ScalarOperator predicate = Utils.compoundOr(fusionPredicateSet); + ScalarRangePredicateExtractor extractor = new ScalarRangePredicateExtractor(); + predicate = extractor.rewriteOnlyColumn(predicate); + builder.setPredicate(predicate); + return QueryPieces.of(builder.build(), filterUsedRefs); + } + + @Override + public Optional visitOperator(Operator node, List context) { + return Optional.empty(); + } + } +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/ScalarOperatorConverter.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/ScalarOperatorConverter.java new file mode 100644 index 0000000000000..02a57ed80a735 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pieces/ScalarOperatorConverter.java @@ -0,0 +1,82 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.tree.pieces; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter; + +import java.util.Map; + +class ScalarOperatorConverter { + // origin column refs mapping to new column refs + private final Map refsMapping = Maps.newHashMap(); + + private final ReplaceColumnRefRewriter refRewriter = new ReplaceColumnRefRewriter(refsMapping, false); + + // new refs -> refs expression, use for normalize projection + private final Map expressionMapping = Maps.newHashMap(); + + private final ReplaceColumnRefRewriter exprRewriter = new ReplaceColumnRefRewriter(expressionMapping, true); + + private int nextID = 0; + + public int getNextID() { + Preconditions.checkState(nextID > -1); + return nextID++; + } + + public void disableCreateRef() { + nextID = -100; + } + + public void add(ColumnRefOperator originRef, ColumnRefOperator newRef) { + refsMapping.put(originRef, newRef); + } + + public boolean contains(ColumnRefOperator ref) { + return refsMapping.containsKey(ref); + } + + public ColumnRefOperator convertRef(ColumnRefOperator columnRef) { + if (refsMapping.containsKey(columnRef)) { + return refsMapping.get(columnRef); + } + + Preconditions.checkState(nextID > -1); + return convertRef(columnRef, nextID++); + } + + public ColumnRefOperator convertRef(ColumnRefOperator originRef, int newId) { + ColumnRefOperator newRef = + new ColumnRefOperator(newId, originRef.getType(), originRef.getName(), originRef.isNullable()); + refsMapping.put(originRef, newRef); + return newRef; + } + + public ScalarOperator convert(ScalarOperator operator) { + return refRewriter.rewrite(operator); + } + + public void addExpr(ColumnRefOperator ref, ScalarOperator expr) { + expressionMapping.put(ref, expr); + } + + public ScalarOperator convertExpr(ScalarOperator operator) { + return exprRewriter.rewrite(refRewriter.rewrite(operator)); + } +} diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanTestNoneDBBase.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanTestNoneDBBase.java index 3d63995d4bef8..4a387fa7bf029 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanTestNoneDBBase.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanTestNoneDBBase.java @@ -90,6 +90,7 @@ public static void beforeClass() throws Exception { connectContext.getSessionVariable().setOptimizerExecuteTimeout(30000); connectContext.getSessionVariable().setUseLowCardinalityOptimizeV2(false); connectContext.getSessionVariable().setCboEqBaseType(SessionVariableConstants.VARCHAR); + connectContext.getSessionVariable().setCboExtractCommonPlan(false); FeConstants.enablePruneEmptyOutputScan = false; FeConstants.showJoinLocalShuffleInExplain = false; FeConstants.showFragmentCost = false; diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDS1TExtractCTETest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDS1TExtractCTETest.java new file mode 100644 index 0000000000000..d978f65153e6f --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDS1TExtractCTETest.java @@ -0,0 +1,91 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +package com.starrocks.sql.plan; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + + +public class TPCDS1TExtractCTETest extends TPCDS1TTestBase { + + @BeforeAll + public static void beforeClass() throws Exception { + TPCDSPlanTestBase.beforeClass(); + connectContext.getSessionVariable().setCboCTERuseRatio(0); + connectContext.getSessionVariable().setOptimizerExecuteTimeout(-1); + connectContext.getSessionVariable().setCboExtractCommonPlan(true); + } + + @AfterAll + public static void afterClass() { + connectContext.getSessionVariable().setCboExtractCommonPlan(true); + } + + @Test + public void testQuery09() throws Exception { + String plan = getFragmentPlan(Q09); + assertContains(plan, "MultiCastDataSinks"); + assertContains(plan, "2:AGGREGATE (update serialize)\n" + + " | output: count(if(461: expr, 1, NULL)), avg(if(461: expr, 394: ss_ext_discount_amt, NULL))"); + } + + @Test + public void testQuery28() throws Exception { + String plan = getFragmentPlan(Q28); + assertContains(plan, "MultiCastDataSinks"); + assertContains(plan, " 2:AGGREGATE (update serialize)\n" + + " | output: avg(192: if), count(192: if), multi_distinct_count(192: if), " + + "count(196: if), multi_distinct_count(196: if)"); + } + + @Test + public void testQuery44() throws Exception { + String plan = getFragmentPlan(Q44); + assertContains(plan, "MultiCastDataSinks"); + assertContains(plan, " 4:AGGREGATE (merge finalize)\n" + + " | output: avg(192: avg)\n" + + " | group by: 179: ss_item_sk"); + assertContains(plan, " 9:AGGREGATE (merge finalize)\n" + + " | output: avg(168: avg)\n" + + " | group by: 165: ss_store_sk"); + } + + @Test + public void testQuery65() throws Exception { + String plan = getFragmentPlan(Q65); + assertContains(plan, "MultiCastDataSinks"); + assertContains(plan, " 8:AGGREGATE (merge finalize)\n" + + " | output: sum(208: sum)\n" + + " | group by: 177: ss_store_sk, 167: ss_item_sk"); + } + + @Test + public void testQuery88() throws Exception { + String plan = getFragmentPlan(Q88); + assertContains(plan, "MultiCastDataSinks"); + assertContains(plan, "AGGREGATE (merge finalize)\n" + + " | output: count(641: count), count(643: count)"); + } + + @Test + public void testQuery90() throws Exception { + String plan = getFragmentPlan(Q90); + assertContains(plan, "MultiCastDataSinks"); + assertContains(plan, "AGGREGATE (update serialize)\n" + + " | output: count(if((172: t_hour >= 8)"); + } +}