diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 3169e75031fca..b4c18c38f04aa 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -27,4 +27,5 @@ org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat -org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider \ No newline at end of file +org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider +org.apache.spark.sql.execution.datasources.v2.state.StateMetadataSource \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala new file mode 100644 index 0000000000000..8a74db8d19639 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util + +import scala.jdk.CollectionConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration + +case class StateMetadataTableEntry( + operatorId: Long, + operatorName: String, + stateStoreName: String, + numPartitions: Int, + minBatchId: Long, + maxBatchId: Long, + numColsPrefixKey: Int) { + def toRow(): InternalRow = { + InternalRow.fromSeq( + Seq(operatorId, + UTF8String.fromString(operatorName), + UTF8String.fromString(stateStoreName), + numPartitions, + minBatchId, + maxBatchId, + numColsPrefixKey)) + } +} + +object StateMetadataTableEntry { + private[sql] val schema = { + new StructType() + .add("operatorId", LongType) + .add("operatorName", StringType) + .add("stateStoreName", StringType) + .add("numPartitions", IntegerType) + .add("minBatchId", LongType) + .add("maxBatchId", LongType) + } +} + +class StateMetadataSource extends TableProvider with DataSourceRegister { + override def shortName(): String = "state-metadata" + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + new StateMetadataTable + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + // The schema of state metadata table is static. + StateMetadataTableEntry.schema + } +} + + +class StateMetadataTable extends Table with SupportsRead with SupportsMetadataColumns { + override def name(): String = "state-metadata-table" + + override def schema(): StructType = StateMetadataTableEntry.schema + + override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + () => { + if (!options.containsKey("path")) { + throw new IllegalArgumentException("Checkpoint path is not specified for" + + " state metadata data source.") + } + new StateMetadataScan(options.get("path")) + } + } + + private object NumColsPrefixKeyColumn extends MetadataColumn { + override def name: String = "_numColsPrefixKey" + override def dataType: DataType = IntegerType + override def comment: String = "Number of columns in prefix key of the state store instance" + } + + override val metadataColumns: Array[MetadataColumn] = Array(NumColsPrefixKeyColumn) +} + +case class StateMetadataInputPartition(checkpointLocation: String) extends InputPartition + +class StateMetadataScan(checkpointLocation: String) extends Scan { + override def readSchema: StructType = StateMetadataTableEntry.schema + + override def toBatch: Batch = { + new Batch { + override def planInputPartitions(): Array[InputPartition] = { + Array(StateMetadataInputPartition(checkpointLocation)) + } + + override def createReaderFactory(): PartitionReaderFactory = { + // Don't need to broadcast the hadoop conf because this source only has one partition. + val conf = new SerializableConfiguration(SparkSession.active.sessionState.newHadoopConf()) + StateMetadataPartitionReaderFactory(conf) + } + } + } +} + +case class StateMetadataPartitionReaderFactory(hadoopConf: SerializableConfiguration) + extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new StateMetadataPartitionReader( + partition.asInstanceOf[StateMetadataInputPartition].checkpointLocation, hadoopConf) + } +} + +class StateMetadataPartitionReader( + checkpointLocation: String, + serializedHadoopConf: SerializableConfiguration) extends PartitionReader[InternalRow] { + + override def next(): Boolean = { + stateMetadata.hasNext + } + + override def get(): InternalRow = { + stateMetadata.next().toRow() + } + + override def close(): Unit = {} + + private def pathToLong(path: Path) = { + path.getName.toLong + } + + private def pathNameCanBeParsedAsLong(path: Path) = { + try { + pathToLong(path) + true + } catch { + case _: NumberFormatException => false + } + } + + // Return true when the filename can be parsed as long integer. + private val pathNameCanBeParsedAsLongFilter = new PathFilter { + override def accept(path: Path): Boolean = pathNameCanBeParsedAsLong(path) + } + + private lazy val hadoopConf: Configuration = serializedHadoopConf.value + + private lazy val fileManager = + CheckpointFileManager.create(new Path(checkpointLocation), hadoopConf) + + // List the commit log entries to find all the available batch ids. + private def batchIds: Array[Long] = { + val commitLog = new Path(checkpointLocation, "commits") + if (fileManager.exists(commitLog)) { + fileManager + .list(commitLog, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted + } else Array.empty + } + + private def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + val stateDir = new Path(checkpointLocation, "state") + val opIds = fileManager + .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted + opIds.map { opId => + new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() + } + } + + private lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { + allOperatorStateMetadata.flatMap { operatorStateMetadata => + require(operatorStateMetadata.version == 1) + val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] + operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, + operatorStateMetadataV1.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + stateStoreMetadata.numColsPrefixKey + ) + } + } + }.iterator +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index 48cc17bbbabf2..340187fa49514 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.hadoop.fs.Path -import org.apache.spark.sql.Column +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} @@ -53,6 +53,15 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val operatorMetadata = OperatorStateMetadataV1(operatorInfo, stateStoreInfo.toArray) new OperatorStateMetadataWriter(statePath, hadoopConf).write(operatorMetadata) checkOperatorStateMetadata(checkpointDir.toString, 0, operatorMetadata) + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + // Commit log is empty, there is no available batch id. + checkAnswer(df, Seq(Row(1, "Join", "store1", 200, -1L, -1L), + Row(1, "Join", "store2", 200, -1L, -1L), + Row(1, "Join", "store3", 200, -1L, -1L), + Row(1, "Join", "store4", 200, -1L, -1L) + )) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), + Seq(Row(1), Row(1), Row(1), Row(1))) } } @@ -105,6 +114,16 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val expectedMetadata = OperatorStateMetadataV1( OperatorInfoV1(0, "symmetricHashJoin"), expectedStateStoreInfo) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) + + val df = spark.read.format("state-metadata") + .load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0L, 1L) + )) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), + Seq(Row(0), Row(0), Row(0), Row(0))) } } @@ -147,6 +166,10 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { Array(StateStoreMetadataV1("default", 1, spark.sessionState.conf.numShufflePartitions)) ) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L))) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1))) } } @@ -176,6 +199,18 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { Array(StateStoreMetadataV1("default", 0, numShufflePartitions))) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata0) checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L), + Row(1, "stateStoreSave", "default", 5, 0L, 1L))) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0))) + } + } + + test("State metadata data source handle missing argument") { + val e = intercept[IllegalArgumentException] { + spark.read.format("state-metadata").load().collect() } + assert(e.getMessage == "Checkpoint path is not specified for state metadata data source.") } }