Skip to content

Commit

Permalink
[SPARK-50152][SS] Support handleInitialState with state data source r…
Browse files Browse the repository at this point in the history
…eader

### What changes were proposed in this pull request?

This PR adds support for users to provide a Dataframe that can be used to instantiate state for the query in the first batch for arbitrary state API v2. More specifically, this dataframe is coming from state data source reader.

Remove the restraints that initialState dataframe can only contains one value row for a grouping key. This is to enable the integration with state data source reader. In flattened state data source reader for composite type, we will have multiple value rows mapping to the same grouping key.

For example, we can union dataframe created by state data source reader on a single state variable and union them together and get an output dataframe as initial state for a transformWithState operator like this:
```
+-----------+-----+---------+----------+------------+
|groupingKey|value|listValue|userMapKey|userMapValue|
+-----------+-----+---------+----------+------------+
|a          |3    |NULL     |NULL      |NULL        |
|b          |2    |NULL     |NULL      |NULL        |
|a          |NULL |1        |NULL      |NULL        |
|a          |NULL |2        |NULL      |NULL        |
|a          |NULL |3        |NULL      |NULL        |
|b          |NULL |1        |NULL      |NULL        |
|b          |NULL |2        |NULL      |NULL        |
|a          |NULL |NULL     |a         |3           |
|b          |NULL |NULL     |b         |2           |
+-----------+-----+---------+----------+------------+
```
### Why are the changes needed?

This change is for supporting initial state handling for integration with state data source reader.

### Does this PR introduce _any_ user-facing change?

No. The user API is the same as prior PR: #45467 for initial state support without state data source reader.

### How was this patch tested?

Unit test cases added in `TransformWithStateWithInitialStateSuite`.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48686 from jingz-db/initial-state-reader-integration.

Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
jingz-db and HeartSaVioR committed Nov 12, 2024
1 parent d96c623 commit 5432cef
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 56 deletions.
6 changes: 0 additions & 6 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4233,12 +4233,6 @@
],
"sqlState" : "42802"
},
"STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : {
"message" : [
"Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=<groupingKey>."
],
"sqlState" : "42802"
},
"STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : {
"message" : [
"State variable with name <stateVarName> has already been defined in the StatefulProcessor."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ private[sql] abstract class StatefulProcessorWithInitialState[K, I, O, S]

/**
* Function that will be invoked only in the first batch for users to process initial states.
* The provided initial state can be arbitrary dataframe with the same grouping key schema with
* the input rows, e.g. dataframe from data source reader of existing streaming query
* checkpoint.
*
* @param key
* \- grouping key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,9 @@ case class TransformWithStateExec(
ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
val initStateObjIter = initStateIter.map(getInitStateValueObj.apply)

var seenInitStateOnKey = false
initStateObjIter.foreach { initState =>
// cannot re-initialize state on the same grouping key during initial state handling
if (seenInitStateOnKey) {
throw StateStoreErrors.cannotReInitializeStateOnKey(keyObj.toString)
}
seenInitStateOnKey = true
// allow multiple initial state rows on the same grouping key for integration
// with state data source reader with initial state
statefulProcessor
.asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
.handleInitialState(keyObj, initState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@ object StateStoreErrors {
new StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, handleState)
}

def cannotReInitializeStateOnKey(groupingKey: String):
StatefulProcessorCannotReInitializeState = {
new StatefulProcessorCannotReInitializeState(groupingKey)
}

def cannotProvideTTLConfigForTimeMode(stateName: String, timeMode: String):
StatefulProcessorCannotAssignTTLInTimeMode = {
new StatefulProcessorCannotAssignTTLInTimeMode(stateName, timeMode)
Expand Down Expand Up @@ -272,11 +267,6 @@ class StatefulProcessorCannotPerformOperationWithInvalidHandleState(
messageParameters = Map("operationType" -> operationType, "handleState" -> handleState)
)

class StatefulProcessorCannotReInitializeState(groupingKey: String)
extends SparkUnsupportedOperationException(
errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
messageParameters = Map("groupingKey" -> groupingKey))

class StateStoreUnsupportedOperationOnMissingColumnFamily(
operationType: String,
colFamilyName: String) extends SparkUnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,33 @@

package org.apache.spark.sql.streaming

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.{Dataset, Encoders, KeyValueGroupedDataset}
import org.apache.spark.sql.{DataFrame, Dataset, Encoders, KeyValueGroupedDataset}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.functions.{col, timestamp_seconds}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock

case class InitInputRow(key: String, action: String, value: Double)
case class InputRowForInitialState(
key: String, value: Double, entries: List[Double], mapping: Map[Double, Int])

case class UnionInitialStateRow(
groupingKey: String,
value: Option[Long],
listValue: Option[Long],
userMapKey: Option[String],
userMapValue: Option[Long]
)

case class UnionUnflattenInitialStateRow(
groupingKey: String,
value: Option[Long],
listValue: Option[Seq[Long]],
mapValue: Option[Map[String, Long]]
)

abstract class StatefulProcessorWithInitialStateTestClass[V]
extends StatefulProcessorWithInitialState[
String, InitInputRow, (String, String, Double), V] {
Expand Down Expand Up @@ -86,6 +101,86 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
}
}

/**
* Class that updates all state variables with input rows. Act as a counter -
* keep the count in value state; keep all the occurrences in list state; and
* keep a map of key -> occurrence count in the map state.
*/
abstract class InitialStateWithStateDataSourceBase[V]
extends StatefulProcessorWithInitialState[
String, String, (String, String), V] {
@transient var _valState: ValueState[Long] = _
@transient var _listState: ListState[Long] = _
@transient var _mapState: MapState[String, Long] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_valState = getHandle.getValueState[Long]("testVal", Encoders.scalaLong, TTLConfig.NONE)
_listState = getHandle.getListState[Long]("testList", Encoders.scalaLong, TTLConfig.NONE)
_mapState = getHandle.getMapState[String, Long](
"testMap", Encoders.STRING, Encoders.scalaLong, TTLConfig.NONE)
}

override def handleInputRows(
key: String,
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
val curCountValue = if (_valState.exists()) {
_valState.get()
} else {
0L
}
var cnt = curCountValue
inputRows.foreach { row =>
cnt += 1
_listState.appendValue(cnt)
val mapCurVal = if (_mapState.containsKey(row)) {
_mapState.getValue(row)
} else {
0
}
_mapState.updateValue(row, mapCurVal + 1L)
}
_valState.update(cnt)
Iterator.single((key, cnt.toString))
}

override def close(): Unit = super.close()
}

class InitialStatefulProcessorWithStateDataSource
extends InitialStateWithStateDataSourceBase[UnionInitialStateRow] {
override def handleInitialState(
key: String, initialState: UnionInitialStateRow, timerValues: TimerValues): Unit = {
if (initialState.value.isDefined) {
_valState.update(initialState.value.get)
} else if (initialState.listValue.isDefined) {
_listState.appendValue(initialState.listValue.get)
} else if (initialState.userMapKey.isDefined) {
_mapState.updateValue(
initialState.userMapKey.get, initialState.userMapValue.get)
}
}
}

class InitialStatefulProcessorWithUnflattenStateDataSource
extends InitialStateWithStateDataSourceBase[UnionUnflattenInitialStateRow] {
override def handleInitialState(
key: String, initialState: UnionUnflattenInitialStateRow, timerValues: TimerValues): Unit = {
if (initialState.value.isDefined) {
_valState.update(initialState.value.get)
} else if (initialState.listValue.isDefined) {
_listState.appendList(
initialState.listValue.get.toArray)
} else if (initialState.mapValue.isDefined) {
initialState.mapValue.get.keys.foreach { key =>
_mapState.updateValue(key, initialState.mapValue.get.get(key).get)
}
}
}
}

class AccumulateStatefulProcessorWithInitState
extends StatefulProcessorWithInitialStateTestClass[(String, Double)] {
override def handleInitialState(
Expand Down Expand Up @@ -398,37 +493,6 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
checkAnswer(df, Seq(("k1", "getOption", 37.0)).toDF())
}

test("transformWithStateWithInitialState - " +
"cannot re-initialize state during initial state handling") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {
val initDf = Seq(("init_1", 40.0), ("init_2", 100.0), ("init_1", 50.0)).toDS()
.groupByKey(x => x._1).mapValues(x => x)
val inputData = MemoryStream[InitInputRow]
val query = inputData.toDS()
.groupByKey(x => x.key)
.transformWithState(new AccumulateStatefulProcessorWithInitState(),
TimeMode.None(),
OutputMode.Append(),
initDf)

testStream(query, OutputMode.Update())(
AddData(inputData, InitInputRow("k1", "add", 50.0)),
Execute { q =>
val e = intercept[Exception] {
q.processAllAvailable()
}
checkError(
exception = e.getCause.asInstanceOf[SparkUnsupportedOperationException],
condition = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
sqlState = Some("42802"),
parameters = Map("groupingKey" -> "init_1")
)
}
)
}
}

test("transformWithStateWithInitialState - streaming with processing time timer, " +
"can emit expired initial state rows when grouping key is not received for new input rows") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
Expand Down Expand Up @@ -503,4 +567,158 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
)
}
}

Seq(true, false).foreach { flattenOption =>
Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
test("state data source reader dataframe as initial state " +
s"(flatten option=$flattenOption, shuffle partition for 1st stream=${partitions._1}, " +
s"shuffle partition for 1st stream=${partitions._2})") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {
withTempPaths(2) { checkpointDirs =>
SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, partitions._1)
val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new InitialStatefulProcessorWithStateDataSource(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = checkpointDirs(0).getCanonicalPath),
AddData(inputData, "a", "b"),
CheckNewAnswer(("a", "1"), ("b", "1")),
AddData(inputData, "a", "b", "a"),
CheckNewAnswer(("a", "3"), ("b", "2"))
)

// We are trying to mimic a use case where users load all state data rows
// from a previous tws query as initial state and start a new tws query.
// In this use case, users will need to create a single dataframe with
// all the state rows from different state variables with different schema.
// We can only read from one state variable from one state data source reader
// query, and they are of different schema. We will get one dataframe from each
// state variable, and we union them together into a single dataframe.
val valueDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, checkpointDirs(0).getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "testVal")
.load()
.drop("partition_id")

val listDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, checkpointDirs(0).getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "testList")
.option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption)
.load()
.drop("partition_id")

val mapDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, checkpointDirs(0).getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "testMap")
.option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption)
.load()
.drop("partition_id")

// create a df where each row contains all value, list, map state rows
// fill the missing column with null.
SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, partitions._2)
val inputData2 = MemoryStream[String]
val query = startQueryWithDataSourceDataframeAsInitState(
flattenOption, valueDf, listDf, mapDf, inputData2)

testStream(query, OutputMode.Update())(
StartStream(checkpointLocation = checkpointDirs(1).getCanonicalPath),
// check initial state is updated for state vars
AddData(inputData2, "c"),
CheckNewAnswer(("c", "1")),
Execute { _ =>
val valueDf2 = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, checkpointDirs(1).getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "testVal")
.option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption)
.load()
.drop("partition_id")
.filter(col("key.value") =!= "c")

val listDf2 = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, checkpointDirs(1).getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "testList")
.option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption)
.load()
.drop("partition_id")
.filter(col("key.value") =!= "c")

val mapDf2 = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, checkpointDirs(1).getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "testMap")
.option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption)
.load()
.drop("partition_id")
.filter(col("key.value") =!= "c")

checkAnswer(valueDf, valueDf2)
checkAnswer(listDf, listDf2)
checkAnswer(mapDf, mapDf2)
}
)
}
}
}
}
}

private def startQueryWithDataSourceDataframeAsInitState(
flattenOption: Boolean,
valDf: DataFrame,
listDf: DataFrame,
mapDf: DataFrame,
inputData: MemoryStream[String]): DataFrame = {
if (flattenOption) {
// when we read the state rows with flattened option set to true, values of a composite
// state variable will be flattened into multiple rows where each row is a
// key -> single value pair
val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS value")
val flattenListDf = listDf
.selectExpr("key.value AS groupingKey", "list_element.value AS listValue")
val flattenMapDf = mapDf
.selectExpr(
"key.value AS groupingKey",
"user_map_key.value AS userMapKey",
"user_map_value.value AS userMapValue")
val df_joined =
valueDf.unionByName(flattenListDf, true)
.unionByName(flattenMapDf, true)
val kvDataSet = inputData.toDS().groupByKey(x => x)
val initDf = df_joined.as[UnionInitialStateRow].groupByKey(x => x.groupingKey)
(kvDataSet.transformWithState(
new InitialStatefulProcessorWithStateDataSource(),
TimeMode.None(), OutputMode.Append(), initDf).toDF())
} else {
// when we read the state rows with flattened option set to false, values of a composite
// state variable will be composed into a single row of list/map type
val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS value")
val unflattenListDf = listDf
.selectExpr("key.value AS groupingKey",
"list_value.value as listValue")
val unflattenMapDf = mapDf
.selectExpr(
"key.value AS groupingKey",
"map_from_entries(transform(map_entries(map_value), x -> " +
"struct(x.key.value, x.value.value))) as mapValue")
val df_joined =
valueDf.unionByName(unflattenListDf, true)
.unionByName(unflattenMapDf, true)
val kvDataSet = inputData.toDS().groupByKey(x => x)
val initDf = df_joined.as[UnionUnflattenInitialStateRow].groupByKey(x => x.groupingKey)
kvDataSet.transformWithState(
new InitialStatefulProcessorWithUnflattenStateDataSource(),
TimeMode.None(), OutputMode.Append(), initDf).toDF()
}
}
}

0 comments on commit 5432cef

Please sign in to comment.