Skip to content

Commit

Permalink
[SPARK-41539][SQL] Remap stats and constraints against output in logi…
Browse files Browse the repository at this point in the history
…cal plan for LogicalRDD

### What changes were proposed in this pull request?

This PR proposes to remap stats and constraints against the output in logical for LogicalRDD, like we remap stats and constraints against the "new" output when we call newInstance.

### Why are the changes needed?

The output in logical plan and optimized plan can be "slightly" different (we observed the difference of exprId), and then the query can fail due to the invalid attribute reference(s) in stats and constraints for LogicalRDD.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Modified test cases.

Closes apache#39082 from HeartSaVioR/SPARK-41539.

Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
HeartSaVioR committed Dec 21, 2022
1 parent 4539260 commit 074e1b3
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -103,6 +104,8 @@ case class LogicalRDD(
originConstraints: Option[ExpressionSet] = None)
extends LeafNode with MultiInstanceRelation {

import LogicalRDD._

override protected final def otherCopyArgs: Seq[AnyRef] =
session :: originStats :: originConstraints :: Nil

Expand All @@ -122,22 +125,8 @@ case class LogicalRDD(
case e: Attribute => rewrite.getOrElse(e, e)
}.asInstanceOf[SortOrder])

val rewrittenStatistics = originStats.map { s =>
Statistics(
s.sizeInBytes,
s.rowCount,
AttributeMap[ColumnStat](s.attributeStats.map {
case (attr, v) => (rewrite.getOrElse(attr, attr), v)
}),
s.isRuntime
)
}

val rewrittenConstraints = originConstraints.map { c =>
c.map(_.transform {
case e: Attribute => rewrite.getOrElse(e, e)
})
}
val rewrittenStatistics = originStats.map(rewriteStatistics(_, rewrite))
val rewrittenConstraints = originConstraints.map(rewriteConstraints(_, rewrite))

LogicalRDD(
output.map(rewrite),
Expand All @@ -163,7 +152,7 @@ case class LogicalRDD(
override lazy val constraints: ExpressionSet = originConstraints.getOrElse(ExpressionSet())
}

object LogicalRDD {
object LogicalRDD extends Logging {
/**
* Create a new LogicalRDD based on existing Dataset. Stats and constraints are inherited from
* origin Dataset.
Expand All @@ -183,16 +172,80 @@ object LogicalRDD {
}
}

val logicalPlan = originDataset.logicalPlan
val optimizedPlan = originDataset.queryExecution.optimizedPlan
val executedPlan = originDataset.queryExecution.executedPlan

val (stats, constraints) = rewriteStatsAndConstraints(logicalPlan, optimizedPlan)

LogicalRDD(
originDataset.logicalPlan.output,
rdd,
firstLeafPartitioning(executedPlan.outputPartitioning),
executedPlan.outputOrdering,
isStreaming
)(originDataset.sparkSession, Some(optimizedPlan.stats), Some(optimizedPlan.constraints))
)(originDataset.sparkSession, stats, constraints)
}

private[sql] def buildOutputAssocForRewrite(
source: Seq[Attribute],
destination: Seq[Attribute]): Option[Map[Attribute, Attribute]] = {
// We check the name and type, allowing nullability, exprId, metadata, qualifier be different
// E.g. This could happen during optimization phase.
val rewrite = source.zip(destination).flatMap { case (attr1, attr2) =>
if (attr1.name == attr2.name && attr1.dataType == attr2.dataType) {
Some(attr1 -> attr2)
} else {
None
}
}.toMap

if (rewrite.size == source.size) {
Some(rewrite)
} else {
None
}
}

private[sql] def rewriteStatsAndConstraints(
logicalPlan: LogicalPlan,
optimizedPlan: LogicalPlan): (Option[Statistics], Option[ExpressionSet]) = {
val rewrite = buildOutputAssocForRewrite(optimizedPlan.output, logicalPlan.output)

rewrite.map { rw =>
val rewrittenStatistics = rewriteStatistics(optimizedPlan.stats, rw)
val rewrittenConstraints = rewriteConstraints(optimizedPlan.constraints, rw)

(Some(rewrittenStatistics), Some(rewrittenConstraints))
}.getOrElse {
// can't rewrite stats and constraints, give up
logWarning("The output columns are expected to the same (for name and type) for output " +
"between logical plan and optimized plan, but they aren't. output in logical plan: " +
s"${logicalPlan.output.map(_.simpleString(10))} / output in optimized plan: " +
s"${optimizedPlan.output.map(_.simpleString(10))}")

(None, None)
}
}

private[sql] def rewriteStatistics(
originStats: Statistics,
colRewrite: Map[Attribute, Attribute]): Statistics = {
Statistics(
originStats.sizeInBytes,
originStats.rowCount,
AttributeMap[ColumnStat](originStats.attributeStats.map {
case (attr, v) => (colRewrite.getOrElse(attr, attr), v)
}),
originStats.isRuntime)
}

private[sql] def rewriteConstraints(
originConstraints: ExpressionSet,
colRewrite: Map[Attribute, Attribute]): ExpressionSet = {
originConstraints.map(_.transform {
case e: Attribute => colRewrite.getOrElse(e, e)
})
}
}

Expand Down
29 changes: 27 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2121,12 +2121,25 @@ class DataFrameSuite extends QueryTest

withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
val df = Dataset.ofRows(spark, statsPlan)
// add some map-like operations which optimizer will optimize away, and make a divergence
// for output between logical plan and optimized plan
// logical plan
// Project [cb#6 AS cbool#12, cby#7 AS cbyte#13, ci#8 AS cint#14]
// +- Project [cbool#0 AS cb#6, cbyte#1 AS cby#7, cint#2 AS ci#8]
// +- OutputListAwareStatsTestPlan [cbool#0, cbyte#1, cint#2], 2, 16
// optimized plan
// OutputListAwareStatsTestPlan [cbool#0, cbyte#1, cint#2], 2, 16
.selectExpr("cbool AS cb", "cbyte AS cby", "cint AS ci")
.selectExpr("cb AS cbool", "cby AS cbyte", "ci AS cint")

// We can't leverage LogicalRDD.fromDataset here, since it triggers physical planning and
// there is no matching physical node for OutputListAwareStatsTestPlan.
val optimizedPlan = df.queryExecution.optimizedPlan
val rewrite = LogicalRDD.buildOutputAssocForRewrite(optimizedPlan.output,
df.logicalPlan.output)
val logicalRDD = LogicalRDD(
df.logicalPlan.output, spark.sparkContext.emptyRDD[InternalRow], isStreaming = true)(
spark, Some(df.queryExecution.optimizedPlan.stats), None)
spark, Some(LogicalRDD.rewriteStatistics(optimizedPlan.stats, rewrite.get)), None)

val stats = logicalRDD.computeStats()
val expectedStats = Statistics(sizeInBytes = expectedSize, rowCount = Some(2),
Expand Down Expand Up @@ -2164,12 +2177,24 @@ class DataFrameSuite extends QueryTest
val statsPlan = OutputListAwareConstraintsTestPlan(outputList = outputList)

val df = Dataset.ofRows(spark, statsPlan)
// add some map-like operations which optimizer will optimize away, and make a divergence
// for output between logical plan and optimized plan
// logical plan
// Project [cb#6 AS cbool#12, cby#7 AS cbyte#13, ci#8 AS cint#14]
// +- Project [cbool#0 AS cb#6, cbyte#1 AS cby#7, cint#2 AS ci#8]
// +- OutputListAwareConstraintsTestPlan [cbool#0, cbyte#1, cint#2]
// optimized plan
// OutputListAwareConstraintsTestPlan [cbool#0, cbyte#1, cint#2]
.selectExpr("cbool AS cb", "cbyte AS cby", "cint AS ci")
.selectExpr("cb AS cbool", "cby AS cbyte", "ci AS cint")

// We can't leverage LogicalRDD.fromDataset here, since it triggers physical planning and
// there is no matching physical node for OutputListAwareConstraintsTestPlan.
val optimizedPlan = df.queryExecution.optimizedPlan
val rewrite = LogicalRDD.buildOutputAssocForRewrite(optimizedPlan.output, df.logicalPlan.output)
val logicalRDD = LogicalRDD(
df.logicalPlan.output, spark.sparkContext.emptyRDD[InternalRow], isStreaming = true)(
spark, None, Some(df.queryExecution.optimizedPlan.constraints))
spark, None, Some(LogicalRDD.rewriteConstraints(optimizedPlan.constraints, rewrite.get)))

val constraints = logicalRDD.constraints
val expectedConstraints = buildExpectedConstraints(logicalRDD.output)
Expand Down

0 comments on commit 074e1b3

Please sign in to comment.