diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain index ca1534c2aca33..f2ada15eccb7d 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain @@ -1,3 +1,3 @@ -Aggregate [count(if ((_common_expr_1#0 = false)) null else _common_expr_1#0) AS count_if((a > 0))#0L] -+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_1#0] +Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L] ++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain index 44bc59d4aef2b..1811f770f8297 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain @@ -1,3 +1,3 @@ -Project [if ((_common_expr_2#0 = )) null else _common_expr_2#0 AS regexp_substr(g, \d{2}(a|b|m))#0] -+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_2#0] +Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS regexp_substr(g, \d{2}(a|b|m))#0] ++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_0#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index ea351e53080d8..073f60bca47f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -44,13 +44,13 @@ object RewriteWithExpression extends Rule[LogicalPlan] { case With(child, defs) if defs.forall(!_.containsPattern(WITH_EXPRESSION)) => val idToCheapExpr = mutable.HashMap.empty[Long, Expression] val idToNonCheapExpr = mutable.HashMap.empty[Long, Alias] - defs.foreach { commonExprDef => + defs.zipWithIndex.foreach { case (commonExprDef, index) => if (CollapseProject.isCheap(commonExprDef.child)) { idToCheapExpr(commonExprDef.id) = commonExprDef.child } else { // TODO: we should calculate the ref count and also inline the common expression // if it's ref count is 1. - val alias = Alias(commonExprDef.child, s"_common_expr_${commonExprDef.id}")() + val alias = Alias(commonExprDef.child, s"_common_expr_$index")() commonExprs += alias idToNonCheapExpr(commonExprDef.id) = alias } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index 11fa7f143bbf4..c4b08e6e5de85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{CommonExpressionDef, CommonExpressionRef, With} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.IntegerType class RewriteWithExpressionSuite extends PlanTest { @@ -46,7 +47,7 @@ class RewriteWithExpressionSuite extends PlanTest { val commonExprDef = CommonExpressionDef(a + a) val ref = new CommonExpressionRef(commonExprDef) val plan = testRelation.select(With(ref * ref, Seq(commonExprDef)).as("col")) - val commonExprName = "_common_expr_" + commonExprDef.id + val commonExprName = "_common_expr_0" comparePlans( Optimizer.execute(plan), testRelation @@ -61,23 +62,25 @@ class RewriteWithExpressionSuite extends PlanTest { val commonExprDef = CommonExpressionDef(a + a) val ref = new CommonExpressionRef(commonExprDef) val innerExpr = With(ref + ref, Seq(commonExprDef)) - val innerCommonExprName = "_common_expr_" + commonExprDef.id + val innerCommonExprName = "_common_expr_0" val b = testRelation.output.last val outerCommonExprDef = CommonExpressionDef(innerExpr + b) val outerRef = new CommonExpressionRef(outerCommonExprDef) val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef)) - val outerCommonExprName = "_common_expr_" + outerCommonExprDef.id + val outerCommonExprName = "_common_expr_0" val plan = testRelation.select(outerExpr.as("col")) val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b) .as(outerCommonExprName) + val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)( + exprId = rewrittenOuterExpr.exprId) comparePlans( Optimizer.execute(plan), testRelation .select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*) .select((testRelation.output :+ $"$innerCommonExprName" :+ rewrittenOuterExpr): _*) - .select(($"$outerCommonExprName" * $"$outerCommonExprName").as("col")) + .select((outerExprAttr * outerExprAttr).as("col")) .analyze ) } @@ -87,7 +90,7 @@ class RewriteWithExpressionSuite extends PlanTest { val commonExprDef = CommonExpressionDef(a + a) val ref = new CommonExpressionRef(commonExprDef) val plan = testRelation.where(With(ref < 10 && ref > 0, Seq(commonExprDef))) - val commonExprName = "_common_expr_" + commonExprDef.id + val commonExprName = "_common_expr_0" comparePlans( Optimizer.execute(plan), testRelation @@ -104,7 +107,7 @@ class RewriteWithExpressionSuite extends PlanTest { val ref = new CommonExpressionRef(commonExprDef) val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprName = "_common_expr_" + commonExprDef.id + val commonExprName = "_common_expr_0" comparePlans( Optimizer.execute(plan), testRelation @@ -121,7 +124,7 @@ class RewriteWithExpressionSuite extends PlanTest { val ref = new CommonExpressionRef(commonExprDef) val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprName = "_common_expr_" + commonExprDef.id + val commonExprName = "_common_expr_0" comparePlans( Optimizer.execute(plan), testRelation