From 8e3b9d40dca6d3b6510d9d06859596b37dfb2ef8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Mar 2023 13:44:43 +0900 Subject: [PATCH] [SPARK-42376][SS] Introduce watermark propagation among operators ### What changes were proposed in this pull request? This PR proposes to introduce watermark propagation among operators via simulation, which enables the workload of "stream-stream time interval join followed by stateful operator". As of now, Spark considers all stateful operators to have same input watermark and output watermark, which is insufficient to handle stream-stream time interval join. It can delay joined output more than global watermark, based on the join criteria. (e.g. `leftTime BETWEEN rightTime - INTERVAL 30 seconds AND rightTime + INTERVAL 40 seconds`). To address this, the join operator should be able to produce "delayed" watermark to the downstream operator. That said, Spark has to "propagate" watermark among operators, flowing through leaf node(s) to root (going downstream). This PR introduces a new interface `WatermarkPropagator` which performs simulation of watermark propagation based on the watermark. There are three implementations for this interface: 1. NoOpWatermarkPropagator: Do nothing. This is used for initializing dummy IncrementalExecution. 2. UseSingleWatermarkPropagator: Uses a single global watermark for late events and eviction. This is used for compatibility mode (`spark.sql.streaming.statefulOperator.allowMultiple` to `false`). 3. PropagateWatermarkSimulator: simulates propagation of watermark among operators. The simulation algorithm used in `PropagateWatermarkSimulator` traverses the physical plan tree via post-order (children first) to calculate (input watermark, output watermark) for all nodes. For each node, below logic is applied: - Input watermark for specific node is decided by `min(input watermarks from all children)`. - Children providing no input watermark are excluded. - If there is no valid input watermark from children, it's considered as there is no input watermark. - Output watermark for specific node is decided as following: - watermark nodes: origin watermark value (global watermark). - stateless nodes: same as input watermark. - stateful nodes: the return value of `op.produceOutputWatermark(input watermark)`. (if there is no input watermark, there is no output watermark) Once the algorithm traverses the physical plan tree, the association between stateful operator and input watermark will be constructed. The association is cached after calculation and being used across microbatches, till Spark determines the association as no longer to be used. As mentioned like `op.produceOutputWatermark()` in above, this PR also adds a new method `produceOutputWatermark` in StateStoreWriter, which requires each stateful operator to calculate output watermark based on given input watermark. In most cases, this is same as the criteria of state eviction, as most stateful operators produce the output from two different kinds: 1. without buffering (event time > input watermark) 2. with buffering (state) The state eviction happens when event time exceeds a "certain threshold of timestamp", which denotes a lower bound of event time values for output (output watermark). Since most stateful operators construct the predicate for state eviction based on watermark in batch planning phase, they can produce an output watermark once Spark provides an input watermark. Please refer to the walkthrough code comment for the test case of `stream-stream time interval left outer join -> aggregation, append mode`. There are several additional decisions made by this PR which introduces backward incompatibility. 1. Re-definition of watermark will be disallowed. Technically, each watermark node can track its own value of watermark and PropagateWatermarkSimulator can propagate these values correctly. (multiple origins) While this may help to accelerate processing faster stream (as all watermarks don't need to follow the slowest one till join/union), this involves more complicated questions on UX perspective, as all UX about watermark is based on global watermark. This seems harder to address, hence this PR proposes to retain the global watermark as it is. Since we want to produce watermark as the single origin value, redefinition of watermark does not make sense. Consider stream-stream time interval join followed by another watermark node. Which is the right value of output watermark for another watermark node? delayed watermark, or global watermark? 2. stateful operator will not allow multiple event time columns being defined in the input DataFrame. The output of stream-stream join may have two event time columns, which is ambiguous on late record filtering and eviction. Currently the first appeared event time column has been picked up for late record filtering and eviction, which is ambiguous to reason about the correctness. After this PR, Spark will throw an exception. The downstream operator has to pick up only one of event time column to continue. Turning off the flag `spark.sql.streaming.statefulOperator.allowMultiple` will restore the old behavior from the above. (https://issues.apache.org/jira/browse/SPARK-42549 is filed to remove this limitation later.) ### Why are the changes needed? stream-stream time interval join followed by stateful operator is not supported yet, and this PR unblocks it. ### Does this PR introduce _any_ user-facing change? Yes, here is a list of user facing changes (some are backward incompatibility changes, though we have compatibility flag): - stream-stream time-interval join followed by stateful operator will be allowed. - Re-definition of watermark will be disallowed. - stateful operator will not allow multiple event time columns being defined in the input DataFrame. ### How was this patch tested? New & modified test cases. Closes #39931 from HeartSaVioR/SPARK-42376. Authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../structured-streaming-programming-guide.md | 4 +- .../UnsupportedOperationChecker.scala | 4 - .../plans/logical/EventTimeWatermark.scala | 4 + .../analysis/UnsupportedOperationsSuite.scala | 25 +- .../spark/sql/execution/QueryExecution.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 9 +- .../streaming/IncrementalExecution.scala | 239 ++++++--- .../streaming/MicroBatchExecution.scala | 8 +- .../StreamingSymmetricHashJoinExec.scala | 28 +- .../StreamingSymmetricHashJoinHelper.scala | 100 +++- .../streaming/WatermarkPropagator.scala | 322 ++++++++++++ .../continuous/ContinuousExecution.scala | 3 +- .../streaming/statefulOperators.scala | 92 +++- .../streaming/EventTimeWatermarkSuite.scala | 131 ++++- .../MultiStatefulOperatorsSuite.scala | 475 ++++++++++++++++-- .../StreamingDeduplicationSuite.scala | 1 - 16 files changed, 1274 insertions(+), 176 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkPropagator.scala diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2fc742736264f..cf7f0ab6e15d7 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -2067,9 +2067,7 @@ Some of them are as follows. for more details. - Chaining multiple stateful operations on streaming Datasets is not supported with Update and Complete mode. - - In addition, below operations followed by other stateful operation is not supported in Append mode. - - stream-stream time interval join (inner/outer) - - flatMapGroupsWithState + - In addition, mapGroupsWithState/flatMapGroupsWithState operation followed by other stateful operation is not supported in Append mode. - A known workaround is to split your streaming query into multiple queries having a single stateful operation per each query, and ensure end-to-end exactly once per query. Ensuring end-to-end exactly once for the last query is optional. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 06581e23d5854..69ebe09667deb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -21,7 +21,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison, CurrentDate, CurrentTimestampLike, Expression, GreaterThan, GreaterThanOrEqual, GroupingSets, LessThan, LessThanOrEqual, LocalTimestamp, MonotonicallyIncreasingID, SessionWindow} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -84,9 +83,6 @@ object UnsupportedOperationChecker extends Logging { */ private def ifCannotBeFollowedByStatefulOperation( p: LogicalPlan, outputMode: OutputMode): Boolean = p match { - case ExtractEquiJoinKeys(_, _, _, otherCondition, _, left, right, _) => - left.isStreaming && right.isStreaming && - otherCondition.isDefined && hasRangeExprAgainstEventTimeCol(otherCondition.get) // FlatMapGroupsWithState configured with event time case f @ FlatMapGroupsWithState(_, _, _, _, _, _, _, _, _, timeout, _, _, _, _, _, _) if f.isStreaming && timeout == GroupStateTimeout.EventTimeTimeout => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index aaa11b4382ea1..32a9030ff62b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -45,6 +45,10 @@ case class EventTimeWatermark( final override val nodePatterns: Seq[TreePattern] = Seq(EVENT_TIME_WATERMARK) // Update the metadata on the eventTime column to include the desired delay. + // This is not allowed by default - WatermarkPropagator will throw an exception. We keep the + // logic here because we also maintain the compatibility flag. (See + // SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE for details.) + // TODO: Disallow updating the metadata once we remove the compatibility flag. override val output: Seq[Attribute] = child.output.map { a => if (a semanticEquals eventTime) { val delayMs = EventTimeWatermark.getDelayMs(delay) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 64c5ea3f5b19f..f9fd02b86e904 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -541,8 +541,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { isMapGroupsWithState = true, GroupStateTimeout.ProcessingTimeTimeout(), streamRelation)), outputMode = Append) - // stream-stream relation, time interval join can't be followed by any stateful operators - assertFailOnGlobalWatermarkLimit( + // stream-stream relation, time interval join can be followed by any stateful operators + assertPassOnGlobalWatermarkLimit( "multiple stateful ops - stream-stream time-interval join followed by agg", Aggregate(Nil, aggExprs("c"), streamRelation.join(streamRelation, joinType = Inner, @@ -550,7 +550,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { attributeWithWatermark > attributeWithWatermark + 10))), outputMode = Append) - // stream-stream relation, only equality join can be followed by any stateful operators + // stream-stream relation, equality join can be followed by any stateful operators assertPassOnGlobalWatermarkLimit( "multiple stateful ops - stream-stream equality join followed by agg", Aggregate(Nil, aggExprs("c"), @@ -601,10 +601,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { outputMode = outputMode) } - // Deduplication, if on event time column, is a stateful operator - // and cannot be placed after join - assertFailOnGlobalWatermarkLimit( - "multiple stateful ops - stream-stream time interval join followed by" + + assertPassOnGlobalWatermarkLimit( + "multiple stateful ops - stream-stream time interval join followed by " + "dedup (with event-time)", Deduplicate(Seq(attributeWithWatermark), streamRelation.join(streamRelation, joinType = Inner, @@ -612,11 +610,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { attributeWithWatermark > attributeWithWatermark + 10))), outputMode = Append) - // Deduplication, if not on event time column, - // although it is still a stateful operator, - // it can be placed after join assertPassOnGlobalWatermarkLimit( - "multiple stateful ops - stream-stream time interval join followed by" + + "multiple stateful ops - stream-stream time interval join followed by " + "dedup (without event-time)", Deduplicate(Seq(att), streamRelation.join(streamRelation, joinType = Inner, @@ -624,15 +619,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { attributeWithWatermark > attributeWithWatermark + 10))), outputMode = Append) - // for a stream-stream join followed by a stateful operator, - // if the join is keyed on time-interval inequality conditions (inequality on watermarked cols), - // should fail. - // if the join is keyed on time-interval equality conditions -> should pass Seq(Inner, LeftOuter, RightOuter, FullOuter).foreach { joinType => - assertFailOnGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( s"streaming aggregation after " + - s"stream-stream $joinType join keyed on time inequality in Append mode are not supported", + s"stream-stream $joinType join keyed on time interval in Append mode are not supported", streamRelation.join(streamRelation, joinType = joinType, condition = Some(attributeWithWatermark === attribute && attributeWithWatermark < attributeWithWatermark + 10)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ea713e390e059..db18cdeaa65b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableU import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery -import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.Utils @@ -244,7 +244,8 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0), + WatermarkPropagator.noop()) } else { this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 760681e81c916..d30b9ad116f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -90,19 +90,24 @@ trait FlatMapGroupsWithStateExecBase override def shortName: String = "flatMapGroupsWithState" - override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { timeoutConf match { case ProcessingTimeTimeout => true // Always run batches to process timeouts case EventTimeTimeout => // Process another non-data batch only if the watermark has changed in this executed plan eventTimeWatermarkForEviction.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get + newInputWatermark > eventTimeWatermarkForEviction.get case _ => false } } + // There is no guarantee that any of the column in the output is bound to the watermark. The + // user function is quite flexible. Hence Spark does not support the stateful operator(s) after + // (flat)MapGroupsWithState. + override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = None + /** * Process data by applying the user defined function on a per partition basis. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index e5e4dc7d0dcb6..8bf3440c8386c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -49,7 +49,8 @@ class IncrementalExecution( val runId: UUID, val currentBatchId: Long, val prevOffsetSeqMetadata: Option[OffsetSeqMetadata], - val offsetSeqMetadata: OffsetSeqMetadata) + val offsetSeqMetadata: OffsetSeqMetadata, + val watermarkPropagator: WatermarkPropagator) extends QueryExecution(sparkSession, logicalPlan) with Logging { // Modified planner with stateful operations. @@ -97,6 +98,9 @@ class IncrementalExecution( } } + private val allowMultipleStatefulOperators: Boolean = + sparkSession.sessionState.conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE) + /** * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. @@ -113,20 +117,31 @@ class IncrementalExecution( numStateStores) } - // Watermarks to use for late record filtering and state eviction in stateful operators. - // Using the previous watermark for late record filtering is a Spark behavior change so we allow - // this to be disabled. - val eventTimeWatermarkForEviction = offsetSeqMetadata.batchWatermarkMs - val eventTimeWatermarkForLateEvents = - if (sparkSession.conf.get(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) { - prevOffsetSeqMetadata.getOrElse(offsetSeqMetadata).batchWatermarkMs - } else { - eventTimeWatermarkForEviction - } + sealed trait SparkPlanPartialRule { + val rule: PartialFunction[SparkPlan, SparkPlan] + } - /** Locates save/restore pairs surrounding aggregation. */ - val state = new Rule[SparkPlan] { + object ShufflePartitionsRule extends SparkPlanPartialRule { + override val rule: PartialFunction[SparkPlan, SparkPlan] = { + // NOTE: we should include all aggregate execs here which are used in streaming aggregations + case a: SortAggregateExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + + case a: HashAggregateExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + + case a: ObjectHashAggregateExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + case a: MergingSessionsExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + + case a: UpdatingSessionsExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + } + } + + object ConvertLocalLimitRule extends SparkPlanPartialRule { /** * Ensures that this plan DOES NOT have any stateful operation in it whose pipelined execution * depends on this plan. In other words, this function returns true if this plan does @@ -153,33 +168,27 @@ class IncrementalExecution( !statefulOpFound } - override def apply(plan: SparkPlan): SparkPlan = plan transform { - // NOTE: we should include all aggregate execs here which are used in streaming aggregations - case a: SortAggregateExec if a.isStreaming => - a.copy(numShufflePartitions = Some(numStateStores)) - - case a: HashAggregateExec if a.isStreaming => - a.copy(numShufflePartitions = Some(numStateStores)) - - case a: ObjectHashAggregateExec if a.isStreaming => - a.copy(numShufflePartitions = Some(numStateStores)) - - case a: MergingSessionsExec if a.isStreaming => - a.copy(numShufflePartitions = Some(numStateStores)) - - case a: UpdatingSessionsExec if a.isStreaming => - a.copy(numShufflePartitions = Some(numStateStores)) + override val rule: PartialFunction[SparkPlan, SparkPlan] = { + case StreamingLocalLimitExec(limit, child) if hasNoStatefulOp(child) => + // Optimize limit execution by replacing StreamingLocalLimitExec (consumes the iterator + // completely) to LocalLimitExec (does not consume the iterator) when the child plan has + // no stateful operator (i.e., consuming the iterator is not needed). + LocalLimitExec(limit, child) + } + } + object StateOpIdRule extends SparkPlanPartialRule { + override val rule: PartialFunction[SparkPlan, SparkPlan] = { case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion, - UnaryExecNode(agg, - StateStoreRestoreExec(_, None, _, child))) => + UnaryExecNode(agg, + StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), - eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), - eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( @@ -189,35 +198,35 @@ class IncrementalExecution( child) :: Nil)) case SessionWindowStateStoreSaveExec(keys, session, None, None, None, None, - stateFormatVersion, - UnaryExecNode(agg, - SessionWindowStateStoreRestoreExec(_, _, None, None, None, _, child))) => - val aggStateInfo = nextStatefulOperationStateInfo - SessionWindowStateStoreSaveExec( - keys, - session, - Some(aggStateInfo), - Some(outputMode), - eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), - eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), - stateFormatVersion, - agg.withNewChildren( - SessionWindowStateStoreRestoreExec( - keys, - session, - Some(aggStateInfo), - eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), - eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), - stateFormatVersion, - child) :: Nil)) + stateFormatVersion, + UnaryExecNode(agg, + SessionWindowStateStoreRestoreExec(_, _, None, None, None, _, child))) => + val aggStateInfo = nextStatefulOperationStateInfo + SessionWindowStateStoreSaveExec( + keys, + session, + Some(aggStateInfo), + Some(outputMode), + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, + stateFormatVersion, + agg.withNewChildren( + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, + stateFormatVersion, + child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None, None) => StreamingDeduplicateExec( keys, child, Some(nextStatefulOperationStateInfo), - eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), - eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction)) + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None) case m: FlatMapGroupsWithStateExec => // We set this to true only for the first batch of the streaming query. @@ -225,8 +234,8 @@ class IncrementalExecution( m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), - eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, hasInitialState = hasInitialState ) @@ -234,30 +243,108 @@ class IncrementalExecution( m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), - eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction) + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None ) case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), - eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), - eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), - stateWatermarkPredicates = - StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( - j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(eventTimeWatermarkForEviction))) + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None + ) case l: StreamingGlobalLimitExec => l.copy( stateInfo = Some(nextStatefulOperationStateInfo), outputMode = Some(outputMode)) + } + } - case StreamingLocalLimitExec(limit, child) if hasNoStatefulOp(child) => - // Optimize limit execution by replacing StreamingLocalLimitExec (consumes the iterator - // completely) to LocalLimitExec (does not consume the iterator) when the child plan has - // no stateful operator (i.e., consuming the iterator is not needed). - LocalLimitExec(limit, child) + object WatermarkPropagationRule extends SparkPlanPartialRule { + private def inputWatermarkForLateEvents(stateInfo: StatefulOperatorStateInfo): Option[Long] = { + Some(watermarkPropagator.getInputWatermarkForLateEvents(currentBatchId, + stateInfo.operatorId)) + } + + private def inputWatermarkForEviction(stateInfo: StatefulOperatorStateInfo): Option[Long] = { + Some(watermarkPropagator.getInputWatermarkForEviction(currentBatchId, stateInfo.operatorId)) + } + + override val rule: PartialFunction[SparkPlan, SparkPlan] = { + case s: StateStoreSaveExec if s.stateInfo.isDefined => + s.copy( + eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(s.stateInfo.get), + eventTimeWatermarkForEviction = inputWatermarkForEviction(s.stateInfo.get) + ) + + case s: SessionWindowStateStoreSaveExec if s.stateInfo.isDefined => + s.copy( + eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(s.stateInfo.get), + eventTimeWatermarkForEviction = inputWatermarkForEviction(s.stateInfo.get) + ) + + case s: SessionWindowStateStoreRestoreExec if s.stateInfo.isDefined => + s.copy( + eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(s.stateInfo.get), + eventTimeWatermarkForEviction = inputWatermarkForEviction(s.stateInfo.get) + ) + + case s: StreamingDeduplicateExec if s.stateInfo.isDefined => + s.copy( + eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(s.stateInfo.get), + eventTimeWatermarkForEviction = inputWatermarkForEviction(s.stateInfo.get) + ) + + case m: FlatMapGroupsWithStateExec if m.stateInfo.isDefined => + m.copy( + eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(m.stateInfo.get), + eventTimeWatermarkForEviction = inputWatermarkForEviction(m.stateInfo.get) + ) + + case m: FlatMapGroupsInPandasWithStateExec if m.stateInfo.isDefined => + m.copy( + eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(m.stateInfo.get), + eventTimeWatermarkForEviction = inputWatermarkForEviction(m.stateInfo.get) + ) + + case j: StreamingSymmetricHashJoinExec => + val iwLateEvents = inputWatermarkForLateEvents(j.stateInfo.get) + val iwEviction = inputWatermarkForEviction(j.stateInfo.get) + j.copy( + eventTimeWatermarkForLateEvents = iwLateEvents, + eventTimeWatermarkForEviction = iwEviction, + stateWatermarkPredicates = + StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( + j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, + iwEviction, !allowMultipleStatefulOperators) + ) + } + } + + val state = new Rule[SparkPlan] { + private def simulateWatermarkPropagation(plan: SparkPlan): Unit = { + val watermarkForPrevBatch = prevOffsetSeqMetadata.map(_.batchWatermarkMs).getOrElse(0L) + val watermarkForCurrBatch = offsetSeqMetadata.batchWatermarkMs + + // This is to simulate watermark propagation for late events. + watermarkPropagator.propagate(currentBatchId - 1, plan, watermarkForPrevBatch) + // This is to simulate watermark propagation for eviction. + watermarkPropagator.propagate(currentBatchId, plan, watermarkForCurrBatch) + } + + private lazy val composedRule: PartialFunction[SparkPlan, SparkPlan] = { + // There should be no same pattern across rules in the list. + val rulesToCompose = Seq(ShufflePartitionsRule, ConvertLocalLimitRule, StateOpIdRule) + .map(_.rule) + + rulesToCompose.reduceLeft { (ruleA, ruleB) => ruleA orElse ruleB } + } + + override def apply(plan: SparkPlan): SparkPlan = { + val planWithStateOpId = plan transform composedRule + simulateWatermarkPropagation(planWithStateOpId) + planWithStateOpId transform WatermarkPropagationRule.rule } } @@ -269,10 +356,18 @@ class IncrementalExecution( /** * Should the MicroBatchExecution run another batch based on this execution and the current * updated metadata. + * + * This method performs simulation of watermark propagation against new batch (which is not + * planned yet), which is required for asking the needs of another batch to each stateful + * operator. */ def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + val tentativeBatchId = currentBatchId + 1 + watermarkPropagator.propagate(tentativeBatchId, executedPlan, newMetadata.batchWatermarkMs) executedPlan.collect { - case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + case p: StateStoreWriter => p.shouldRunAnotherBatch( + watermarkPropagator.getInputWatermarkForEviction(tentativeBatchId, + p.stateInfo.get.operatorId)) }.exists(_ == true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 5aece36e2f025..65a7032814896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -212,6 +212,8 @@ class MicroBatchExecution( logInfo(s"Query $prettyIdString was stopped") } + private val watermarkPropagator = WatermarkPropagator(sparkSession.sessionState.conf) + override def cleanup(): Unit = { super.cleanup() @@ -713,7 +715,8 @@ class MicroBatchExecution( runId, currentBatchId, offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1), - offsetSeqMetadata) + offsetSeqMetadata, + watermarkPropagator) lastExecution.executedPlan // Force the lazy generation of execution plan } @@ -789,6 +792,9 @@ class MicroBatchExecution( val prevBatchOff = offsetLog.get(currentBatchId - 1) if (prevBatchOff.isDefined) { commitSources(prevBatchOff.get) + // The watermark for each batch is given as (prev. watermark, curr. watermark), hence + // we can't purge the previous version of watermark. + watermarkPropagator.purge(currentBatchId - 2) } else { throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index dfde4156812b1..2445dcd519ece 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -23,14 +23,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.SymmetricHashJoinStateManager.KeyToValuePair -import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -190,6 +189,8 @@ case class StreamingSymmetricHashJoinExec( private val hadoopConfBcast = sparkContext.broadcast( new SerializableConfiguration(SessionState.newHadoopConf( sparkContext.hadoopConfiguration, conf))) + private val allowMultipleStatefulOperators = + conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE) val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) @@ -219,14 +220,14 @@ case class StreamingSymmetricHashJoinExec( override def shortName: String = "symmetricHashJoin" - override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { val watermarkUsedForStateCleanup = stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty // Latest watermark value is more than that used in this previous executed plan val watermarkHasChanged = eventTimeWatermarkForEviction.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get + newInputWatermark > eventTimeWatermarkForEviction.get watermarkUsedForStateCleanup && watermarkHasChanged } @@ -544,6 +545,8 @@ case class StreamingSymmetricHashJoinExec( } private[this] var updatedStateRowsCount = 0 + private[this] val allowMultipleStatefulOperators: Boolean = + conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE) /** * Generate joined rows by consuming input from this side, and matching it with the buffered @@ -557,7 +560,8 @@ case class StreamingSymmetricHashJoinExec( generateJoinedRow: (InternalRow, InternalRow) => JoinedRow) : Iterator[InternalRow] = { - val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) + val watermarkAttribute = WatermarkSupport.findEventTimeColumn(inputAttributes, + allowMultipleEventTimeColumns = !allowMultipleStatefulOperators) val nonLateRows = WatermarkSupport.watermarkExpression( watermarkAttribute, eventTimeWatermarkForLateEvents) match { @@ -699,4 +703,18 @@ case class StreamingSymmetricHashJoinExec( } else { Nil } + + // This operator will evict based on the state watermark on both side of inputs; we would like + // to let users leverage both sides of event time column for output of join, so the watermark + // must be lower bound of both sides of event time column. The lower bound of event time column + // for each side is determined by state watermark, hence we take a minimum of (left state + // watermark, right state watermark, input watermark) to decide the output watermark. + override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = { + val (leftStateWatermark, rightStateWatermark) = + StreamingSymmetricHashJoinHelper.getStateWatermark( + left.output, right.output, leftKeys, rightKeys, condition.full, Some(inputWatermarkMs), + !allowMultipleStatefulOperators) + + Some((leftStateWatermark ++ rightStateWatermark ++ Some(inputWatermarkMs)).min) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 7bf6381e08ffe..49e1f5e8ba12a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -130,36 +130,74 @@ object StreamingSymmetricHashJoinHelper extends Logging { } } - /** Get the predicates defining the state watermarks for both sides of the join */ - def getStateWatermarkPredicates( + def getStateWatermark( leftAttributes: Seq[Attribute], rightAttributes: Seq[Attribute], leftKeys: Seq[Expression], rightKeys: Seq[Expression], condition: Option[Expression], - eventTimeWatermarkForEviction: Option[Long]): JoinStateWatermarkPredicates = { + eventTimeWatermarkForEviction: Option[Long], + allowMultipleEventTimeColumns: Boolean): (Option[Long], Option[Long]) = { + // Perform assertions against multiple event time columns in the same DataFrame. This method + // assumes there is only one event time column per each side (left / right) and it is not very + // clear to reason about the correctness if there are multiple event time columns. Disallow to + // be conservative. + WatermarkSupport.findEventTimeColumn(leftAttributes, + allowMultipleEventTimeColumns = allowMultipleEventTimeColumns) + WatermarkSupport.findEventTimeColumn(rightAttributes, + allowMultipleEventTimeColumns = allowMultipleEventTimeColumns) - // Join keys of both sides generate rows of the same fields, that is, same sequence of data - // types. If one side (say left side) has a column (say timestamp) that has a watermark on it, - // then it will never consider joining keys that are < state key watermark (i.e. event time - // watermark). On the other side (i.e. right side), even if there is no watermark defined, - // there has to be an equivalent column (i.e., timestamp). And any right side data that has the - // timestamp < watermark will not match will not match with left side data, as the left side get - // filtered with the explicitly defined watermark. So, the watermark in timestamp column in - // left side keys effectively causes the timestamp on the right side to have a watermark. - // We will use the ordinal of the left timestamp in the left keys to find the corresponding - // right timestamp in the right keys. - val joinKeyOrdinalForWatermark: Option[Int] = { - leftKeys.zipWithIndex.collectFirst { - case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index - } orElse { - rightKeys.zipWithIndex.collectFirst { - case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index - } + val joinKeyOrdinalForWatermark: Option[Int] = findJoinKeyOrdinalForWatermark( + leftKeys, rightKeys) + + def getOneSideStateWatermark( + oneSideInputAttributes: Seq[Attribute], + otherSideInputAttributes: Seq[Attribute]): Option[Long] = { + val isWatermarkDefinedOnInput = oneSideInputAttributes.exists(_.metadata.contains(delayKey)) + val isWatermarkDefinedOnJoinKey = joinKeyOrdinalForWatermark.isDefined + + if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 in the StreamingSymmetricHashJoinExec docs + eventTimeWatermarkForEviction + } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs + StreamingJoinHelper.getStateValueWatermark( + attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), + attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), + condition, + eventTimeWatermarkForEviction) + } else { + None } } + val leftStateWatermark = getOneSideStateWatermark(leftAttributes, rightAttributes) + val rightStateWatermark = getOneSideStateWatermark(rightAttributes, leftAttributes) + + (leftStateWatermark, rightStateWatermark) + } + + /** Get the predicates defining the state watermarks for both sides of the join */ + def getStateWatermarkPredicates( + leftAttributes: Seq[Attribute], + rightAttributes: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + condition: Option[Expression], + eventTimeWatermarkForEviction: Option[Long], + useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = { + + // Perform assertions against multiple event time columns in the same DataFrame. This method + // assumes there is only one event time column per each side (left / right) and it is not very + // clear to reason about the correctness if there are multiple event time columns. Disallow to + // be conservative. + WatermarkSupport.findEventTimeColumn(leftAttributes, + allowMultipleEventTimeColumns = useFirstEventTimeColumn) + WatermarkSupport.findEventTimeColumn(rightAttributes, + allowMultipleEventTimeColumns = useFirstEventTimeColumn) + + val joinKeyOrdinalForWatermark: Option[Int] = findJoinKeyOrdinalForWatermark( + leftKeys, rightKeys) + def getOneSideStateWatermarkPredicate( oneSideInputAttributes: Seq[Attribute], oneSideJoinKeys: Seq[Expression], @@ -197,6 +235,28 @@ object StreamingSymmetricHashJoinHelper extends Logging { JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate) } + private def findJoinKeyOrdinalForWatermark( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Option[Int] = { + // Join keys of both sides generate rows of the same fields, that is, same sequence of data + // types. If one side (say left side) has a column (say timestamp) that has a watermark on it, + // then it will never consider joining keys that are < state key watermark (i.e. event time + // watermark). On the other side (i.e. right side), even if there is no watermark defined, + // there has to be an equivalent column (i.e., timestamp). And any right side data that has the + // timestamp < watermark will not match will not match with left side data, as the left side get + // filtered with the explicitly defined watermark. So, the watermark in timestamp column in + // left side keys effectively causes the timestamp on the right side to have a watermark. + // We will use the ordinal of the left timestamp in the left keys to find the corresponding + // right timestamp in the right keys. + leftKeys.zipWithIndex.collectFirst { + case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index + } orElse { + rightKeys.zipWithIndex.collectFirst { + case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index + } + } + } + /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkPropagator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkPropagator.scala new file mode 100644 index 0000000000000..6f3725bebb9ab --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkPropagator.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.{util => jutil} + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Interface for propagating watermark. The implementation is not required to be thread-safe, + * as all methods are expected to be called from the query execution thread. + * (The guarantee may change on further improvements on Structured Streaming - update + * implementations if we change the guarantee.) + */ +sealed trait WatermarkPropagator { + /** + * Request to propagate watermark among operators based on origin watermark value. The result + * should be input watermark per stateful operator, which Spark will request the value by calling + * getInputWatermarkXXX with operator ID. + * + * It is recommended for implementation to cache the result, as Spark can request the propagation + * multiple times with the same batch ID and origin watermark value. + */ + def propagate(batchId: Long, plan: SparkPlan, originWatermark: Long): Unit + + /** Provide the calculated input watermark for late events for given stateful operator. */ + def getInputWatermarkForLateEvents(batchId: Long, stateOpId: Long): Long + + /** Provide the calculated input watermark for eviction for given stateful operator. */ + def getInputWatermarkForEviction(batchId: Long, stateOpId: Long): Long + + /** + * Request to clean up cached result on propagation. Spark will call this method when the given + * batch ID will be likely to be not re-executed. + */ + def purge(batchId: Long): Unit +} + +/** + * Do nothing. This is dummy implementation to help creating a dummy IncrementalExecution instance. + */ +object NoOpWatermarkPropagator extends WatermarkPropagator { + def propagate(batchId: Long, plan: SparkPlan, originWatermark: Long): Unit = {} + def getInputWatermarkForLateEvents(batchId: Long, stateOpId: Long): Long = Long.MinValue + def getInputWatermarkForEviction(batchId: Long, stateOpId: Long): Long = Long.MinValue + def purge(batchId: Long): Unit = {} +} + +/** + * This implementation uses a single global watermark for late events and eviction. + * + * This implementation provides the behavior before Structured Streaming supports multiple stateful + * operators. (prior to SPARK-40925) This is only used for compatibility mode. + */ +class UseSingleWatermarkPropagator extends WatermarkPropagator { + // We use treemap to sort the key (batchID) and evict old batch IDs efficiently. + private val batchIdToWatermark: jutil.TreeMap[Long, Long] = new jutil.TreeMap[Long, Long]() + + private def isInitialized(batchId: Long): Boolean = batchIdToWatermark.containsKey(batchId) + + override def propagate(batchId: Long, plan: SparkPlan, originWatermark: Long): Unit = { + if (batchId < 0) { + // no-op + } else if (isInitialized(batchId)) { + val cached = batchIdToWatermark.get(batchId) + assert(cached == originWatermark, + s"Watermark has been changed for the same batch ID! Batch ID: $batchId, " + + s"Value in cache: $cached, value given: $originWatermark") + } else { + batchIdToWatermark.put(batchId, originWatermark) + } + } + + private def getInputWatermark(batchId: Long, stateOpId: Long): Long = { + if (batchId < 0) { + 0 + } else { + assert(isInitialized(batchId), s"Watermark for batch ID $batchId is not yet set!") + batchIdToWatermark.get(batchId) + } + } + + def getInputWatermarkForLateEvents(batchId: Long, stateOpId: Long): Long = + getInputWatermark(batchId, stateOpId) + + def getInputWatermarkForEviction(batchId: Long, stateOpId: Long): Long = + getInputWatermark(batchId, stateOpId) + + override def purge(batchId: Long): Unit = { + val keyIter = batchIdToWatermark.keySet().iterator() + var stopIter = false + while (keyIter.hasNext && !stopIter) { + val currKey = keyIter.next() + if (currKey <= batchId) { + keyIter.remove() + } else { + stopIter = true + } + } + } +} + +/** + * This implementation simulates propagation of watermark among operators. + * + * The simulation algorithm traverses the physical plan tree via post-order (children first) to + * calculate (input watermark, output watermark) for all nodes. + * + * For each node, below logic is applied: + * + * - Input watermark for specific node is decided by `min(input watermarks from all children)`. + * -- Children providing no input watermark (DEFAULT_WATERMARK_MS) are excluded. + * -- If there is no valid input watermark from children, input watermark = DEFAULT_WATERMARK_MS. + * - Output watermark for specific node is decided as following: + * -- watermark nodes: origin watermark value + * This could be individual origin watermark value, but we decide to retain global watermark + * to keep the watermark model be simple. + * -- stateless nodes: same as input watermark + * -- stateful nodes: the return value of `op.produceOutputWatermark(input watermark)`. + * + * @see [[StateStoreWriter.produceOutputWatermark]] + * + * Note that this implementation will throw an exception if watermark node sees a valid input + * watermark from children, meaning that we do not support re-definition of watermark. + * + * Once the algorithm traverses the physical plan tree, the association between stateful operator + * and input watermark will be constructed. Spark will request the input watermark for specific + * stateful operator, which this implementation will give the value from the association. + * + * We skip simulation of propagation for the value of watermark as 0. Input watermark for every + * operator will be 0. (This may not be expected for the case op.produceOutputWatermark returns + * higher than the input watermark, but it won't happen in most practical cases.) + */ +class PropagateWatermarkSimulator extends WatermarkPropagator with Logging { + // We use treemap to sort the key (batchID) and evict old batch IDs efficiently. + private val batchIdToWatermark: jutil.TreeMap[Long, Long] = new jutil.TreeMap[Long, Long]() + + // contains the association for batchId -> (stateful operator ID -> input watermark) + private val inputWatermarks: mutable.Map[Long, Map[Long, Option[Long]]] = + mutable.Map[Long, Map[Long, Option[Long]]]() + + private def isInitialized(batchId: Long): Boolean = batchIdToWatermark.containsKey(batchId) + + /** + * Retrieve the available input watermarks for specific node in the plan. Every child will + * produce an output watermark, which we capture as input watermark. If the child provides + * default watermark value (no watermark info), it is excluded. + */ + private def getInputWatermarks( + node: SparkPlan, + nodeToOutputWatermark: mutable.Map[Int, Option[Long]]): Seq[Long] = { + node.children.flatMap { child => + nodeToOutputWatermark.getOrElse(child.id, { + throw new IllegalStateException( + s"watermark for the node ${child.id} should be registered") + }) + // Since we use flatMap here, this will exclude children from watermark calculation + // which don't have watermark information. + } + } + + private def doSimulate(batchId: Long, plan: SparkPlan, originWatermark: Long): Unit = { + val statefulOperatorIdToNodeId = mutable.HashMap[Long, Int]() + val nodeToOutputWatermark = mutable.HashMap[Int, Option[Long]]() + val nextStatefulOperatorToWatermark = mutable.HashMap[Long, Option[Long]]() + + // This calculation relies on post-order traversal of the query plan. + plan.transformUp { + case node: EventTimeWatermarkExec => + val inputWatermarks = getInputWatermarks(node, nodeToOutputWatermark) + if (inputWatermarks.nonEmpty) { + throw new AnalysisException("Redefining watermark is disallowed. You can set the " + + s"config '${SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE.key}' to 'false' to restore " + + "the previous behavior. Note that multiple stateful operators will be disallowed.") + } + + nodeToOutputWatermark.put(node.id, Some(originWatermark)) + node + + case node: StateStoreWriter => + val stOpId = node.stateInfo.get.operatorId + statefulOperatorIdToNodeId.put(stOpId, node.id) + + val inputWatermarks = getInputWatermarks(node, nodeToOutputWatermark) + + val finalInputWatermarkMs = if (inputWatermarks.nonEmpty) { + Some(inputWatermarks.min) + } else { + // We can't throw exception here, as we allow stateful operator to process without + // watermark. E.g. streaming aggregation with update/complete mode. + None + } + + val outputWatermarkMs = finalInputWatermarkMs.flatMap { wm => + node.produceOutputWatermark(wm) + } + nodeToOutputWatermark.put(node.id, outputWatermarkMs) + nextStatefulOperatorToWatermark.put(stOpId, finalInputWatermarkMs) + node + + case node => + // pass-through, but also consider multiple children like the case of union + val inputWatermarks = getInputWatermarks(node, nodeToOutputWatermark) + val finalInputWatermarkMs = if (inputWatermarks.nonEmpty) { + Some(inputWatermarks.min) + } else { + None + } + + nodeToOutputWatermark.put(node.id, finalInputWatermarkMs) + node + } + + inputWatermarks.put(batchId, nextStatefulOperatorToWatermark.toMap) + batchIdToWatermark.put(batchId, originWatermark) + + logDebug(s"global watermark for batch ID $batchId is set to $originWatermark") + logDebug(s"input watermarks for batch ID $batchId is set to $nextStatefulOperatorToWatermark") + } + + override def propagate(batchId: Long, plan: SparkPlan, originWatermark: Long): Unit = { + if (batchId < 0) { + // no-op + } else if (isInitialized(batchId)) { + val cached = batchIdToWatermark.get(batchId) + assert(cached == originWatermark, + s"Watermark has been changed for the same batch ID! Batch ID: $batchId, " + + s"Value in cache: $cached, value given: $originWatermark") + } else { + logDebug(s"watermark for batch ID $batchId is received as $originWatermark, " + + s"call site: ${Utils.getCallSite().longForm}") + + if (originWatermark == 0) { + logDebug(s"skipping the propagation for batch $batchId as origin watermark is 0.") + batchIdToWatermark.put(batchId, 0L) + inputWatermarks.put(batchId, Map.empty[Long, Option[Long]]) + } else { + doSimulate(batchId, plan, originWatermark) + } + } + } + + private def getInputWatermark(batchId: Long, stateOpId: Long): Long = { + if (batchId < 0) { + 0 + } else { + assert(isInitialized(batchId), s"Watermark for batch ID $batchId is not yet set!") + inputWatermarks(batchId).get(stateOpId) match { + case Some(wmOpt) => + // In current Spark's logic, event time watermark cannot go down to negative. So even + // there is no input watermark for operator, the final input watermark for operator should + // be 0L. + Math.max(wmOpt.getOrElse(0L), 0L) + case None => + if (batchIdToWatermark.get(batchId) == 0L) { + // We skip the propagation when the origin watermark is produced as 0L. This is safe, + // as output watermark is not expected to be later than the input watermark. That said, + // all operators would have the input watermark as 0L. + 0L + } else { + throw new IllegalStateException(s"Watermark for batch ID $batchId and " + + s"stateOpId $stateOpId is not yet set!") + } + } + } + } + + override def getInputWatermarkForLateEvents(batchId: Long, stateOpId: Long): Long = { + // We use watermark for previous microbatch to determine late events. + getInputWatermark(batchId - 1, stateOpId) + } + + override def getInputWatermarkForEviction(batchId: Long, stateOpId: Long): Long = + getInputWatermark(batchId, stateOpId) + + override def purge(batchId: Long): Unit = { + val keyIter = batchIdToWatermark.keySet().iterator() + var stopIter = false + while (keyIter.hasNext && !stopIter) { + val currKey = keyIter.next() + if (currKey <= batchId) { + keyIter.remove() + inputWatermarks.remove(currKey) + } else { + stopIter = true + } + } + } +} + +object WatermarkPropagator { + def apply(conf: SQLConf): WatermarkPropagator = { + if (conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) { + new PropagateWatermarkSimulator + } else { + new UseSingleWatermarkPropagator + } + } + + def noop(): WatermarkPropagator = NoOpWatermarkPropagator +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index e8092e072bc22..58119b74f5a2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -219,7 +219,8 @@ class ContinuousExecution( runId, currentBatchId, None, - offsetSeqMetadata) + offsetSeqMetadata, + WatermarkPropagator.noop()) lastExecution.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 457e5f80ae6bb..49bb8607128f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -36,6 +37,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.python.PythonSQLMetrics import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} @@ -96,6 +98,39 @@ trait StateStoreReader extends StatefulOperator { /** An operator that writes to a StateStore. */ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: SparkPlan => + /** + * Produce the output watermark for given input watermark (ms). + * + * In most cases, this is same as the criteria of state eviction, as most stateful operators + * produce the output from two different kinds: + * + * 1. without buffering + * 2. with buffering (state) + * + * The state eviction happens when event time exceeds a "certain threshold of timestamp", which + * denotes a lower bound of event time values for output (output watermark). + * + * The default implementation provides the input watermark as it is. Most built-in operators + * will evict based on min input watermark and ensure it will be minimum of the event time value + * for the output so far (including output from eviction). Operators which behave differently + * (e.g. different criteria on eviction) must override this method. + * + * Note that the default behavior wil advance the watermark aggressively to simplify the logic, + * but it does not break the semantic of output watermark, which is following: + * + * An operator guarantees that it will not emit record with an event timestamp lower than its + * output watermark. + * + * For example, for 5 minutes time window aggregation, the advancement of watermark can happen + * "before" the window has been evicted and produced as output. Say, suppose there's an window + * in state: [0, 5) and input watermark = 3. Although there is no output for this operator, this + * operator will produce an output watermark as 3. It's still respecting the guarantee, as the + * operator will produce the window [0, 5) only when the output watermark is equal or greater + * than 5, and the downstream operator will process the input data, "and then" advance the + * watermark. Hence this window is considered as "non-late" record. + */ + def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = Some(inputWatermarkMs) + override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext, @@ -199,9 +234,9 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp /** * Should the MicroBatchExecution run another batch based on this stateful operator and the - * current updated metadata. + * new input watermark. */ - def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false + def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = false } /** An operator that supports watermark. */ @@ -234,10 +269,14 @@ trait WatermarkSupport extends SparkPlan { lazy val watermarkExpressionForEviction: Option[Expression] = watermarkExpression(eventTimeWatermarkForEviction) + lazy val allowMultipleStatefulOperators: Boolean = + conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE) + /** Generate an expression that matches data older than the watermark */ private def watermarkExpression(watermark: Option[Long]): Option[Expression] = { WatermarkSupport.watermarkExpression( - child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)), watermark) + WatermarkSupport.findEventTimeColumn(child.output, + allowMultipleEventTimeColumns = !allowMultipleStatefulOperators), watermark) } /** Predicate based on keys that matches data older than the late event filtering watermark */ @@ -324,6 +363,41 @@ object WatermarkSupport { } Some(evictionExpression) } + + /** + * Find the column which is marked as "event time" column. + * + * If there are multiple event time columns in given column list, the behavior depends on the + * parameter `allowMultipleEventTimeColumns`. If it's set to true, the first occurred column will + * be returned. If not, this method will throw an AnalysisException as it is not allowed to have + * multiple event time columns. + */ + def findEventTimeColumn( + attrs: Seq[Attribute], + allowMultipleEventTimeColumns: Boolean): Option[Attribute] = { + val eventTimeCols = attrs.filter(_.metadata.contains(EventTimeWatermark.delayKey)) + if (!allowMultipleEventTimeColumns) { + // There is a case projection leads the same column (same exprId) to appear more than one + // time. Allowing them does not hurt the correctness of state row eviction, hence let's start + // with allowing them. + val eventTimeColsSet = eventTimeCols.map(_.exprId).toSet + if (eventTimeColsSet.size > 1) { + throw new AnalysisException("More than one event time columns are available. Please " + + "ensure there is at most one event time column per stream. event time columns: " + + eventTimeCols.mkString("(", ",", ")")) + } + + // With above check, even there are multiple columns in eventTimeCols, all columns must be + // the same. + } else { + // This is for compatibility with previous behavior - we allow multiple distinct event time + // columns and pick up the first occurrence. This is incorrect if non-first occurrence is + // not smaller than the first one, but allow this as "escape hatch" in case we break the + // existing query. + } + // pick the first element if exists + eventTimeCols.headOption + } } /** @@ -545,10 +619,10 @@ case class StateStoreSaveExec( override def shortName: String = "stateStoreSave" - override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && eventTimeWatermarkForEviction.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get + newInputWatermark > eventTimeWatermarkForEviction.get } override protected def withNewChildInternal(newChild: SparkPlan): StateStoreSaveExec = @@ -744,10 +818,10 @@ case class SessionWindowStateStoreSaveExec( keyWithoutSessionExpressions, getStateInfo, conf) :: Nil } - override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && eventTimeWatermarkForEviction.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get + newInputWatermark > eventTimeWatermarkForEviction.get } private def putToStore(iter: Iterator[InternalRow], store: StateStore): Unit = { @@ -893,9 +967,9 @@ case class StreamingDeduplicateExec( override def shortName: String = "dedupe" - override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { eventTimeWatermarkForEviction.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get + newInputWatermark > eventTimeWatermarkForEviction.get } override protected def withNewChildInternal(newChild: SparkPlan): StreamingDeduplicateExec = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 058c335ad43e0..ca8ad7f88bc34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -28,14 +28,15 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ +import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Dataset} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemorySink -import org.apache.spark.sql.functions.{count, timestamp_seconds, window} +import org.apache.spark.sql.functions.{count, expr, timestamp_seconds, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.util.Utils @@ -548,17 +549,131 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche assert(e.getMessage contains "should not be negative.") } - test("the new watermark should override the old one") { - val df = MemoryStream[(Long, Long)].toDF() + private def buildTestQueryForOverridingWatermark(): (MemoryStream[(Long, Long)], DataFrame) = { + val input = MemoryStream[(Long, Long)] + val df = input.toDF() .withColumn("first", timestamp_seconds($"_1")) .withColumn("second", timestamp_seconds($"_2")) .withWatermark("first", "1 minute") + .select("*") .withWatermark("second", "2 minutes") + .groupBy(window($"second", "1 minute")) + .count() - val eventTimeColumns = df.logicalPlan.output - .filter(_.metadata.contains(EventTimeWatermark.delayKey)) - assert(eventTimeColumns.size === 1) - assert(eventTimeColumns(0).name === "second") + (input, df) + } + + test("overriding watermark should not be allowed by default") { + val (input, df) = buildTestQueryForOverridingWatermark() + testStream(df)( + AddData(input, (100L, 200L)), + ExpectFailure[AnalysisException](assertFailure = exc => { + assert(exc.getMessage.contains("Redefining watermark is disallowed.")) + assert(exc.getMessage.contains(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE.key)) + }) + ) + } + + test("overriding watermark should not fail in compatibility mode") { + withSQLConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE.key -> "false") { + val (input, df) = buildTestQueryForOverridingWatermark() + testStream(df)( + AddData(input, (100L, 200L)), + CheckAnswer(), + Execute { query => + val lastExecution = query.lastExecution + val aggSaveOperator = lastExecution.executedPlan.collect { + case j: StateStoreSaveExec => j + }.head + + // - watermark from first definition = 100 - 60 = 40 + // - watermark from second definition = 200 - 120 = 80 + // - global watermark = min(40, 60) = 40 + // + // As we see the result, even though we override the watermark definition, the old + // definition of watermark still plays to calculate global watermark. + // + // This is conceptually the right behavior. For operators after the first watermark + // definition, the column named "first" is considered as event time column, and for + // operators after the second watermark definition, the column named "second" is + // considered as event time column. The correct "single" value of watermark satisfying + // all operators should be lower bound of both columns "first" and "second". + // + // That said, this easily leads to incorrect definition - e.g. re-define watermark + // against the output of streaming aggregation for append mode. The global watermark + // cannot advance. This is the reason we don't allow re-define watermark in new behavior. + val expectedWatermarkMs = 40 * 1000 + + assert(aggSaveOperator.eventTimeWatermarkForLateEvents === Some(expectedWatermarkMs)) + assert(aggSaveOperator.eventTimeWatermarkForEviction === Some(expectedWatermarkMs)) + + val eventTimeCols = aggSaveOperator.keyExpressions.filter( + _.metadata.contains(EventTimeWatermark.delayKey)) + assert(eventTimeCols.size === 1) + assert(eventTimeCols.head.name === "window") + // 2 minutes delay threshold + assert(eventTimeCols.head.metadata.getLong(EventTimeWatermark.delayKey) === 120 * 1000) + } + ) + } + } + + private def buildTestQueryForMultiEventTimeColumns() + : (MemoryStream[(String, Long)], MemoryStream[(String, Long)], DataFrame) = { + val input1 = MemoryStream[(String, Long)] + val input2 = MemoryStream[(String, Long)] + val df1 = input1.toDF() + .selectExpr("_1 AS id1", "timestamp_seconds(_2) AS ts1") + .withWatermark("ts1", "1 minute") + + val df2 = input2.toDF() + .selectExpr("_1 AS id2", "timestamp_seconds(_2) AS ts2") + .withWatermark("ts2", "2 minutes") + + val joined = df1.join(df2, expr("id1 = id2 AND ts1 = ts2 + INTERVAL 10 SECONDS"), "inner") + .selectExpr("id1", "ts1", "ts2") + // the output of join contains both ts1 and ts2 + val dedup = joined.dropDuplicates() + .selectExpr("id1", "CAST(ts1 AS LONG) AS ts1", "CAST(ts2 AS LONG) AS ts2") + + (input1, input2, dedup) + } + + test("multiple event time columns in an input DataFrame for stateful operator is " + + "not allowed") { + // for ease of verification, we change the session timezone to UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + val (input1, input2, dedup) = buildTestQueryForMultiEventTimeColumns() + testStream(dedup)( + MultiAddData( + (input1, Seq(("A", 200L), ("B", 300L))), + (input2, Seq(("A", 190L), ("C", 350L))) + ), + ExpectFailure[SparkException](assertFailure = exc => { + val cause = exc.getCause + assert(cause.getMessage.contains("More than one event time columns are available.")) + assert(cause.getMessage.contains( + "Please ensure there is at most one event time column per stream.")) + }) + ) + } + } + + test("stateful operator should pick the first occurrence of event time column if there is " + + "multiple event time columns in compatibility mode") { + // for ease of verification, we change the session timezone to UTC + withSQLConf( + SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + val (input1, input2, dedup) = buildTestQueryForMultiEventTimeColumns() + testStream(dedup)( + MultiAddData( + (input1, Seq(("A", 200L), ("B", 300L))), + (input2, Seq(("A", 190L), ("C", 350L))) + ), + CheckAnswer(("A", 200L, 190L)) + ) + } } test("EventTime watermark should be ignored in batch query.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index eb1e0de79cae7..9f00aa2e6ee4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.streaming +import java.sql.Timestamp + import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreSaveExec, StreamingSymmetricHashJoinExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf // Tests for the multiple stateful operators support. class MultiStatefulOperatorsSuite @@ -403,35 +406,6 @@ class MultiStatefulOperatorsSuite ) } - test("join on time interval -> window agg, append mode, should fail") { - val input1 = MemoryStream[Int] - val inputDF1 = input1.toDF() - .withColumnRenamed("value", "value1") - .withColumn("eventTime1", timestamp_seconds($"value1")) - .withWatermark("eventTime1", "0 seconds") - - val input2 = MemoryStream[(Int, Int)] - val inputDF2 = input2.toDS().toDF("start", "end") - .withColumn("eventTime2Start", timestamp_seconds($"start")) - .withColumn("eventTime2End", timestamp_seconds($"end")) - .withColumn("start2", timestamp_seconds($"start")) - .withWatermark("eventTime2Start", "0 seconds") - - val stream = inputDF1.join(inputDF2, - expr("eventTime1 >= eventTime2Start AND eventTime1 < eventTime2End " + - "AND eventTime1 = start2"), "inner") - .groupBy(window($"eventTime1", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - val e = intercept[AnalysisException] { - testStream(stream)( - StartStream() - ) - } - assert(e.getMessage.contains("Detected pattern of possible 'correctness' issue")) - } - test("join with range join on non-time intervals -> window agg, append mode, shouldn't fail") { val input1 = MemoryStream[Int] val inputDF1 = input1.toDF() @@ -463,6 +437,445 @@ class MultiStatefulOperatorsSuite ) } + test("stream-stream time interval left outer join -> aggregation, append mode") { + // This test performs stream-stream time interval left outer join against two streams, and + // applies tumble time window aggregation based on the event time column from the output of + // stream-stream join. + val input1 = MemoryStream[(String, Timestamp)] + val input2 = MemoryStream[(String, Timestamp)] + + val s1 = input1.toDF() + .toDF("id1", "timestamp1") + .withWatermark("timestamp1", "0 seconds") + .as("s1") + + val s2 = input2.toDF() + .toDF("id2", "timestamp2") + .withWatermark("timestamp2", "0 seconds") + .as("s2") + + val s3 = s1.join(s2, expr("s1.id1 = s2.id2 AND (s1.timestamp1 BETWEEN " + + "s2.timestamp2 - INTERVAL 1 hour AND s2.timestamp2 + INTERVAL 1 hour)"), "leftOuter") + + val agg = s3.groupBy(window($"timestamp1", "10 minutes")) + .agg(count("*").as("cnt")) + .selectExpr("CAST(window.start AS STRING) AS window_start", + "CAST(window.end AS STRING) AS window_end", "cnt") + + // for ease of verification, we change the session timezone to UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + testStream(agg)( + MultiAddData( + (input1, Seq( + ("1", Timestamp.valueOf("2023-01-01 01:00:10")), + ("2", Timestamp.valueOf("2023-01-01 01:00:30"))) + ), + (input2, Seq( + ("1", Timestamp.valueOf("2023-01-01 01:00:20")))) + ), + + // < data batch > + // global watermark (0, 0) + // op1 (join) + // -- IW (0, 0) + // -- OW 0 + // -- left state + // ("1", "2023-01-01 01:00:10", matched=true) + // ("1", "2023-01-01 01:00:30", matched=false) + // -- right state + // ("1", "2023-01-01 01:00:20") + // -- result + // ("1", "2023-01-01 01:00:10", "1", "2023-01-01 01:00:20") + // op2 (aggregation) + // -- IW (0, 0) + // -- OW 0 + // -- state row + // ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 1) + // -- result + // None + + // -- watermark calculation + // watermark in left input: 2023-01-01 01:00:30 + // watermark in right input: 2023-01-01 01:00:20 + // origin watermark: 2023-01-01 01:00:20 + + // < no-data batch > + // global watermark (0, 2023-01-01 01:00:20) + // op1 (join) + // -- IW (0, 2023-01-01 01:00:20) + // -- OW 2023-01-01 00:00:19.999999 + // -- left state + // ("1", "2023-01-01 01:00:10", matched=true) + // ("1", "2023-01-01 01:00:30", matched=false) + // -- right state + // ("1", "2023-01-01 01:00:20") + // -- result + // None + // op2 (aggregation) + // -- IW (0, 2023-01-01 00:00:19.999999) + // -- OW 2023-01-01 00:00:19.999999 + // -- state row + // ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 1) + // -- result + // None + CheckAnswer(), + + Execute { query => + val lastExecution = query.lastExecution + val joinOperator = lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val aggSaveOperator = lastExecution.executedPlan.collect { + case j: StateStoreSaveExec => j + }.head + + assert(joinOperator.eventTimeWatermarkForLateEvents === Some(0)) + assert(joinOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 01:00:20").getTime)) + + assert(aggSaveOperator.eventTimeWatermarkForLateEvents === Some(0)) + assert(aggSaveOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 00:00:20").getTime - 1)) + }, + + MultiAddData( + (input1, Seq(("5", Timestamp.valueOf("2023-01-01 01:15:00")))), + (input2, Seq(("6", Timestamp.valueOf("2023-01-01 01:15:00")))) + ), + + // < data batch > + // global watermark (2023-01-01 01:00:20, 2023-01-01 01:00:20) + // op1 (join) + // -- IW (2023-01-01 01:00:20, 2023-01-01 01:00:20) + // -- OW 2023-01-01 00:00:19.999999 + // -- left state + // ("1", "2023-01-01 01:00:10", matched=true) + // ("1", "2023-01-01 01:00:30", matched=false) + // ("5", "2023-01-01 01:15:00", matched=false) + // -- right state + // ("1", "2023-01-01 01:00:20") + // ("6", "2023-01-01 01:15:00") + // -- result + // None + // op2 (aggregation) + // -- IW (2023-01-01 00:00:19.999999, 2023-01-01 00:00:19.999999) + // -- OW 2023-01-01 00:00:19.999999 + // -- state row + // ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 1) + // -- result + // None + + // -- watermark calculation + // watermark in left input: 2023-01-01 01:15:00 + // watermark in right input: 2023-01-01 01:15:00 + // origin watermark: 2023-01-01 01:15:00 + + // < no-data batch > + // global watermark (2023-01-01 01:00:20, 2023-01-01 01:15:00) + // op1 (join) + // -- IW (2023-01-01 01:00:20, 2023-01-01 01:15:00) + // -- OW 2023-01-01 00:14:59.999999 + // -- left state + // ("1", "2023-01-01 01:00:10", matched=true) + // ("1", "2023-01-01 01:00:30", matched=false) + // ("5", "2023-01-01 01:15:00", matched=false) + // -- right state + // ("1", "2023-01-01 01:00:20") + // ("6", "2023-01-01 01:15:00") + // -- result + // None + // op2 (aggregation) + // -- IW (2023-01-01 00:00:19.999999, 2023-01-01 00:14:59.999999) + // -- OW 2023-01-01 00:14:59.999999 + // -- state row + // ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 1) + // -- result + // None + CheckAnswer(), + + Execute { query => + val lastExecution = query.lastExecution + val joinOperator = lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val aggSaveOperator = lastExecution.executedPlan.collect { + case j: StateStoreSaveExec => j + }.head + + assert(joinOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 01:00:20").getTime)) + assert(joinOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 01:15:00").getTime)) + + assert(aggSaveOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 00:00:20").getTime - 1)) + assert(aggSaveOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 00:15:00").getTime - 1)) + }, + + MultiAddData( + (input1, Seq( + ("5", Timestamp.valueOf("2023-01-01 02:16:00")))), + (input2, Seq( + ("6", Timestamp.valueOf("2023-01-01 02:16:00")))) + ), + + // < data batch > + // global watermark (2023-01-01 01:15:00, 2023-01-01 01:15:00) + // op1 (join) + // -- IW (2023-01-01 01:15:00, 2023-01-01 01:15:00) + // -- OW 2023-01-01 00:14:59.999999 + // -- left state + // ("1", "2023-01-01 01:00:10", matched=true) + // ("1", "2023-01-01 01:00:30", matched=false) + // ("5", "2023-01-01 01:15:00", matched=false) + // ("5", "2023-01-01 02:16:00", matched=false) + // -- right state + // ("1", "2023-01-01 01:00:20") + // ("6", "2023-01-01 01:15:00") + // ("6", "2023-01-01 02:16:00") + // -- result + // None + // op2 (aggregation) + // -- IW (2023-01-01 00:14:59.999999, 2023-01-01 00:14:59.999999) + // -- OW 2023-01-01 00:14:59.999999 + // -- state row + // ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 1) + // -- result + // None + + // -- watermark calculation + // watermark in left input: 2023-01-01 02:16:00 + // watermark in right input: 2023-01-01 02:16:00 + // origin watermark: 2023-01-01 02:16:00 + + // < no-data batch > + // global watermark (2023-01-01 01:15:00, 2023-01-01 02:16:00) + // op1 (join) + // -- IW (2023-01-01 01:15:00, 2023-01-01 02:16:00) + // -- OW 2023-01-01 01:15:59.999999 + // -- left state + // ("5", "2023-01-01 02:16:00", matched=false) + // -- right state + // ("6", "2023-01-01 02:16:00") + // -- result + // ("1", "2023-01-01 01:00:30", null, null) + // ("5", "2023-01-01 01:15:00", null, null) + // op2 (aggregation) + // -- IW (2023-01-01 00:14:59.999999, 2023-01-01 01:15:59.999999) + // -- OW 2023-01-01 01:15:59.999999 + // -- state row + // ("2023-01-01 01:10:00", "2023-01-01 01:20:00", 1) + // -- result + // ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 2) + CheckAnswer( + ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 2) + ), + + Execute { query => + val lastExecution = query.lastExecution + val joinOperator = lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val aggSaveOperator = lastExecution.executedPlan.collect { + case j: StateStoreSaveExec => j + }.head + + assert(joinOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 01:15:00").getTime)) + assert(joinOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 02:16:00").getTime)) + + assert(aggSaveOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 00:15:00").getTime - 1)) + assert(aggSaveOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 01:16:00").getTime - 1)) + } + ) + } + } + + // This test case simply swaps the left and right from the test case "stream-stream time interval + // left outer join -> aggregation, append mode". This test case intends to verify the behavior + // that both event time columns from both inputs are available to use after stream-stream join. + // For explanation of the behavior, please refer to the test case "stream-stream time interval + // left outer join -> aggregation, append mode". + test("stream-stream time interval right outer join -> aggregation, append mode") { + val input1 = MemoryStream[(String, Timestamp)] + val input2 = MemoryStream[(String, Timestamp)] + + val s1 = input1.toDF() + .toDF("id1", "timestamp1") + .withWatermark("timestamp1", "0 seconds") + .as("s1") + + val s2 = input2.toDF() + .toDF("id2", "timestamp2") + .withWatermark("timestamp2", "0 seconds") + .as("s2") + + val s3 = s1.join(s2, expr("s1.id1 = s2.id2 AND (s1.timestamp1 BETWEEN " + + "s2.timestamp2 - INTERVAL 1 hour AND s2.timestamp2 + INTERVAL 1 hour)"), "rightOuter") + + val agg = s3.groupBy(window($"timestamp2", "10 minutes")) + .agg(count("*").as("cnt")) + .selectExpr("CAST(window.start AS STRING) AS window_start", + "CAST(window.end AS STRING) AS window_end", "cnt") + + // for ease of verification, we change the session timezone to UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + testStream(agg)( + MultiAddData( + (input2, Seq( + ("1", Timestamp.valueOf("2023-01-01 01:00:10")), + ("2", Timestamp.valueOf("2023-01-01 01:00:30"))) + ), + (input1, Seq( + ("1", Timestamp.valueOf("2023-01-01 01:00:20")))) + ), + CheckAnswer(), + + Execute { query => + val lastExecution = query.lastExecution + val joinOperator = lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val aggSaveOperator = lastExecution.executedPlan.collect { + case j: StateStoreSaveExec => j + }.head + + assert(joinOperator.eventTimeWatermarkForLateEvents === Some(0)) + assert(joinOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 01:00:20").getTime)) + + assert(aggSaveOperator.eventTimeWatermarkForLateEvents === Some(0)) + assert(aggSaveOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 00:00:20").getTime - 1)) + }, + + MultiAddData( + (input2, Seq(("5", Timestamp.valueOf("2023-01-01 01:15:00")))), + (input1, Seq(("6", Timestamp.valueOf("2023-01-01 01:15:00")))) + ), + CheckAnswer(), + + Execute { query => + val lastExecution = query.lastExecution + val joinOperator = lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val aggSaveOperator = lastExecution.executedPlan.collect { + case j: StateStoreSaveExec => j + }.head + + assert(joinOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 01:00:20").getTime)) + assert(joinOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 01:15:00").getTime)) + + assert(aggSaveOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 00:00:20").getTime - 1)) + assert(aggSaveOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 00:15:00").getTime - 1)) + }, + + MultiAddData( + (input2, Seq( + ("5", Timestamp.valueOf("2023-01-01 02:16:00")))), + (input1, Seq( + ("6", Timestamp.valueOf("2023-01-01 02:16:00")))) + ), + CheckAnswer( + ("2023-01-01 01:00:00", "2023-01-01 01:10:00", 2) + ), + + Execute { query => + val lastExecution = query.lastExecution + val joinOperator = lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val aggSaveOperator = lastExecution.executedPlan.collect { + case j: StateStoreSaveExec => j + }.head + + assert(joinOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 01:15:00").getTime)) + assert(joinOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 02:16:00").getTime)) + + assert(aggSaveOperator.eventTimeWatermarkForLateEvents === + Some(Timestamp.valueOf("2023-01-01 00:15:00").getTime - 1)) + assert(aggSaveOperator.eventTimeWatermarkForEviction === + Some(Timestamp.valueOf("2023-01-01 01:16:00").getTime - 1)) + } + ) + } + } + + test("stream-stream time interval join - output watermark for various intervals") { + def testOutputWatermarkInJoin( + df: DataFrame, + input: MemoryStream[(String, Timestamp)], + expectedOutputWatermark: Long): Unit = { + testStream(df)( + // dummy row to trigger execution + AddData(input, ("1", Timestamp.valueOf("2023-01-01 01:00:10"))), + CheckAnswer(), + Execute { query => + val lastExecution = query.lastExecution + val joinOperator = lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + + val outputWatermark = joinOperator.produceOutputWatermark(0) + assert(outputWatermark.get === expectedOutputWatermark) + } + ) + } + + val input1 = MemoryStream[(String, Timestamp)] + val df1 = input1.toDF + .selectExpr("_1 as leftId", "_2 as leftEventTime") + .withWatermark("leftEventTime", "5 minutes") + + val input2 = MemoryStream[(String, Timestamp)] + val df2 = input2.toDF + .selectExpr("_1 as rightId", "_2 as rightEventTime") + .withWatermark("rightEventTime", "10 minutes") + + val join1 = df1.join(df2, + expr( + """ + |leftId = rightId AND leftEventTime BETWEEN + | rightEventTime AND rightEventTime + INTERVAL 40 seconds + |""".stripMargin)) + + // right row should wait for additional 40 seconds (+ 1 ms) to be matched with left rows + testOutputWatermarkInJoin(join1, input1, -40L * 1000 - 1) + + val join2 = df1.join(df2, + expr( + """ + |leftId = rightId AND leftEventTime BETWEEN + | rightEventTime - INTERVAL 30 seconds AND rightEventTime + |""".stripMargin)) + + // left row should wait for additional 30 seconds (+ 1 ms) to be matched with left rows + testOutputWatermarkInJoin(join2, input1, -30L * 1000 - 1) + + val join3 = df1.join(df2, + expr( + """ + |leftId = rightId AND leftEventTime BETWEEN + | rightEventTime - INTERVAL 30 seconds AND rightEventTime + INTERVAL 40 seconds + |""".stripMargin)) + + // left row should wait for additional 30 seconds (+ 1 ms) to be matched with left rows + // right row should wait for additional 40 seconds (+ 1 ms) to be matched with right rows + // taking minimum of both criteria - 40 seconds (+ 1 ms) + testOutputWatermarkInJoin(join3, input1, -40L * 1000 - 1) + } + private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 0315e03d18784..8607de389425c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -150,7 +150,6 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest { .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") .dropDuplicates() - .withWatermark("eventTime", "10 seconds") .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])