Skip to content

Commit

Permalink
make it stable
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Nov 3, 2023
1 parent 90b424d commit 2316fac
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Aggregate [count(if ((_common_expr_1#0 = false)) null else _common_expr_1#0) AS count_if((a > 0))#0L]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_1#0]
Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Project [if ((_common_expr_2#0 = )) null else _common_expr_2#0 AS regexp_substr(g, \d{2}(a|b|m))#0]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_2#0]
Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS regexp_substr(g, \d{2}(a|b|m))#0]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
case With(child, defs) if defs.forall(!_.containsPattern(WITH_EXPRESSION)) =>
val idToCheapExpr = mutable.HashMap.empty[Long, Expression]
val idToNonCheapExpr = mutable.HashMap.empty[Long, Alias]
defs.foreach { commonExprDef =>
defs.zipWithIndex.foreach { case (commonExprDef, index) =>
if (CollapseProject.isCheap(commonExprDef.child)) {
idToCheapExpr(commonExprDef.id) = commonExprDef.child
} else {
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
val alias = Alias(commonExprDef.child, s"_common_expr_${commonExprDef.id}")()
val alias = Alias(commonExprDef.child, s"_common_expr_$index")()
commonExprs += alias
idToNonCheapExpr(commonExprDef.id) = alias
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.IntegerType

class RewriteWithExpressionSuite extends PlanTest {

Expand All @@ -46,7 +47,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val plan = testRelation.select(With(ref * ref, Seq(commonExprDef)).as("col"))
val commonExprName = "_common_expr_" + commonExprDef.id
val commonExprName = "_common_expr_0"
comparePlans(
Optimizer.execute(plan),
testRelation
Expand All @@ -61,23 +62,25 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val innerExpr = With(ref + ref, Seq(commonExprDef))
val innerCommonExprName = "_common_expr_" + commonExprDef.id
val innerCommonExprName = "_common_expr_0"

val b = testRelation.output.last
val outerCommonExprDef = CommonExpressionDef(innerExpr + b)
val outerRef = new CommonExpressionRef(outerCommonExprDef)
val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef))
val outerCommonExprName = "_common_expr_" + outerCommonExprDef.id
val outerCommonExprName = "_common_expr_0"

val plan = testRelation.select(outerExpr.as("col"))
val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b)
.as(outerCommonExprName)
val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)(
exprId = rewrittenOuterExpr.exprId)
comparePlans(
Optimizer.execute(plan),
testRelation
.select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*)
.select((testRelation.output :+ $"$innerCommonExprName" :+ rewrittenOuterExpr): _*)
.select(($"$outerCommonExprName" * $"$outerCommonExprName").as("col"))
.select((outerExprAttr * outerExprAttr).as("col"))
.analyze
)
}
Expand All @@ -87,7 +90,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val plan = testRelation.where(With(ref < 10 && ref > 0, Seq(commonExprDef)))
val commonExprName = "_common_expr_" + commonExprDef.id
val commonExprName = "_common_expr_0"
comparePlans(
Optimizer.execute(plan),
testRelation
Expand All @@ -104,7 +107,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val ref = new CommonExpressionRef(commonExprDef)
val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
val plan = testRelation.join(testRelation2, condition = Some(condition))
val commonExprName = "_common_expr_" + commonExprDef.id
val commonExprName = "_common_expr_0"
comparePlans(
Optimizer.execute(plan),
testRelation
Expand All @@ -121,7 +124,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val ref = new CommonExpressionRef(commonExprDef)
val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
val plan = testRelation.join(testRelation2, condition = Some(condition))
val commonExprName = "_common_expr_" + commonExprDef.id
val commonExprName = "_common_expr_0"
comparePlans(
Optimizer.execute(plan),
testRelation
Expand Down

0 comments on commit 2316fac

Please sign in to comment.