Skip to content

Commit

Permalink
[SPARK-48481][SQL][SS] Do not apply OptimizeOneRowPlan against stream…
Browse files Browse the repository at this point in the history
…ing Dataset

### What changes were proposed in this pull request?

This PR proposes to exclude streaming Dataset from the target of OptimizeOneRowPlan.

### Why are the changes needed?

The rule should not be applied to streaming source, since the number of rows it sees is just for current microbatch. It does not mean the streaming source will ever produce max 1 rows during lifetime of the query.

Suppose the case: the streaming query has a case where batch 0 runs with empty data in streaming source A which triggers the rule with Aggregate, and batch 1 runs with several data in streaming source A which no longer trigger the rule.

In the above scenario, this could fail the query as stateful operator is expected to be planned for every batches whereas here it is planned "selectively".

### Does this PR introduce _any_ user-facing change?

Yes, but the behavior can be reverted back with a new config, `spark.sql.streaming.optimizeOneRowPlan.enabled`, although I believe there should be really rare case where users have to turn the config on.

### How was this patch tested?

New UT.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46820 from HeartSaVioR/SPARK-48481.

Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
HeartSaVioR committed Jun 1, 2024
1 parent 114164b commit 1cecdc7
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.internal.SQLConf

/**
* The rule is applied both normal and AQE Optimizer. It optimizes plan using max rows:
Expand All @@ -31,19 +32,37 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
* it's grouping only(include the rewritten distinct plan), convert aggregate to project
* - if the max rows of the child of aggregate is less than or equal to 1,
* set distinct to false in all aggregate expression
*
* Note: the rule should not be applied to streaming source, since the number of rows it sees is
* just for current microbatch. It does not mean the streaming source will ever produce max 1
* rows during lifetime of the query. Suppose the case: the streaming query has a case where
* batch 0 runs with empty data in streaming source A which triggers the rule with Aggregate,
* and batch 1 runs with several data in streaming source A which no longer trigger the rule.
* In the above scenario, this could fail the query as stateful operator is expected to be planned
* for every batches whereas here it is planned "selectively".
*/
object OptimizeOneRowPlan extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
val enableForStreaming = conf.getConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED)

plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) {
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) => child
case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) =>
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) =>
Project(agg.aggregateExpressions, child)
case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) =>
case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) &&
isChildEligible(agg.child, enableForStreaming) =>
agg.transformExpressions {
case aggExpr: AggregateExpression if aggExpr.isDistinct =>
aggExpr.copy(isDistinct = false)
}
}
}

private def isChildEligible(child: LogicalPlan, enableForStreaming: Boolean): Boolean = {
enableForStreaming || !child.isStreaming
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2334,6 +2334,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED =
buildConf("spark.sql.streaming.optimizeOneRowPlan.enabled")
.internal()
.doc("When true, enable OptimizeOneRowPlan rule for the case where the child is a " +
"streaming Dataset. This is a fallback flag to revert the 'incorrect' behavior, hence " +
"this configuration must not be used without understanding in depth. Use this only to " +
"quickly recover failure in existing query!")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val VARIABLE_SUBSTITUTE_ENABLED =
buildConf("spark.sql.variable.substitute")
.doc("This enables substitution using syntax like `${var}`, `${system:var}`, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Timestamp
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions.{expr, lit, window}
import org.apache.spark.sql.internal.SQLConf

/**
* This test ensures that any optimizations done by Spark SQL optimizer are
Expand Down Expand Up @@ -451,4 +452,76 @@ class StreamingQueryOptimizationCorrectnessSuite extends StreamTest {
)
}
}

test("SPARK-48481: DISTINCT with empty stream source should retain AGGREGATE") {
def doTest(numExpectedStatefulOperatorsForOneEmptySource: Int): Unit = {
withTempView("tv1", "tv2") {
val inputStream1 = MemoryStream[Int]
val ds1 = inputStream1.toDS()
ds1.registerTempTable("tv1")

val inputStream2 = MemoryStream[Int]
val ds2 = inputStream2.toDS()
ds2.registerTempTable("tv2")

// DISTINCT is rewritten to AGGREGATE, hence an AGGREGATEs for each source
val unioned = spark.sql(
"""
| WITH u AS (
| SELECT DISTINCT value AS value FROM tv1
| ), v AS (
| SELECT DISTINCT value AS value FROM tv2
| )
| SELECT value FROM u UNION ALL SELECT value FROM v
|""".stripMargin
)

testStream(unioned, OutputMode.Update())(
MultiAddData(inputStream1, 1, 1, 2)(inputStream2, 1, 1, 2),
CheckNewAnswer(1, 2, 1, 2),
Execute { qe =>
val stateOperators = qe.lastProgress.stateOperators
// Aggregate should be "stateful" one
assert(stateOperators.length === 2)
stateOperators.zipWithIndex.foreach { case (op, id) =>
assert(op.numRowsUpdated === 2, s"stateful OP ID: $id")
}
},
AddData(inputStream2, 2, 2, 3),
// NOTE: this is probably far from expectation to have 2 as output given user intends
// deduplicate, but the behavior is still correct with rewritten node and output mode:
// Aggregate & Update mode.
// TODO: Probably we should disallow DISTINCT or rewrite to
// dropDuplicates(WithinWatermark) for streaming source?
CheckNewAnswer(2, 3),
Execute { qe =>
val stateOperators = qe.lastProgress.stateOperators
// Aggregate should be "stateful" one
assert(stateOperators.length === numExpectedStatefulOperatorsForOneEmptySource)
val opWithUpdatedRows = stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0)
assert(opWithUpdatedRows.length === 1)
// If this were dropDuplicates, numRowsUpdated should have been 1.
assert(opWithUpdatedRows.head._1.numRowsUpdated === 2,
s"stateful OP ID: ${opWithUpdatedRows.head._2}")
},
AddData(inputStream1, 4, 4, 5),
CheckNewAnswer(4, 5),
Execute { qe =>
val stateOperators = qe.lastProgress.stateOperators
assert(stateOperators.length === numExpectedStatefulOperatorsForOneEmptySource)
val opWithUpdatedRows = stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0)
assert(opWithUpdatedRows.length === 1)
assert(opWithUpdatedRows.head._1.numRowsUpdated === 2,
s"stateful OP ID: ${opWithUpdatedRows.head._2}")
}
)
}
}

doTest(numExpectedStatefulOperatorsForOneEmptySource = 2)

withSQLConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED.key -> "true") {
doTest(numExpectedStatefulOperatorsForOneEmptySource = 1)
}
}
}

0 comments on commit 1cecdc7

Please sign in to comment.