From 50a95a456c8c73449d6f6e8b5da4826e56ed7a94 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Sep 2022 18:42:03 +0900 Subject: [PATCH] [SPARK-40467][SS] Split FlatMapGroupsWithState down to multiple test suites ### What changes were proposed in this pull request? This PR proposes to split the FlatMapGroupsWithStateSuite into three pieces: 1. GroupStateSuite <- test the functionality with (Test)GroupState implementation 2. FlatMapGroupsWithStateWithInitialStateSuite <- test E2E cases which are specific to initial state 3. FlatMapGroupsWithStateSuite <- all other cases (E2E cases which don't leverage initial state) The change is pure extraction - it's cut and paste and no additional code change has been introduced. ### Why are the changes needed? This would help to maintain the test suite FlatMapGroupsWithStateSuite as it's quite huge. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Refactored test suites. Closes #37907 from HeartSaVioR/SPARK-40467. Authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../FlatMapGroupsWithStateSuite.scala | 758 +----------------- ...GroupsWithStateWithInitialStateSuite.scala | 365 +++++++++ .../spark/sql/streaming/GroupStateSuite.scala | 458 +++++++++++ 3 files changed, 825 insertions(+), 756 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b7c9aa4178090..14f083bbd307a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.streaming import java.io.File -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.commons.io.FileUtils import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException -import org.apache.spark.api.java.Optional import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, KeyValueGroupedDataset} +import org.apache.spark.sql.{DataFrame, Encoder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState @@ -78,416 +77,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } } - test("SPARK-35800: ensure TestGroupState creates instances the same as prod") { - val testState = TestGroupState.create[Int]( - Optional.of(5), EventTimeTimeout, 1L, Optional.of(1L), hasTimedOut = false) - - val prodState = GroupStateImpl.createForStreaming[Int]( - Some(5), 1L, 1L, EventTimeTimeout, false, true) - - assert(testState.isInstanceOf[GroupStateImpl[Int]]) - - assert(testState.isRemoved === prodState.isRemoved) - assert(testState.isUpdated === prodState.isUpdated) - assert(testState.exists === prodState.exists) - assert(testState.get === prodState.get) - assert(testState.getTimeoutTimestampMs === prodState.getTimeoutTimestampMs) - assert(testState.hasTimedOut === prodState.hasTimedOut) - assert(testState.getCurrentProcessingTimeMs === prodState.getCurrentProcessingTimeMs) - assert(testState.getCurrentWatermarkMs === prodState.getCurrentWatermarkMs) - - testState.update(6) - prodState.update(6) - assert(testState.isUpdated === prodState.isUpdated) - assert(testState.exists === prodState.exists) - assert(testState.get === prodState.get) - - testState.remove() - prodState.remove() - assert(testState.exists === prodState.exists) - assert(testState.isRemoved === prodState.isRemoved) - } - - test("GroupState - get, exists, update, remove") { - var state: TestGroupState[String] = null - - def testState( - expectedData: Option[String], - shouldBeUpdated: Boolean = false, - shouldBeRemoved: Boolean = false): Unit = { - if (expectedData.isDefined) { - assert(state.exists) - assert(state.get === expectedData.get) - } else { - assert(!state.exists) - intercept[NoSuchElementException] { - state.get - } - } - assert(state.getOption === expectedData) - assert(state.isUpdated === shouldBeUpdated) - assert(state.isRemoved === shouldBeRemoved) - } - - // === Tests for state in streaming queries === - // Updating empty state - state = TestGroupState.create[String]( - Optional.empty[String], NoTimeout, 1, Optional.empty[Long], hasTimedOut = false) - testState(None) - state.update("") - testState(Some(""), shouldBeUpdated = true) - - // Updating exiting state - state = TestGroupState.create[String]( - Optional.of("2"), NoTimeout, 1, Optional.empty[Long], hasTimedOut = false) - testState(Some("2")) - state.update("3") - testState(Some("3"), shouldBeUpdated = true) - - // Removing state - state.remove() - testState(None, shouldBeRemoved = true, shouldBeUpdated = false) - state.remove() // should be still callable - state.update("4") - testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) - - // Updating by null throw exception - intercept[IllegalArgumentException] { - state.update(null) - } - } - - test("GroupState - setTimeout - with NoTimeout") { - for (initValue <- Seq(Optional.empty[Int], Optional.of((5)))) { - val states = Seq( - TestGroupState.create[Int]( - initValue, NoTimeout, 1000, Optional.empty[Long], hasTimedOut = false), - GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false) - ) - for (state <- states) { - // for streaming queries - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - - // for batch queries - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - } - } - } - - test("GroupState - setTimeout - with ProcessingTimeTimeout") { - // for streaming queries - var state = TestGroupState.create[Int]( - Optional.empty[Int], ProcessingTimeTimeout, 1000, Optional.empty[Long], hasTimedOut = false) - assert(!state.getTimeoutTimestampMs.isPresent()) - state.setTimeoutDuration("-1 month 31 days 1 second") - assert(state.getTimeoutTimestampMs.isPresent()) - assert(state.getTimeoutTimestampMs.get() === 2000) - state.setTimeoutDuration(500) - assert(state.getTimeoutTimestampMs.get() === 1500) // can be set without initializing state - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - - state.update(5) - assert(state.getTimeoutTimestampMs.isPresent()) - assert(state.getTimeoutTimestampMs.get() === 1500) // does not change - state.setTimeoutDuration(1000) - assert(state.getTimeoutTimestampMs.get() === 2000) - state.setTimeoutDuration("2 second") - assert(state.getTimeoutTimestampMs.get() === 3000) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - - state.remove() - assert(state.getTimeoutTimestampMs.isPresent()) - assert(state.getTimeoutTimestampMs.get() === 3000) // does not change - state.setTimeoutDuration(500) // can still be set - assert(state.getTimeoutTimestampMs.get() === 1500) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - - // for batch queries - state = GroupStateImpl.createForBatch( - ProcessingTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] - assert(!state.getTimeoutTimestampMs.isPresent()) - state.setTimeoutDuration(500) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - - state.update(5) - state.setTimeoutDuration(1000) - state.setTimeoutDuration("2 second") - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - - state.remove() - state.setTimeoutDuration(500) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - } - - test("GroupState - setTimeout - with EventTimeTimeout") { - var state = TestGroupState.create[Int]( - Optional.empty[Int], EventTimeTimeout, 1000, Optional.of(1000), hasTimedOut = false) - assert(!state.getTimeoutTimestampMs.isPresent()) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestampMs.get() === 5000) // can be set without initializing state - - state.update(5) - assert(state.getTimeoutTimestampMs.get() === 5000) // does not change - state.setTimeoutTimestamp(10000) - assert(state.getTimeoutTimestampMs.get() === 10000) - state.setTimeoutTimestamp(new Date(20000)) - assert(state.getTimeoutTimestampMs.get() === 20000) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - - state.remove() - assert(state.getTimeoutTimestampMs.get() === 20000) - state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestampMs.get() === 5000) // can be set after removing state - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - - // for batch queries - state = GroupStateImpl.createForBatch( - EventTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] - assert(!state.getTimeoutTimestampMs.isPresent()) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - state.setTimeoutTimestamp(5000) - - state.update(5) - state.setTimeoutTimestamp(10000) - state.setTimeoutTimestamp(new Date(20000)) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - - state.remove() - state.setTimeoutTimestamp(5000) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - } - - test("GroupState - illegal params to setTimeout") { - var state: TestGroupState[Int] = null - - // Test setTimeout() with illegal values - def testIllegalTimeout(body: => Unit): Unit = { - intercept[IllegalArgumentException] { - body - } - assert(!state.getTimeoutTimestampMs.isPresent()) - } - - // Test setTimeout() with illegal values - state = TestGroupState.create[Int]( - Optional.of(5), ProcessingTimeTimeout, 1000, Optional.empty[Long], hasTimedOut = false) - - testIllegalTimeout { - state.setTimeoutDuration(-1000) - } - testIllegalTimeout { - state.setTimeoutDuration(0) - } - testIllegalTimeout { - state.setTimeoutDuration("-2 second") - } - testIllegalTimeout { - state.setTimeoutDuration("-1 month") - } - - testIllegalTimeout { - state.setTimeoutDuration("1 month -31 day") - } - - state = TestGroupState.create[Int]( - Optional.of(5), EventTimeTimeout, 1000, Optional.of(1000), hasTimedOut = false) - testIllegalTimeout { - state.setTimeoutTimestamp(-10000) - } - testIllegalTimeout { - state.setTimeoutTimestamp(10000, "-3 second") - } - testIllegalTimeout { - state.setTimeoutTimestamp(10000, "-1 month") - } - testIllegalTimeout { - state.setTimeoutTimestamp(10000, "1 month -32 day") - } - testIllegalTimeout { - state.setTimeoutTimestamp(new Date(-10000)) - } - testIllegalTimeout { - state.setTimeoutTimestamp(new Date(-10000), "-3 second") - } - testIllegalTimeout { - state.setTimeoutTimestamp(new Date(-10000), "-1 month") - } - testIllegalTimeout { - state.setTimeoutTimestamp(new Date(-10000), "1 month -32 day") - } - } - - test("SPARK-35800: illegal params to create") { - // eventTimeWatermarkMs >= 0 if present - var illegalArgument = intercept[IllegalArgumentException] { - TestGroupState.create[Int]( - Optional.of(5), EventTimeTimeout, 100L, Optional.of(-1000), hasTimedOut = false) - } - assert( - illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or positive if present")) - illegalArgument = intercept[IllegalArgumentException] { - GroupStateImpl.createForStreaming[Int]( - Some(5), 100L, -1000L, EventTimeTimeout, false, true) - } - assert( - illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or positive if present")) - - // batchProcessingTimeMs must be positive - illegalArgument = intercept[IllegalArgumentException] { - TestGroupState.create[Int]( - Optional.of(5), EventTimeTimeout, -100L, Optional.of(1000), hasTimedOut = false) - } - assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 0 or positive")) - illegalArgument = intercept[IllegalArgumentException] { - GroupStateImpl.createForStreaming[Int]( - Some(5), -100L, 1000L, EventTimeTimeout, false, true) - } - assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 0 or positive")) - - // hasTimedOut cannot be true if there's no timeout configured - var unsupportedOperation = intercept[UnsupportedOperationException] { - TestGroupState.create[Int]( - Optional.of(5), NoTimeout, 100L, Optional.empty[Long], hasTimedOut = true) - } - assert( - unsupportedOperation - .getMessage.contains("hasTimedOut is true however there's no timeout configured")) - unsupportedOperation = intercept[UnsupportedOperationException] { - GroupStateImpl.createForStreaming[Int]( - Some(5), 100L, NO_TIMESTAMP, NoTimeout, true, false) - } - assert( - unsupportedOperation - .getMessage.contains("hasTimedOut is true however there's no timeout configured")) - } - - test("GroupState - hasTimedOut") { - for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { - // for streaming queries - for (initState <- Seq(Optional.empty[Int], Optional.of(5))) { - val state1 = TestGroupState.create[Int]( - initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = false) - assert(state1.hasTimedOut === false) - - // hasTimedOut can only be set as true when timeoutConf isn't NoTimeout - if (timeoutConf != NoTimeout) { - val state2 = TestGroupState.create[Int]( - initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = true) - assert(state2.hasTimedOut) - } - } - - // for batch queries - assert( - GroupStateImpl.createForBatch(timeoutConf, watermarkPresent = false).hasTimedOut === false) - } - } - - test("GroupState - getCurrentWatermarkMs") { - def streamingState( - timeoutConf: GroupStateTimeout, - watermark: Optional[Long]): GroupState[Int] = { - TestGroupState.create[Int]( - Optional.empty[Int], timeoutConf, 1000, watermark, hasTimedOut = false) - } - - def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { - GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) - } - - def assertWrongTimeoutError(test: => Unit): Unit = { - val e = intercept[UnsupportedOperationException] { test } - assert(e.getMessage.contains( - "Cannot get event time watermark timestamp without setting watermark")) - } - - for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { - // Tests for getCurrentWatermarkMs in streaming queries - assertWrongTimeoutError { - streamingState(timeoutConf, Optional.empty[Long]).getCurrentWatermarkMs() - } - assert(streamingState(timeoutConf, Optional.of(0)).getCurrentWatermarkMs() === 0) - assert(streamingState(timeoutConf, Optional.of(1000)).getCurrentWatermarkMs() === 1000) - assert(streamingState(timeoutConf, Optional.of(2000)).getCurrentWatermarkMs() === 2000) - assert(batchState(EventTimeTimeout, watermarkPresent = true).getCurrentWatermarkMs() === -1) - - // Tests for getCurrentWatermarkMs in batch queries - assertWrongTimeoutError { - batchState(timeoutConf, watermarkPresent = false).getCurrentWatermarkMs() - } - assert(batchState(timeoutConf, watermarkPresent = true).getCurrentWatermarkMs() === -1) - } - } - - test("GroupState - getCurrentProcessingTimeMs") { - def streamingState( - timeoutConf: GroupStateTimeout, - procTime: Long, - watermarkPresent: Boolean): GroupState[Int] = { - val eventTimeWatermarkMs = if (watermarkPresent) { - Optional.of(1000L) - } else { - Optional.empty[Long] - } - TestGroupState.create[Int]( - Optional.of(1000), timeoutConf, procTime, eventTimeWatermarkMs, hasTimedOut = false) - } - - def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { - GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) - } - - for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { - for (watermarkPresent <- Seq(false, true)) { - // Tests for getCurrentProcessingTimeMs in streaming queries - // No negative processing time is allowed, and - // illegal input validation has been added in the separate test - assert(streamingState(timeoutConf, 0, watermarkPresent) - .getCurrentProcessingTimeMs() === 0) - assert(streamingState(timeoutConf, 1000, watermarkPresent) - .getCurrentProcessingTimeMs() === 1000) - assert(streamingState(timeoutConf, 2000, watermarkPresent) - .getCurrentProcessingTimeMs() === 2000) - - // Tests for getCurrentProcessingTimeMs in batch queries - val currentTime = System.currentTimeMillis() - assert(batchState(timeoutConf, watermarkPresent).getCurrentProcessingTimeMs >= currentTime) - } - } - } - - - test("GroupState - primitive type") { - var intState = TestGroupState.create[Int]( - Optional.empty[Int], - NoTimeout, - 1000, - Optional.empty[Long], - hasTimedOut = false) - intercept[NoSuchElementException] { - intState.get - } - assert(intState.getOption === None) - - intState = TestGroupState.create[Int]( - Optional.of(10), - NoTimeout, - 1000, - Optional.empty[Long], - hasTimedOut = false) - - assert(intState.get == 10) - intState.update(0) - assert(intState.get == 0) - intState.remove() - intercept[NoSuchElementException] { - intState.get - } - } - // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 @@ -1268,258 +857,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { assert(e.getMessage === "The output mode of function should be append or update") } - import testImplicits._ - - /** - * FlatMapGroupsWithState function that returns the key, value as passed to it - * along with the updated state. The state is incremented for every value. - */ - val flatMapGroupsWithStateFunc = - (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { - val valList = values.toSeq - if (valList.isEmpty) { - // When the function is called on just the initial state make sure the other fields - // are set correctly - assert(state.exists) - } - assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } - assertCannotGetWatermark { state.getCurrentWatermarkMs() } - assert(!state.hasTimedOut) - if (key.contains("EventTime")) { - state.setTimeoutTimestamp(0, "1 hour") - } - if (key.contains("ProcessingTime")) { - state.setTimeoutDuration("1 hour") - } - val count = state.getOption.map(_.count).getOrElse(0L) + valList.size - // We need to check if not explicitly calling update will still save the init state or not - if (!key.contains("NoUpdate")) { - // this is not reached when valList is empty and the state count is 2 - state.update(new RunningCount(count)) - } - Iterator((key, valList, count.toString)) - } - - Seq("1", "2", "6").foreach { shufflePartitions => - testWithAllStateVersions(s"flatMapGroupsWithState - initial " + - s"state - all cases - shuffle partitions ${shufflePartitions}") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitions) { - // We will test them on different shuffle partition configuration to make sure the - // grouping by key will still work. On higher number of shuffle partitions its possible - // that all keys end up on different partitions. - val initialState: Dataset[(String, RunningCount)] = Seq( - ("keyInStateAndData-1", new RunningCount(1)), - ("keyInStateAndData-2", new RunningCount(2)), - ("keyNoUpdate", new RunningCount(2)), // state.update will not be called - ("keyOnlyInState-1", new RunningCount(1)) - ).toDS() - - val it = initialState.groupByKey(x => x._1).mapValues(_._2) - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState( - Update, GroupStateTimeout.NoTimeout, it)(flatMapGroupsWithStateFunc) - - testStream(result, Update)( - AddData(inputData, "keyOnlyInData", "keyInStateAndData-2"), - CheckNewAnswer( - ("keyOnlyInState-1", Seq[String](), "1"), - ("keyNoUpdate", Seq[String](), "2"), // update will not be called - ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1 - ("keyInStateAndData-1", Seq[String](), "1"), - ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1 - ), - assertNumStateRows(total = 5, updated = 4), - // Stop and Start stream to make sure initial state doesn't get applied again. - StopStream, - StartStream(), - AddData(inputData, "keyInStateAndData-1"), - CheckNewAnswer( - // state incremented by 1 - ("keyInStateAndData-1", Seq[String]("keyInStateAndData-1"), "2") - ), - assertNumStateRows(total = 5, updated = 1), - StopStream - ) - } - } - } - - testWithAllStateVersions("flatMapGroupsWithState - initial state - case class key") { - val stateFunc = (key: User, values: Iterator[User], state: GroupState[Long]) => { - val valList = values.toSeq - if (valList.isEmpty) { - // When the function is called on just the initial state make sure the other fields - // are set correctly - assert(state.exists) - } - assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } - assertCannotGetWatermark { state.getCurrentWatermarkMs() } - assert(!state.hasTimedOut) - val count = state.getOption.getOrElse(0L) + valList.size - // We need to check if not explicitly calling update will still save the state or not - if (!key.name.contains("NoUpdate")) { - // this is not reached when valList is empty and the state count is 2 - state.update(count) - } - Iterator((key, valList.map(_.name), count.toString)) - } - - val ds = Seq( - (User("keyInStateAndData", "1"), (1L)), - (User("keyOnlyInState", "1"), (1L)), - (User("keyNoUpdate", "2"), (2L)) // state.update will not be called on this in the function - ).toDS().groupByKey(_._1).mapValues(_._2) - - val inputData = MemoryStream[User] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(Update, NoTimeout(), ds)(stateFunc) - - testStream(result, Update)( - AddData(inputData, User("keyInStateAndData", "1"), User("keyOnlyInData", "1")), - CheckNewAnswer( - (("keyInStateAndData", "1"), Seq[String]("keyInStateAndData"), "2"), - (("keyOnlyInState", "1"), Seq[String](), "1"), - (("keyNoUpdate", "2"), Seq[String](), "2"), - (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "1") - ), - assertNumStateRows(total = 4, updated = 3), // (keyOnlyInState, 2) does not call update() - // Stop and Start stream to make sure initial state doesn't get applied again. - StopStream, - StartStream(), - AddData(inputData, User("keyOnlyInData", "1")), - CheckNewAnswer( - (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "2") - ), - assertNumStateRows(total = 4, updated = 1), - StopStream - ) - } - - testQuietly("flatMapGroupsWithState - initial state - duplicate keys") { - val initialState = Seq( - ("a", new RunningCount(2)), - ("a", new RunningCount(1)) - ).toDS().groupByKey(_._1).mapValues(_._2) - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc) - testStream(result, Update)( - AddData(inputData, "a"), - ExpectFailure[SparkException] { e => - assert(e.getCause.getMessage.contains("The initial state provided contained " + - "multiple rows(state) with the same key")) - } - ) - } - - Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout => - test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") { - // We will test them on different shuffle partition configuration to make sure the - // grouping by key will still work. On higher number of shuffle partitions its possible - // that all keys end up on different partitions. - val initialState = Seq( - (s"keyInStateAndData-1-$timeout", new RunningCount(1)), - ("keyInStateAndData-2", new RunningCount(2)), - ("keyNoUpdate", new RunningCount(2)), // state.update will not be called - ("keyOnlyInState-1", new RunningCount(1)) - ).toDS().groupByKey(x => x._1).mapValues(_._2) - - val inputData = Seq( - ("keyOnlyInData"), ("keyInStateAndData-2") - ) - val result = inputData.toDS().groupByKey(x => x) - .flatMapGroupsWithState( - Update, timeout, initialState)(flatMapGroupsWithStateFunc) - - val expected = Seq( - ("keyOnlyInState-1", Seq[String](), "1"), - ("keyNoUpdate", Seq[String](), "2"), // update will not be called - ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1 - (s"keyInStateAndData-1-$timeout", Seq[String](), "1"), - ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1 - ).toDF() - checkAnswer(result.toDF(), expected) - } - } - - testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") { - val initialState = Seq( - ("a", new RunningCount(1)), - ("a", new RunningCount(2)) - ).toDS().groupByKey(x => x._1).mapValues(_._2) - - val e = intercept[SparkException] { - Seq("a", "b").toDS().groupByKey(x => x) - .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc) - .show() - } - assert(e.getMessage.contains( - "The initial state provided contained multiple rows(state) with the same key." + - " Make sure to de-duplicate the initial state before passing it.")) - } - - testQuietly("flatMapGroupsWithState - initial state - streaming initial state") { - val initialStateData = MemoryStream[(String, RunningCount)] - initialStateData.addData(("a", new RunningCount(1))) - - val inputData = MemoryStream[String] - - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState( - Update, NoTimeout(), initialStateData.toDS().groupByKey(_._1).mapValues(_._2) - )(flatMapGroupsWithStateFunc) - - val e = intercept[AnalysisException] { - result.writeStream - .format("console") - .start() - } - - val expectedError = "Non-streaming DataFrame/Dataset is not supported" + - " as the initial state in [flatMap|map]GroupsWithState" + - " operation on a streaming DataFrame/Dataset" - assert(e.message.contains(expectedError)) - } - - test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState") { - val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS() - val initialState: KeyValueGroupedDataset[String, RunningCount] = - initialStateDS.groupByKey(_._1).mapValues(_._2) - .mapGroupsWithState( - GroupStateTimeout.NoTimeout())( - (key: String, values: Iterator[RunningCount], state: GroupState[Boolean]) => { - (key, values.next()) - } - ).groupByKey(_._1).mapValues(_._2) - - val inputData = MemoryStream[String] - - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState( - Update, NoTimeout(), initialState - )(flatMapGroupsWithStateFunc) - - testStream(result, Update)( - AddData(inputData, "keyInStateAndData"), - CheckNewAnswer( - ("keyInStateAndData", Seq[String]("keyInStateAndData"), "2") - ), - StopStream - ) - } - test("SPARK-38320 - flatMapGroupsWithState state with data should not timeout") { withTempDir { dir => withSQLConf( @@ -1564,75 +901,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } } - testWithAllStateVersions("mapGroupsWithState - initial state - null key") { - val mapGroupsWithStateFunc = - (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { - val valList = values.toList - val count = state.getOption.map(_.count).getOrElse(0L) + valList.size - state.update(new RunningCount(count)) - (key, state.get.count.toString) - } - val initialState = Seq( - ("key", new RunningCount(5)), - (null, new RunningCount(2)) - ).toDS().groupByKey(_._1).mapValues(_._2) - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(NoTimeout(), initialState)(mapGroupsWithStateFunc) - testStream(result, Update)( - AddData(inputData, "key", null), - CheckNewAnswer( - ("key", "6"), // state is incremented by 1 - (null, "3") // incremented by 1 - ), - assertNumStateRows(total = 2, updated = 2), - StopStream - ) - } - - testWithAllStateVersions("flatMapGroupsWithState - initial state - processing time timeout") { - // function will return -1 on timeout and returns count of the state otherwise - val stateFunc = - (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { - if (state.hasTimedOut) { - state.remove() - Iterator((key, "-1")) - } else { - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - state.update(RunningCount(count)) - state.setTimeoutDuration("10 seconds") - Iterator((key, count.toString)) - } - } - - val clock = new StreamManualClock - val inputData = MemoryStream[(String, Long)] - val initialState = Seq( - ("c", new RunningCount(2)) - ).toDS().groupByKey(_._1).mapValues(_._2) - val result = - inputData.toDF().toDF("key", "time") - .selectExpr("key", "timestamp_seconds(time) as timestamp") - .withWatermark("timestamp", "10 second") - .as[(String, Long)] - .groupByKey(x => x._1) - .flatMapGroupsWithState(Update, ProcessingTimeTimeout(), initialState)(stateFunc) - - testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, ("a", 1L)), - AdvanceManualClock(1 * 1000), // a and c are processed here for the first time. - CheckNewAnswer(("a", "1"), ("c", "2")), - AdvanceManualClock(10 * 1000), - AddData(inputData, ("b", 1L)), // this will trigger c and a to get timed out - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "-1"), ("b", "1"), ("c", "-1")) - ) - } - def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { // Function to maintain running count up to 2, and then remove the count @@ -1787,26 +1055,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { }.get } - def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: TestGroupState[_]): Unit = { - val prevTimestamp = state.getTimeoutTimestampMs - intercept[T] { state.setTimeoutDuration(1000) } - assert(state.getTimeoutTimestampMs === prevTimestamp) - intercept[T] { state.setTimeoutDuration("2 second") } - assert(state.getTimeoutTimestampMs === prevTimestamp) - } - - def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: TestGroupState[_]): Unit = { - val prevTimestamp = state.getTimeoutTimestampMs - intercept[T] { state.setTimeoutTimestamp(2000) } - assert(state.getTimeoutTimestampMs === prevTimestamp) - intercept[T] { state.setTimeoutTimestamp(2000, "1 second") } - assert(state.getTimeoutTimestampMs === prevTimestamp) - intercept[T] { state.setTimeoutTimestamp(new Date(2000)) } - assert(state.getTimeoutTimestampMs === prevTimestamp) - intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") } - assert(state.getTimeoutTimestampMs === prevTimestamp) - } - def newStateStore(): StateStore = new MemoryStateStore() val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) @@ -1829,8 +1077,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } } -case class User(name: String, id: String) - object FlatMapGroupsWithStateSuite { var failInTask = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala new file mode 100644 index 0000000000000..beee07b9fbcd1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, Dataset, KeyValueGroupedDataset} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.{assertCanGetProcessingTime, assertCannotGetWatermark} +import org.apache.spark.sql.streaming.GroupStateTimeout.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.streaming.util.StreamManualClock + +class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest { + import testImplicits._ + + /** + * FlatMapGroupsWithState function that returns the key, value as passed to it + * along with the updated state. The state is incremented for every value. + */ + val flatMapGroupsWithStateFunc = + (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val valList = values.toSeq + if (valList.isEmpty) { + // When the function is called on just the initial state make sure the other fields + // are set correctly + assert(state.exists) + } + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + assert(!state.hasTimedOut) + if (key.contains("EventTime")) { + state.setTimeoutTimestamp(0, "1 hour") + } + if (key.contains("ProcessingTime")) { + state.setTimeoutDuration("1 hour") + } + val count = state.getOption.map(_.count).getOrElse(0L) + valList.size + // We need to check if not explicitly calling update will still save the init state or not + if (!key.contains("NoUpdate")) { + // this is not reached when valList is empty and the state count is 2 + state.update(new RunningCount(count)) + } + Iterator((key, valList, count.toString)) + } + + Seq("1", "2", "6").foreach { shufflePartitions => + testWithAllStateVersions(s"flatMapGroupsWithState - initial " + + s"state - all cases - shuffle partitions ${shufflePartitions}") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitions) { + // We will test them on different shuffle partition configuration to make sure the + // grouping by key will still work. On higher number of shuffle partitions its possible + // that all keys end up on different partitions. + val initialState: Dataset[(String, RunningCount)] = Seq( + ("keyInStateAndData-1", new RunningCount(1)), + ("keyInStateAndData-2", new RunningCount(2)), + ("keyNoUpdate", new RunningCount(2)), // state.update will not be called + ("keyOnlyInState-1", new RunningCount(1)) + ).toDS() + + val it = initialState.groupByKey(x => x._1).mapValues(_._2) + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState( + Update, GroupStateTimeout.NoTimeout, it)(flatMapGroupsWithStateFunc) + + testStream(result, Update)( + AddData(inputData, "keyOnlyInData", "keyInStateAndData-2"), + CheckNewAnswer( + ("keyOnlyInState-1", Seq[String](), "1"), + ("keyNoUpdate", Seq[String](), "2"), // update will not be called + ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1 + ("keyInStateAndData-1", Seq[String](), "1"), + ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1 + ), + assertNumStateRows(total = 5, updated = 4), + // Stop and Start stream to make sure initial state doesn't get applied again. + StopStream, + StartStream(), + AddData(inputData, "keyInStateAndData-1"), + CheckNewAnswer( + // state incremented by 1 + ("keyInStateAndData-1", Seq[String]("keyInStateAndData-1"), "2") + ), + assertNumStateRows(total = 5, updated = 1), + StopStream + ) + } + } + } + + testWithAllStateVersions("flatMapGroupsWithState - initial state - case class key") { + val stateFunc = (key: User, values: Iterator[User], state: GroupState[Long]) => { + val valList = values.toSeq + if (valList.isEmpty) { + // When the function is called on just the initial state make sure the other fields + // are set correctly + assert(state.exists) + } + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + assert(!state.hasTimedOut) + val count = state.getOption.getOrElse(0L) + valList.size + // We need to check if not explicitly calling update will still save the state or not + if (!key.name.contains("NoUpdate")) { + // this is not reached when valList is empty and the state count is 2 + state.update(count) + } + Iterator((key, valList.map(_.name), count.toString)) + } + + val ds = Seq( + (User("keyInStateAndData", "1"), (1L)), + (User("keyOnlyInState", "1"), (1L)), + (User("keyNoUpdate", "2"), (2L)) // state.update will not be called on this in the function + ).toDS().groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[User] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), ds)(stateFunc) + + testStream(result, Update)( + AddData(inputData, User("keyInStateAndData", "1"), User("keyOnlyInData", "1")), + CheckNewAnswer( + (("keyInStateAndData", "1"), Seq[String]("keyInStateAndData"), "2"), + (("keyOnlyInState", "1"), Seq[String](), "1"), + (("keyNoUpdate", "2"), Seq[String](), "2"), + (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "1") + ), + assertNumStateRows(total = 4, updated = 3), // (keyOnlyInState, 2) does not call update() + // Stop and Start stream to make sure initial state doesn't get applied again. + StopStream, + StartStream(), + AddData(inputData, User("keyOnlyInData", "1")), + CheckNewAnswer( + (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "2") + ), + assertNumStateRows(total = 4, updated = 1), + StopStream + ) + } + + testQuietly("flatMapGroupsWithState - initial state - duplicate keys") { + val initialState = Seq( + ("a", new RunningCount(2)), + ("a", new RunningCount(1)) + ).toDS().groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc) + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[SparkException] { e => + assert(e.getCause.getMessage.contains("The initial state provided contained " + + "multiple rows(state) with the same key")) + } + ) + } + + Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout => + test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") { + // We will test them on different shuffle partition configuration to make sure the + // grouping by key will still work. On higher number of shuffle partitions its possible + // that all keys end up on different partitions. + val initialState = Seq( + (s"keyInStateAndData-1-$timeout", new RunningCount(1)), + ("keyInStateAndData-2", new RunningCount(2)), + ("keyNoUpdate", new RunningCount(2)), // state.update will not be called + ("keyOnlyInState-1", new RunningCount(1)) + ).toDS().groupByKey(x => x._1).mapValues(_._2) + + val inputData = Seq( + ("keyOnlyInData"), ("keyInStateAndData-2") + ) + val result = inputData.toDS().groupByKey(x => x) + .flatMapGroupsWithState( + Update, timeout, initialState)(flatMapGroupsWithStateFunc) + + val expected = Seq( + ("keyOnlyInState-1", Seq[String](), "1"), + ("keyNoUpdate", Seq[String](), "2"), // update will not be called + ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1 + (s"keyInStateAndData-1-$timeout", Seq[String](), "1"), + ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1 + ).toDF() + checkAnswer(result.toDF(), expected) + } + } + + testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") { + val initialState = Seq( + ("a", new RunningCount(1)), + ("a", new RunningCount(2)) + ).toDS().groupByKey(x => x._1).mapValues(_._2) + + val e = intercept[SparkException] { + Seq("a", "b").toDS().groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc) + .show() + } + assert(e.getMessage.contains( + "The initial state provided contained multiple rows(state) with the same key." + + " Make sure to de-duplicate the initial state before passing it.")) + } + + testQuietly("flatMapGroupsWithState - initial state - streaming initial state") { + val initialStateData = MemoryStream[(String, RunningCount)] + initialStateData.addData(("a", new RunningCount(1))) + + val inputData = MemoryStream[String] + + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState( + Update, NoTimeout(), initialStateData.toDS().groupByKey(_._1).mapValues(_._2) + )(flatMapGroupsWithStateFunc) + + val e = intercept[AnalysisException] { + result.writeStream + .format("console") + .start() + } + + val expectedError = "Non-streaming DataFrame/Dataset is not supported" + + " as the initial state in [flatMap|map]GroupsWithState" + + " operation on a streaming DataFrame/Dataset" + assert(e.message.contains(expectedError)) + } + + test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState") { + val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS() + val initialState: KeyValueGroupedDataset[String, RunningCount] = + initialStateDS.groupByKey(_._1).mapValues(_._2) + .mapGroupsWithState( + GroupStateTimeout.NoTimeout())( + (key: String, values: Iterator[RunningCount], state: GroupState[Boolean]) => { + (key, values.next()) + } + ).groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[String] + + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState( + Update, NoTimeout(), initialState + )(flatMapGroupsWithStateFunc) + + testStream(result, Update)( + AddData(inputData, "keyInStateAndData"), + CheckNewAnswer( + ("keyInStateAndData", Seq[String]("keyInStateAndData"), "2") + ), + StopStream + ) + } + + testWithAllStateVersions("mapGroupsWithState - initial state - null key") { + val mapGroupsWithStateFunc = + (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val valList = values.toList + val count = state.getOption.map(_.count).getOrElse(0L) + valList.size + state.update(new RunningCount(count)) + (key, state.get.count.toString) + } + val initialState = Seq( + ("key", new RunningCount(5)), + (null, new RunningCount(2)) + ).toDS().groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(NoTimeout(), initialState)(mapGroupsWithStateFunc) + testStream(result, Update)( + AddData(inputData, "key", null), + CheckNewAnswer( + ("key", "6"), // state is incremented by 1 + (null, "3") // incremented by 1 + ), + assertNumStateRows(total = 2, updated = 2), + StopStream + ) + } + + testWithAllStateVersions("flatMapGroupsWithState - initial state - processing time timeout") { + // function will return -1 on timeout and returns count of the state otherwise + val stateFunc = + (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val initialState = Seq( + ("c", new RunningCount(2)) + ).toDS().groupByKey(_._1).mapValues(_._2) + val result = + inputData.toDF().toDF("key", "time") + .selectExpr("key", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, Long)] + .groupByKey(x => x._1) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout(), initialState)(stateFunc) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), // a and c are processed here for the first time. + CheckNewAnswer(("a", "1"), ("c", "2")), + AdvanceManualClock(10 * 1000), + AddData(inputData, ("b", 1L)), // this will trigger c and a to get timed out + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "1"), ("c", "-1")) + ) + } + + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { + for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { + test(s"$name - state format version $version") { + withSQLConf( + SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString, + SQLConf.STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED.key -> "false") { + func + } + } + } + } +} + +case class User(name: String, id: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala new file mode 100644 index 0000000000000..93dac3406df9d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.sql.Date + +import org.apache.spark.SparkFunSuite +import org.apache.spark.api.java.Optional +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.streaming.GroupStateTimeout.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout} + +class GroupStateSuite extends SparkFunSuite { + + test("SPARK-35800: ensure TestGroupState creates instances the same as prod") { + val testState = TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, 1L, Optional.of(1L), hasTimedOut = false) + + val prodState = GroupStateImpl.createForStreaming[Int]( + Some(5), 1L, 1L, EventTimeTimeout, false, true) + + assert(testState.isInstanceOf[GroupStateImpl[Int]]) + + assert(testState.isRemoved === prodState.isRemoved) + assert(testState.isUpdated === prodState.isUpdated) + assert(testState.exists === prodState.exists) + assert(testState.get === prodState.get) + assert(testState.getTimeoutTimestampMs === prodState.getTimeoutTimestampMs) + assert(testState.hasTimedOut === prodState.hasTimedOut) + assert(testState.getCurrentProcessingTimeMs === prodState.getCurrentProcessingTimeMs) + assert(testState.getCurrentWatermarkMs === prodState.getCurrentWatermarkMs) + + testState.update(6) + prodState.update(6) + assert(testState.isUpdated === prodState.isUpdated) + assert(testState.exists === prodState.exists) + assert(testState.get === prodState.get) + + testState.remove() + prodState.remove() + assert(testState.exists === prodState.exists) + assert(testState.isRemoved === prodState.isRemoved) + } + + test("GroupState - get, exists, update, remove") { + var state: TestGroupState[String] = null + + def testState( + expectedData: Option[String], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get + } + } + assert(state.getOption === expectedData) + assert(state.isUpdated === shouldBeUpdated) + assert(state.isRemoved === shouldBeRemoved) + } + + // === Tests for state in streaming queries === + // Updating empty state + state = TestGroupState.create[String]( + Optional.empty[String], NoTimeout, 1, Optional.empty[Long], hasTimedOut = false) + testState(None) + state.update("") + testState(Some(""), shouldBeUpdated = true) + + // Updating exiting state + state = TestGroupState.create[String]( + Optional.of("2"), NoTimeout, 1, Optional.empty[Long], hasTimedOut = false) + testState(Some("2")) + state.update("3") + testState(Some("3"), shouldBeUpdated = true) + + // Removing state + state.remove() + testState(None, shouldBeRemoved = true, shouldBeUpdated = false) + state.remove() // should be still callable + state.update("4") + testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) + + // Updating by null throw exception + intercept[IllegalArgumentException] { + state.update(null) + } + } + + test("GroupState - setTimeout - with NoTimeout") { + for (initValue <- Seq(Optional.empty[Int], Optional.of((5)))) { + val states = Seq( + TestGroupState.create[Int]( + initValue, NoTimeout, 1000, Optional.empty[Long], hasTimedOut = false), + GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false) + ) + for (state <- states) { + // for streaming queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } + } + } + + test("GroupState - setTimeout - with ProcessingTimeTimeout") { + // for streaming queries + var state = TestGroupState.create[Int]( + Optional.empty[Int], ProcessingTimeTimeout, 1000, Optional.empty[Long], hasTimedOut = false) + assert(!state.getTimeoutTimestampMs.isPresent()) + state.setTimeoutDuration("-1 month 31 days 1 second") + assert(state.getTimeoutTimestampMs.isPresent()) + assert(state.getTimeoutTimestampMs.get() === 2000) + state.setTimeoutDuration(500) + assert(state.getTimeoutTimestampMs.get() === 1500) // can be set without initializing state + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.update(5) + assert(state.getTimeoutTimestampMs.isPresent()) + assert(state.getTimeoutTimestampMs.get() === 1500) // does not change + state.setTimeoutDuration(1000) + assert(state.getTimeoutTimestampMs.get() === 2000) + state.setTimeoutDuration("2 second") + assert(state.getTimeoutTimestampMs.get() === 3000) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestampMs.isPresent()) + assert(state.getTimeoutTimestampMs.get() === 3000) // does not change + state.setTimeoutDuration(500) // can still be set + assert(state.getTimeoutTimestampMs.get() === 1500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch( + ProcessingTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] + assert(!state.getTimeoutTimestampMs.isPresent()) + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.update(5) + state.setTimeoutDuration(1000) + state.setTimeoutDuration("2 second") + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } + + test("GroupState - setTimeout - with EventTimeTimeout") { + var state = TestGroupState.create[Int]( + Optional.empty[Int], EventTimeTimeout, 1000, Optional.of(1000), hasTimedOut = false) + assert(!state.getTimeoutTimestampMs.isPresent()) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestampMs.get() === 5000) // can be set without initializing state + + state.update(5) + assert(state.getTimeoutTimestampMs.get() === 5000) // does not change + state.setTimeoutTimestamp(10000) + assert(state.getTimeoutTimestampMs.get() === 10000) + state.setTimeoutTimestamp(new Date(20000)) + assert(state.getTimeoutTimestampMs.get() === 20000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestampMs.get() === 20000) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestampMs.get() === 5000) // can be set after removing state + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch( + EventTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] + assert(!state.getTimeoutTimestampMs.isPresent()) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + state.setTimeoutTimestamp(5000) + + state.update(5) + state.setTimeoutTimestamp(10000) + state.setTimeoutTimestamp(new Date(20000)) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutTimestamp(5000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + } + + test("GroupState - illegal params to setTimeout") { + var state: TestGroupState[Int] = null + + // Test setTimeout() with illegal values + def testIllegalTimeout(body: => Unit): Unit = { + intercept[IllegalArgumentException] { + body + } + assert(!state.getTimeoutTimestampMs.isPresent()) + } + + // Test setTimeout() with illegal values + state = TestGroupState.create[Int]( + Optional.of(5), ProcessingTimeTimeout, 1000, Optional.empty[Long], hasTimedOut = false) + + testIllegalTimeout { + state.setTimeoutDuration(-1000) + } + testIllegalTimeout { + state.setTimeoutDuration(0) + } + testIllegalTimeout { + state.setTimeoutDuration("-2 second") + } + testIllegalTimeout { + state.setTimeoutDuration("-1 month") + } + + testIllegalTimeout { + state.setTimeoutDuration("1 month -31 day") + } + + state = TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, 1000, Optional.of(1000), hasTimedOut = false) + testIllegalTimeout { + state.setTimeoutTimestamp(-10000) + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "1 month -32 day") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000)) + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "1 month -32 day") + } + } + + test("SPARK-35800: illegal params to create") { + // eventTimeWatermarkMs >= 0 if present + var illegalArgument = intercept[IllegalArgumentException] { + TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, 100L, Optional.of(-1000), hasTimedOut = false) + } + assert( + illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or positive if present")) + illegalArgument = intercept[IllegalArgumentException] { + GroupStateImpl.createForStreaming[Int]( + Some(5), 100L, -1000L, EventTimeTimeout, false, true) + } + assert( + illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or positive if present")) + + // batchProcessingTimeMs must be positive + illegalArgument = intercept[IllegalArgumentException] { + TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, -100L, Optional.of(1000), hasTimedOut = false) + } + assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 0 or positive")) + illegalArgument = intercept[IllegalArgumentException] { + GroupStateImpl.createForStreaming[Int]( + Some(5), -100L, 1000L, EventTimeTimeout, false, true) + } + assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 0 or positive")) + + // hasTimedOut cannot be true if there's no timeout configured + var unsupportedOperation = intercept[UnsupportedOperationException] { + TestGroupState.create[Int]( + Optional.of(5), NoTimeout, 100L, Optional.empty[Long], hasTimedOut = true) + } + assert( + unsupportedOperation + .getMessage.contains("hasTimedOut is true however there's no timeout configured")) + unsupportedOperation = intercept[UnsupportedOperationException] { + GroupStateImpl.createForStreaming[Int]( + Some(5), 100L, NO_TIMESTAMP, NoTimeout, true, false) + } + assert( + unsupportedOperation + .getMessage.contains("hasTimedOut is true however there's no timeout configured")) + } + + test("GroupState - hasTimedOut") { + for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + // for streaming queries + for (initState <- Seq(Optional.empty[Int], Optional.of(5))) { + val state1 = TestGroupState.create[Int]( + initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = false) + assert(state1.hasTimedOut === false) + + // hasTimedOut can only be set as true when timeoutConf isn't NoTimeout + if (timeoutConf != NoTimeout) { + val state2 = TestGroupState.create[Int]( + initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = true) + assert(state2.hasTimedOut) + } + } + + // for batch queries + assert( + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent = false).hasTimedOut === false) + } + } + + test("GroupState - getCurrentWatermarkMs") { + def streamingState( + timeoutConf: GroupStateTimeout, + watermark: Optional[Long]): GroupState[Int] = { + TestGroupState.create[Int]( + Optional.empty[Int], timeoutConf, 1000, watermark, hasTimedOut = false) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + def assertWrongTimeoutError(test: => Unit): Unit = { + val e = intercept[UnsupportedOperationException] { test } + assert(e.getMessage.contains( + "Cannot get event time watermark timestamp without setting watermark")) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + // Tests for getCurrentWatermarkMs in streaming queries + assertWrongTimeoutError { + streamingState(timeoutConf, Optional.empty[Long]).getCurrentWatermarkMs() + } + assert(streamingState(timeoutConf, Optional.of(0)).getCurrentWatermarkMs() === 0) + assert(streamingState(timeoutConf, Optional.of(1000)).getCurrentWatermarkMs() === 1000) + assert(streamingState(timeoutConf, Optional.of(2000)).getCurrentWatermarkMs() === 2000) + assert(batchState(EventTimeTimeout, watermarkPresent = true).getCurrentWatermarkMs() === -1) + + // Tests for getCurrentWatermarkMs in batch queries + assertWrongTimeoutError { + batchState(timeoutConf, watermarkPresent = false).getCurrentWatermarkMs() + } + assert(batchState(timeoutConf, watermarkPresent = true).getCurrentWatermarkMs() === -1) + } + } + + test("GroupState - getCurrentProcessingTimeMs") { + def streamingState( + timeoutConf: GroupStateTimeout, + procTime: Long, + watermarkPresent: Boolean): GroupState[Int] = { + val eventTimeWatermarkMs = if (watermarkPresent) { + Optional.of(1000L) + } else { + Optional.empty[Long] + } + TestGroupState.create[Int]( + Optional.of(1000), timeoutConf, procTime, eventTimeWatermarkMs, hasTimedOut = false) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + for (watermarkPresent <- Seq(false, true)) { + // Tests for getCurrentProcessingTimeMs in streaming queries + // No negative processing time is allowed, and + // illegal input validation has been added in the separate test + assert(streamingState(timeoutConf, 0, watermarkPresent) + .getCurrentProcessingTimeMs() === 0) + assert(streamingState(timeoutConf, 1000, watermarkPresent) + .getCurrentProcessingTimeMs() === 1000) + assert(streamingState(timeoutConf, 2000, watermarkPresent) + .getCurrentProcessingTimeMs() === 2000) + + // Tests for getCurrentProcessingTimeMs in batch queries + val currentTime = System.currentTimeMillis() + assert(batchState(timeoutConf, watermarkPresent).getCurrentProcessingTimeMs >= currentTime) + } + } + } + + test("GroupState - primitive type") { + var intState = TestGroupState.create[Int]( + Optional.empty[Int], + NoTimeout, + 1000, + Optional.empty[Long], + hasTimedOut = false) + intercept[NoSuchElementException] { + intState.get + } + assert(intState.getOption === None) + + intState = TestGroupState.create[Int]( + Optional.of(10), + NoTimeout, + 1000, + Optional.empty[Long], + hasTimedOut = false) + + assert(intState.get == 10) + intState.update(0) + assert(intState.get == 0) + intState.remove() + intercept[NoSuchElementException] { + intState.get + } + } + + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: TestGroupState[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestampMs + intercept[T] { state.setTimeoutDuration(1000) } + assert(state.getTimeoutTimestampMs === prevTimestamp) + intercept[T] { state.setTimeoutDuration("2 second") } + assert(state.getTimeoutTimestampMs === prevTimestamp) + } + + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: TestGroupState[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestampMs + intercept[T] { state.setTimeoutTimestamp(2000) } + assert(state.getTimeoutTimestampMs === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(2000, "1 second") } + assert(state.getTimeoutTimestampMs === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000)) } + assert(state.getTimeoutTimestampMs === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") } + assert(state.getTimeoutTimestampMs === prevTimestamp) + } +}