diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 1af2ec174c66d..24166a46bbd39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -74,8 +74,11 @@ abstract class StatePartitionReaderBase( private val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) - protected val keySchema = SchemaUtil.getSchemaAsDataType( - schema, "key").asInstanceOf[StructType] + protected val keySchema = { + if (!SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { + SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] + } else SchemaUtil.getCompositeKeySchema(schema) + } protected val valueSchema = if (stateVariableInfoOpt.isDefined) { schemaForValueRow @@ -178,38 +181,43 @@ class StatePartitionReader( override lazy val iter: Iterator[InternalRow] = { val stateVarName = stateVariableInfoOpt .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) - store - .iterator(stateVarName) - .map { pair => - stateVariableInfoOpt match { - case Some(stateVarInfo) => - val stateVarType = stateVarInfo.stateVariableType + if (SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { + SchemaUtil.unifyMapStateRowPair( + store.iterator(stateVarName), keySchema, partition.partition) + } else { + store + .iterator(stateVarName) + .map { pair => + stateVariableInfoOpt match { + case Some(stateVarInfo) => + val stateVarType = stateVarInfo.stateVariableType - stateVarType match { - case StateVariableType.ValueState => - SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + stateVarType match { + case StateVariableType.ValueState => + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) - case StateVariableType.ListState => - val key = pair.key - val result = store.valuesIterator(key, stateVarName) - var unsafeRowArr: Seq[UnsafeRow] = Seq.empty - result.foreach { entry => - unsafeRowArr = unsafeRowArr :+ entry.copy() - } - // convert the list of values to array type - val arrData = new GenericArrayData(unsafeRowArr.toArray) - SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData), - partition.partition) + case StateVariableType.ListState => + val key = pair.key + val result = store.valuesIterator(key, stateVarName) + var unsafeRowArr: Seq[UnsafeRow] = Seq.empty + result.foreach { entry => + unsafeRowArr = unsafeRowArr :+ entry.copy() + } + // convert the list of values to array type + val arrData = new GenericArrayData(unsafeRowArr.toArray) + SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData), + partition.partition) - case _ => - throw new IllegalStateException( - s"Unsupported state variable type: $stateVarType") - } + case _ => + throw new IllegalStateException( + s"Unsupported state variable type: $stateVarType") + } - case None => - SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + case None => + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + } } - } + } } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 47bf9250000a4..88ea06d598e56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -16,14 +16,18 @@ */ package org.apache.spark.sql.execution.datasources.v2.state.utils +import scala.collection.mutable +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions} -import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} -import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema -import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.execution.streaming.StateVariableType._ +import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo +import org.apache.spark.sql.execution.streaming.state.{StateStoreColFamilySchema, UnsafeRowPair} +import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType} import org.apache.spark.util.ArrayImplicits._ object SchemaUtil { @@ -81,6 +85,112 @@ object SchemaUtil { row } + /** + * For map state variables, state rows are stored as composite key. + * To return grouping key -> Map{user key -> value} as one state reader row to + * the users, we need to perform grouping on state rows by their grouping key, + * and construct a map for that grouping key. + * + * We traverse the iterator returned from state store, + * and will only return a row for `next()` only if the grouping key in the next row + * from state store is different (or there are no more rows) + * + * Note that all state rows with the same grouping key are co-located so they will + * appear consecutively during the iterator traversal. + */ + def unifyMapStateRowPair( + stateRows: Iterator[UnsafeRowPair], + compositeKeySchema: StructType, + partitionId: Int): Iterator[InternalRow] = { + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + compositeKeySchema, "key" + ).asInstanceOf[StructType] + val userKeySchema = SchemaUtil.getSchemaAsDataType( + compositeKeySchema, "userKey" + ).asInstanceOf[StructType] + + def appendKVPairToMap( + curMap: mutable.Map[Any, Any], + stateRowPair: UnsafeRowPair): Unit = { + curMap += ( + stateRowPair.key.get(1, userKeySchema) + .asInstanceOf[UnsafeRow].copy() -> + stateRowPair.value.copy() + ) + } + + def createDataRow( + groupingKey: Any, + curMap: mutable.Map[Any, Any]): GenericInternalRow = { + val row = new GenericInternalRow(3) + val mapData = ArrayBasedMapData(curMap) + row.update(0, groupingKey) + row.update(1, mapData) + row.update(2, partitionId) + row + } + + // All of the rows with the same grouping key were co-located and were + // grouped together consecutively. + new Iterator[InternalRow] { + var curGroupingKey: UnsafeRow = _ + var curStateRowPair: UnsafeRowPair = _ + val curMap = mutable.Map.empty[Any, Any] + + override def hasNext: Boolean = + stateRows.hasNext || !curMap.isEmpty + + override def next(): InternalRow = { + var foundNewGroupingKey = false + while (stateRows.hasNext && !foundNewGroupingKey) { + curStateRowPair = stateRows.next() + if (curGroupingKey == null) { + // First time in the iterator + // Need to make a copy because we need to keep the + // value across function calls + curGroupingKey = curStateRowPair.key + .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy() + appendKVPairToMap(curMap, curStateRowPair) + } else { + val curPairGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + if (curPairGroupingKey == curGroupingKey) { + appendKVPairToMap(curMap, curStateRowPair) + } else { + // find a different grouping key, exit loop and return a row + foundNewGroupingKey = true + } + } + } + if (foundNewGroupingKey) { + // found a different grouping key + val row = createDataRow(curGroupingKey, curMap) + // update vars + curGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + .asInstanceOf[UnsafeRow].copy() + // empty the map, append current row + curMap.clear() + appendKVPairToMap(curMap, curStateRowPair) + // return map value of previous grouping key + row + } else { + if (curMap.isEmpty) { + throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " + + "user is trying to get element from an exhausted iterator.") + } + else { + // reach the end of the state rows + val row = createDataRow(curGroupingKey, curMap) + // clear the map to end the iterator + curMap.clear() + row + } + } + } + } + } + def isValidSchema( sourceOptions: StateSourceOptions, schema: StructType, @@ -92,6 +202,7 @@ object SchemaUtil { "value" -> classOf[StructType], "single_value" -> classOf[StructType], "list_value" -> classOf[ArrayType], + "map_value" -> classOf[MapType], "partition_id" -> classOf[IntegerType]) val expectedFieldNames = if (sourceOptions.readChangeFeed) { @@ -101,12 +212,15 @@ object SchemaUtil { val stateVarType = stateVarInfo.stateVariableType stateVarType match { - case StateVariableType.ValueState => + case ValueState => Seq("key", "single_value", "partition_id") - case StateVariableType.ListState => + case ListState => Seq("key", "list_value", "partition_id") + case MapState => + Seq("key", "map_value", "partition_id") + case _ => throw StateDataSourceErrors .internalError(s"Unsupported state variable type $stateVarType") @@ -131,20 +245,71 @@ object SchemaUtil { val stateVarType = stateVarInfo.stateVariableType stateVarType match { - case StateVariableType.ValueState => + case ValueState => new StructType() .add("key", stateStoreColFamilySchema.keySchema) .add("single_value", stateStoreColFamilySchema.valueSchema) .add("partition_id", IntegerType) - case StateVariableType.ListState => + case ListState => new StructType() .add("key", stateStoreColFamilySchema.keySchema) .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema)) .add("partition_id", IntegerType) + case MapState => + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + stateStoreColFamilySchema.keySchema, "key") + val userKeySchema = stateStoreColFamilySchema.userKeyEncoderSchema.get + val valueMapSchema = MapType.apply( + keyType = userKeySchema, + valueType = stateStoreColFamilySchema.valueSchema + ) + + new StructType() + .add("key", groupingKeySchema) + .add("map_value", valueMapSchema) + .add("partition_id", IntegerType) + case _ => throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType") } } + + /** + * Helper functions for map state data source reader. + * + * Map state variables are stored in RocksDB state store has the schema of + * `TransformWithStateKeyValueRowSchemaUtils.getCompositeKeySchema()`; + * But for state store reader, we need to return in format of: + * "key": groupingKey, "map_value": Map(userKey -> value). + * + * The following functions help to translate between two schema. + */ + def isMapStateVariable( + stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = { + stateVariableInfoOpt.isDefined && + stateVariableInfoOpt.get.stateVariableType == MapState + } + + /** + * Given key-value schema generated from `generateSchemaForStateVar()`, + * returns the compositeKey schema that key is stored in the state store + */ + def getCompositeKeySchema(schema: StructType): StructType = { + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + schema, "key").asInstanceOf[StructType] + val userKeySchema = try { + Option( + SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType] + .keyType.asInstanceOf[StructType]) + } catch { + case NonFatal(e) => + throw StateDataSourceErrors.internalError(s"No such field named as 'map_value' " + + s"during state source reader schema initialization. Internal exception message: $e") + } + new StructType() + .add("key", groupingKeySchema) + .add("userKey", userKeySchema.get) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 1c06e4f97f2b7..61091fde35e79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass} import org.apache.spark.sql.functions.explode import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, ListState, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState} +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} +import org.apache.spark.sql.streaming.util.StreamManualClock /** Stateful processor of single value state var with non-primitive type */ class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor { @@ -370,4 +371,114 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } } + + test("state data source integration - map state with single variable") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputData = MemoryStream[InputMapRow] + val result = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new TestMapStateProcessor(), + TimeMode.None(), + OutputMode.Append()) + testStream(result, OutputMode.Append())( + StartStream(checkpointLocation = tempDir.getCanonicalPath), + AddData(inputData, InputMapRow("k1", "updateValue", ("v1", "10"))), + AddData(inputData, InputMapRow("k1", "exists", ("", ""))), + AddData(inputData, InputMapRow("k2", "exists", ("", ""))), + CheckNewAnswer(("k1", "exists", "true"), ("k2", "exists", "false")), + + AddData(inputData, InputMapRow("k1", "updateValue", ("v2", "5"))), + AddData(inputData, InputMapRow("k2", "updateValue", ("v2", "3"))), + ProcessAllAvailable(), + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "sessionState") + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", "map_value AS mapValue") + + checkAnswer(resultDf, + Seq( + Row("k1", + Map(Row("v1") -> Row("10"), Row("v2") -> Row("5"))), + Row("k2", + Map(Row("v2") -> Row("3")))) + ) + } + } + } + + test("state data source integration - map state TTL with single variable") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputStream = MemoryStream[MapInputEvent] + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new MapStateTTLProcessor(ttlConfig), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = tempDir.getCanonicalPath), + AddData(inputStream, + MapInputEvent("k1", "key1", "put", 1), + MapInputEvent("k1", "key2", "put", 2) + ), + AdvanceManualClock(1 * 1000), // batch timestamp: 1000 + CheckNewAnswer(), + AddData(inputStream, + MapInputEvent("k1", "key1", "get", -1), + MapInputEvent("k1", "key2", "get", -1) + ), + AdvanceManualClock(30 * 1000), // batch timestamp: 31000 + CheckNewAnswer( + MapOutputEvent("k1", "key1", 1, isTTLValue = false, -1), + MapOutputEvent("k1", "key2", 2, isTTLValue = false, -1) + ), + // get values from ttl state + AddData(inputStream, + MapInputEvent("k1", "", "get_values_in_ttl_state", -1) + ), + AdvanceManualClock(1 * 1000), // batch timestamp: 32000 + CheckNewAnswer( + MapOutputEvent("k1", "key1", -1, isTTLValue = true, 61000), + MapOutputEvent("k1", "key2", -1, isTTLValue = true, 61000) + ), + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mapState") + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", "map_value AS mapValue") + + checkAnswer(resultDf, + Seq( + Row("k1", + Map(Row("key2") -> Row(Row(2), 61000L), + Row("key1") -> Row(Row(1), 61000L)))) + ) + } + } + } }