Skip to content

Commit

Permalink
[SPARK-49191][SS] Add support for reading transformWithState map stat…
Browse files Browse the repository at this point in the history
…e variables with state data source reader

### What changes were proposed in this pull request?

Add support for reading transformWithState map state variables with state data source reader.

### Why are the changes needed?

Changes are needed to integrate reading state reading with new operator metadata and state schema format for the map state types used in state variables within transformWithState

### Does this PR introduce _any_ user-facing change?

No. Similar way as reading valueState, user can now read mapState state var as:
```
spark
   .read
   .format("statestore")
   .option("operatorId", <operatorId>)
   .option("stateVarName", <mapStateVarName>)
   .load(<state path>)
```
The output Dataframe will look like:
```
+----+---------------------------+------------+
|key |map_value                  |partition_id|
+----+---------------------------+------------+
|{k1}|{{v2} -> {5}, {v1} -> {10}}|0           |
|{k2}|{{v2} -> {3}}              |0           |
+----+---------------------------+------------+
```
Or this if TTL is enabled:
```
+----+------------------------------------------------+------------+
|key |map_value                                       |partition_id|
+----+------------------------------------------------+------------+
|{k1}|{{key2} -> {{2}, 61000}, {key1} -> {{1}, 61000}}|0           |
+----+------------------------------------------------+------------+
```
An example schema for output dataframe:
```
root
 |-- key: struct (nullable = true) # grouping key row
 |    |-- value: string (nullable = true)
 |-- map_value: map (nullable = true)
 |    |-- key: struct # user key row
 |    |    |-- value: string (nullable = true)
 |    |-- value: struct (valueContainsNull = false) # value row in state store
 |    |    |-- value: struct (nullable = true) # value row
 |    |    |    |-- value: integer (nullable = false)
 |    |    |-- ttlExpirationMs: long (nullable = true) # ttl column
 |-- partition_id: integer (nullable = true)
```
### How was this patch tested?

Unit tests in `StateDataSourceTransformWithStateSuite`

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48000 from jingz-db/map-state-rebase.

Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jing Zhan <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
2 people authored and HeartSaVioR committed Sep 10, 2024
1 parent c4a396b commit 8732528
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ abstract class StatePartitionReaderBase(
private val schemaForValueRow: StructType =
StructType(Array(StructField("__dummy__", NullType)))

protected val keySchema = SchemaUtil.getSchemaAsDataType(
schema, "key").asInstanceOf[StructType]
protected val keySchema = {
if (!SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) {
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
} else SchemaUtil.getCompositeKeySchema(schema)
}

protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
schemaForValueRow
Expand Down Expand Up @@ -178,38 +181,43 @@ class StatePartitionReader(
override lazy val iter: Iterator[InternalRow] = {
val stateVarName = stateVariableInfoOpt
.map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
store
.iterator(stateVarName)
.map { pair =>
stateVariableInfoOpt match {
case Some(stateVarInfo) =>
val stateVarType = stateVarInfo.stateVariableType
if (SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) {
SchemaUtil.unifyMapStateRowPair(
store.iterator(stateVarName), keySchema, partition.partition)
} else {
store
.iterator(stateVarName)
.map { pair =>
stateVariableInfoOpt match {
case Some(stateVarInfo) =>
val stateVarType = stateVarInfo.stateVariableType

stateVarType match {
case StateVariableType.ValueState =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
stateVarType match {
case StateVariableType.ValueState =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)

case StateVariableType.ListState =>
val key = pair.key
val result = store.valuesIterator(key, stateVarName)
var unsafeRowArr: Seq[UnsafeRow] = Seq.empty
result.foreach { entry =>
unsafeRowArr = unsafeRowArr :+ entry.copy()
}
// convert the list of values to array type
val arrData = new GenericArrayData(unsafeRowArr.toArray)
SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData),
partition.partition)
case StateVariableType.ListState =>
val key = pair.key
val result = store.valuesIterator(key, stateVarName)
var unsafeRowArr: Seq[UnsafeRow] = Seq.empty
result.foreach { entry =>
unsafeRowArr = unsafeRowArr :+ entry.copy()
}
// convert the list of values to array type
val arrData = new GenericArrayData(unsafeRowArr.toArray)
SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData),
partition.partition)

case _ =>
throw new IllegalStateException(
s"Unsupported state variable type: $stateVarType")
}
case _ =>
throw new IllegalStateException(
s"Unsupported state variable type: $stateVarType")
}

case None =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
case None =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
}
}
}
}
}

override def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
*/
package org.apache.spark.sql.execution.datasources.v2.state.utils

import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions}
import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.execution.streaming.StateVariableType._
import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo
import org.apache.spark.sql.execution.streaming.state.{StateStoreColFamilySchema, UnsafeRowPair}
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType}
import org.apache.spark.util.ArrayImplicits._

object SchemaUtil {
Expand Down Expand Up @@ -81,6 +85,112 @@ object SchemaUtil {
row
}

/**
* For map state variables, state rows are stored as composite key.
* To return grouping key -> Map{user key -> value} as one state reader row to
* the users, we need to perform grouping on state rows by their grouping key,
* and construct a map for that grouping key.
*
* We traverse the iterator returned from state store,
* and will only return a row for `next()` only if the grouping key in the next row
* from state store is different (or there are no more rows)
*
* Note that all state rows with the same grouping key are co-located so they will
* appear consecutively during the iterator traversal.
*/
def unifyMapStateRowPair(
stateRows: Iterator[UnsafeRowPair],
compositeKeySchema: StructType,
partitionId: Int): Iterator[InternalRow] = {
val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
compositeKeySchema, "key"
).asInstanceOf[StructType]
val userKeySchema = SchemaUtil.getSchemaAsDataType(
compositeKeySchema, "userKey"
).asInstanceOf[StructType]

def appendKVPairToMap(
curMap: mutable.Map[Any, Any],
stateRowPair: UnsafeRowPair): Unit = {
curMap += (
stateRowPair.key.get(1, userKeySchema)
.asInstanceOf[UnsafeRow].copy() ->
stateRowPair.value.copy()
)
}

def createDataRow(
groupingKey: Any,
curMap: mutable.Map[Any, Any]): GenericInternalRow = {
val row = new GenericInternalRow(3)
val mapData = ArrayBasedMapData(curMap)
row.update(0, groupingKey)
row.update(1, mapData)
row.update(2, partitionId)
row
}

// All of the rows with the same grouping key were co-located and were
// grouped together consecutively.
new Iterator[InternalRow] {
var curGroupingKey: UnsafeRow = _
var curStateRowPair: UnsafeRowPair = _
val curMap = mutable.Map.empty[Any, Any]

override def hasNext: Boolean =
stateRows.hasNext || !curMap.isEmpty

override def next(): InternalRow = {
var foundNewGroupingKey = false
while (stateRows.hasNext && !foundNewGroupingKey) {
curStateRowPair = stateRows.next()
if (curGroupingKey == null) {
// First time in the iterator
// Need to make a copy because we need to keep the
// value across function calls
curGroupingKey = curStateRowPair.key
.get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy()
appendKVPairToMap(curMap, curStateRowPair)
} else {
val curPairGroupingKey =
curStateRowPair.key.get(0, groupingKeySchema)
if (curPairGroupingKey == curGroupingKey) {
appendKVPairToMap(curMap, curStateRowPair)
} else {
// find a different grouping key, exit loop and return a row
foundNewGroupingKey = true
}
}
}
if (foundNewGroupingKey) {
// found a different grouping key
val row = createDataRow(curGroupingKey, curMap)
// update vars
curGroupingKey =
curStateRowPair.key.get(0, groupingKeySchema)
.asInstanceOf[UnsafeRow].copy()
// empty the map, append current row
curMap.clear()
appendKVPairToMap(curMap, curStateRowPair)
// return map value of previous grouping key
row
} else {
if (curMap.isEmpty) {
throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " +
"user is trying to get element from an exhausted iterator.")
}
else {
// reach the end of the state rows
val row = createDataRow(curGroupingKey, curMap)
// clear the map to end the iterator
curMap.clear()
row
}
}
}
}
}

def isValidSchema(
sourceOptions: StateSourceOptions,
schema: StructType,
Expand All @@ -92,6 +202,7 @@ object SchemaUtil {
"value" -> classOf[StructType],
"single_value" -> classOf[StructType],
"list_value" -> classOf[ArrayType],
"map_value" -> classOf[MapType],
"partition_id" -> classOf[IntegerType])

val expectedFieldNames = if (sourceOptions.readChangeFeed) {
Expand All @@ -101,12 +212,15 @@ object SchemaUtil {
val stateVarType = stateVarInfo.stateVariableType

stateVarType match {
case StateVariableType.ValueState =>
case ValueState =>
Seq("key", "single_value", "partition_id")

case StateVariableType.ListState =>
case ListState =>
Seq("key", "list_value", "partition_id")

case MapState =>
Seq("key", "map_value", "partition_id")

case _ =>
throw StateDataSourceErrors
.internalError(s"Unsupported state variable type $stateVarType")
Expand All @@ -131,20 +245,71 @@ object SchemaUtil {
val stateVarType = stateVarInfo.stateVariableType

stateVarType match {
case StateVariableType.ValueState =>
case ValueState =>
new StructType()
.add("key", stateStoreColFamilySchema.keySchema)
.add("single_value", stateStoreColFamilySchema.valueSchema)
.add("partition_id", IntegerType)

case StateVariableType.ListState =>
case ListState =>
new StructType()
.add("key", stateStoreColFamilySchema.keySchema)
.add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema))
.add("partition_id", IntegerType)

case MapState =>
val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
stateStoreColFamilySchema.keySchema, "key")
val userKeySchema = stateStoreColFamilySchema.userKeyEncoderSchema.get
val valueMapSchema = MapType.apply(
keyType = userKeySchema,
valueType = stateStoreColFamilySchema.valueSchema
)

new StructType()
.add("key", groupingKeySchema)
.add("map_value", valueMapSchema)
.add("partition_id", IntegerType)

case _ =>
throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType")
}
}

/**
* Helper functions for map state data source reader.
*
* Map state variables are stored in RocksDB state store has the schema of
* `TransformWithStateKeyValueRowSchemaUtils.getCompositeKeySchema()`;
* But for state store reader, we need to return in format of:
* "key": groupingKey, "map_value": Map(userKey -> value).
*
* The following functions help to translate between two schema.
*/
def isMapStateVariable(
stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = {
stateVariableInfoOpt.isDefined &&
stateVariableInfoOpt.get.stateVariableType == MapState
}

/**
* Given key-value schema generated from `generateSchemaForStateVar()`,
* returns the compositeKey schema that key is stored in the state store
*/
def getCompositeKeySchema(schema: StructType): StructType = {
val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
schema, "key").asInstanceOf[StructType]
val userKeySchema = try {
Option(
SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType]
.keyType.asInstanceOf[StructType])
} catch {
case NonFatal(e) =>
throw StateDataSourceErrors.internalError(s"No such field named as 'map_value' " +
s"during state source reader schema initialization. Internal exception message: $e")
}
new StructType()
.add("key", groupingKeySchema)
.add("userKey", userKeySchema.get)
}
}
Loading

0 comments on commit 8732528

Please sign in to comment.