From c4b7498b37a3608e96f5725d9458711ec6841631 Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Sat, 4 Nov 2023 20:58:49 -0700 Subject: [PATCH 1/4] add implementation and tests --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../state/metadata/StateMetadataSource.scala | 206 ++++++++++++++++++ .../state/OperatorStateMetadataSuite.scala | 23 +- 3 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 3169e75031fca..b4c18c38f04aa 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -27,4 +27,5 @@ org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat -org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider \ No newline at end of file +org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider +org.apache.spark.sql.execution.datasources.v2.state.StateMetadataSource \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala new file mode 100644 index 0000000000000..08395eb369551 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util + +import scala.jdk.CollectionConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration + +case class StateMetadataTableEntry( + operatorId: Long, + operatorName: String, + stateStoreName: String, + numPartitions: Int, + numColsPrefixKey: Int, + minBatchId: Long, + maxBatchId: Long) { + def toRow(): InternalRow = { + InternalRow.fromSeq( + Seq(operatorId, + UTF8String.fromString(operatorName), + UTF8String.fromString(stateStoreName), + numPartitions, + numColsPrefixKey, + minBatchId, + maxBatchId)) + } +} + +object StateMetadataTableEntry { + private[sql] val schema = { + new StructType() + .add("operatorId", LongType) + .add("operatorName", StringType) + .add("stateStoreName", StringType) + .add("numPartitions", IntegerType) + .add("numColsPrefixKey", IntegerType) + .add("minBatchId", LongType) + .add("maxBatchId", LongType) + } +} + +class StateMetadataSource extends TableProvider with DataSourceRegister { + override def shortName(): String = "state-metadata" + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + new StateMetadataTable + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + // The schema of state metadata table is static. + StateMetadataTableEntry.schema + } +} + + +class StateMetadataTable extends Table with SupportsRead { + override def name(): String = "state-metadata-table" + + override def schema(): StructType = StateMetadataTableEntry.schema + + override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + () => { + assert(options.containsKey("path"), "Must specify checkpoint path to read state metadata") + new StateMetadataScan(options.get("path")) + } + } +} + +case class StateMetadataInputPartition(checkpointLocation: String) extends InputPartition + +class StateMetadataScan(checkpointLocation: String) extends Scan { + override def readSchema: StructType = StateMetadataTableEntry.schema + + override def toBatch: Batch = { + new Batch { + override def planInputPartitions(): Array[InputPartition] = { + Array(StateMetadataInputPartition(checkpointLocation)) + } + + override def createReaderFactory(): PartitionReaderFactory = { + // Don't need to broadcast the hadoop conf because this source only has one partition. + val conf = new SerializableConfiguration(SparkSession.active.sessionState.newHadoopConf()) + StateMetadataPartitionReaderFactory(conf) + } + } + } +} + +case class StateMetadataPartitionReaderFactory(hadoopConf: SerializableConfiguration) + extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new StateMetadataPartitionReader( + partition.asInstanceOf[StateMetadataInputPartition].checkpointLocation, hadoopConf) + } +} + +class StateMetadataPartitionReader( + checkpointLocation: String, + serializedHadoopConf: SerializableConfiguration) extends PartitionReader[InternalRow] { + private lazy val hadoopConf: Configuration = serializedHadoopConf.value + + private lazy val fileManager = + CheckpointFileManager.create(new Path(checkpointLocation), hadoopConf) + + override def next(): Boolean = { + stateMetadata.hasNext + } + + override def get(): InternalRow = { + stateMetadata.next().toRow() + } + + private def stateDir = new Path(checkpointLocation, "state") + + private def pathNameCanBeParsedAsLong(path: Path) = { + try { + path.getName.toLong + true + } catch { + case _: NumberFormatException => false + } + } + + private def pathToLong(path: Path) = { + path.getName.toLong + } + + // Return true when the filename can be parsed as long integer. + private val longFileNameFilter = new PathFilter { + override def accept(path: Path): Boolean = pathNameCanBeParsedAsLong(path) + } + + override def close(): Unit = {} + + // List the state directory to find all the operator id. + private def opIds: Array[Long] = { + fileManager.list(stateDir, longFileNameFilter).map(f => pathToLong(f.getPath)).sorted + } + + // List the commit log entries to find all the available batch ids. + private def batchIds: Array[Long] = { + val commitLog = new Path(checkpointLocation, "commits") + if (fileManager.exists(commitLog)) { + fileManager.list(commitLog, longFileNameFilter).map(f => pathToLong(f.getPath)).sorted + } else Array.empty + } + + private def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + opIds.map { opId => + new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() + } + } + + private lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { + allOperatorStateMetadata.flatMap { operatorStateMetadata => + require(operatorStateMetadata.version == 1) + val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] + operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, + operatorStateMetadataV1.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + stateStoreMetadata.numColsPrefixKey, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1 + ) + } + } + }.iterator +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index 48cc17bbbabf2..ef0770ca11453 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.hadoop.fs.Path -import org.apache.spark.sql.Column +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} @@ -53,6 +53,13 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val operatorMetadata = OperatorStateMetadataV1(operatorInfo, stateStoreInfo.toArray) new OperatorStateMetadataWriter(statePath, hadoopConf).write(operatorMetadata) checkOperatorStateMetadata(checkpointDir.toString, 0, operatorMetadata) + // Commit log is empty, there is no available batch id. + checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), + Seq(Row(1, "Join", "store1", 200, 1, -1L, -1L), + Row(1, "Join", "store2", 200, 1, -1L, -1L), + Row(1, "Join", "store3", 200, 1, -1L, -1L), + Row(1, "Join", "store4", 200, 1, -1L, -1L) + )) } } @@ -105,6 +112,13 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val expectedMetadata = OperatorStateMetadataV1( OperatorInfoV1(0, "symmetricHashJoin"), expectedStateStoreInfo) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) + + checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), + Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0, 0L, 1L), + Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0, 0L, 1L) + )) } } @@ -147,6 +161,9 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { Array(StateStoreMetadataV1("default", 1, spark.sessionState.conf.numShufflePartitions)) ) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) + + checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), + Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 1, 0L, 0L))) } } @@ -176,6 +193,10 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { Array(StateStoreMetadataV1("default", 0, numShufflePartitions))) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata0) checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1) + + checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), + Seq(Row(0, "stateStoreSave", "default", 5, 0, 0L, 1L), + Row(1, "stateStoreSave", "default", 5, 0, 0L, 1L))) } } } From 76e0420b4e9bb114e3f758140310cd1c34a251ef Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Sat, 4 Nov 2023 22:36:20 -0700 Subject: [PATCH 2/4] clean up --- .../state/metadata/StateMetadataSource.scala | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 08395eb369551..61d66d14c9c2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -133,10 +133,6 @@ case class StateMetadataPartitionReaderFactory(hadoopConf: SerializableConfigura class StateMetadataPartitionReader( checkpointLocation: String, serializedHadoopConf: SerializableConfiguration) extends PartitionReader[InternalRow] { - private lazy val hadoopConf: Configuration = serializedHadoopConf.value - - private lazy val fileManager = - CheckpointFileManager.create(new Path(checkpointLocation), hadoopConf) override def next(): Boolean = { stateMetadata.hasNext @@ -146,42 +142,44 @@ class StateMetadataPartitionReader( stateMetadata.next().toRow() } - private def stateDir = new Path(checkpointLocation, "state") + override def close(): Unit = {} + + private def pathToLong(path: Path) = { + path.getName.toLong + } private def pathNameCanBeParsedAsLong(path: Path) = { try { - path.getName.toLong + pathToLong(path) true } catch { case _: NumberFormatException => false } } - private def pathToLong(path: Path) = { - path.getName.toLong - } - // Return true when the filename can be parsed as long integer. - private val longFileNameFilter = new PathFilter { + private val pathNameCanBeParsedAsLongFilter = new PathFilter { override def accept(path: Path): Boolean = pathNameCanBeParsedAsLong(path) } - override def close(): Unit = {} + private lazy val hadoopConf: Configuration = serializedHadoopConf.value - // List the state directory to find all the operator id. - private def opIds: Array[Long] = { - fileManager.list(stateDir, longFileNameFilter).map(f => pathToLong(f.getPath)).sorted - } + private lazy val fileManager = + CheckpointFileManager.create(new Path(checkpointLocation), hadoopConf) // List the commit log entries to find all the available batch ids. private def batchIds: Array[Long] = { val commitLog = new Path(checkpointLocation, "commits") if (fileManager.exists(commitLog)) { - fileManager.list(commitLog, longFileNameFilter).map(f => pathToLong(f.getPath)).sorted + fileManager + .list(commitLog, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted } else Array.empty } private def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + val stateDir = new Path(checkpointLocation, "state") + val opIds = fileManager + .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted opIds.map { opId => new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() } From cb2baa8a7ff573d50678af1ae604581801686b84 Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Tue, 7 Nov 2023 13:26:42 -0800 Subject: [PATCH 3/4] use metadata column and exception --- .../state/metadata/StateMetadataSource.scala | 32 ++++++++----- .../state/OperatorStateMetadataSuite.scala | 46 +++++++++++++------ 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 61d66d14c9c2c..8a74db8d19639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -25,13 +25,13 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration @@ -41,18 +41,18 @@ case class StateMetadataTableEntry( operatorName: String, stateStoreName: String, numPartitions: Int, - numColsPrefixKey: Int, minBatchId: Long, - maxBatchId: Long) { + maxBatchId: Long, + numColsPrefixKey: Int) { def toRow(): InternalRow = { InternalRow.fromSeq( Seq(operatorId, UTF8String.fromString(operatorName), UTF8String.fromString(stateStoreName), numPartitions, - numColsPrefixKey, minBatchId, - maxBatchId)) + maxBatchId, + numColsPrefixKey)) } } @@ -63,7 +63,6 @@ object StateMetadataTableEntry { .add("operatorName", StringType) .add("stateStoreName", StringType) .add("numPartitions", IntegerType) - .add("numColsPrefixKey", IntegerType) .add("minBatchId", LongType) .add("maxBatchId", LongType) } @@ -86,7 +85,7 @@ class StateMetadataSource extends TableProvider with DataSourceRegister { } -class StateMetadataTable extends Table with SupportsRead { +class StateMetadataTable extends Table with SupportsRead with SupportsMetadataColumns { override def name(): String = "state-metadata-table" override def schema(): StructType = StateMetadataTableEntry.schema @@ -95,10 +94,21 @@ class StateMetadataTable extends Table with SupportsRead { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { () => { - assert(options.containsKey("path"), "Must specify checkpoint path to read state metadata") + if (!options.containsKey("path")) { + throw new IllegalArgumentException("Checkpoint path is not specified for" + + " state metadata data source.") + } new StateMetadataScan(options.get("path")) } } + + private object NumColsPrefixKeyColumn extends MetadataColumn { + override def name: String = "_numColsPrefixKey" + override def dataType: DataType = IntegerType + override def comment: String = "Number of columns in prefix key of the state store instance" + } + + override val metadataColumns: Array[MetadataColumn] = Array(NumColsPrefixKeyColumn) } case class StateMetadataInputPartition(checkpointLocation: String) extends InputPartition @@ -194,9 +204,9 @@ class StateMetadataPartitionReader( operatorStateMetadataV1.operatorInfo.operatorName, stateStoreMetadata.storeName, stateStoreMetadata.numPartitions, - stateStoreMetadata.numColsPrefixKey, if (batchIds.nonEmpty) batchIds.head else -1, - if (batchIds.nonEmpty) batchIds.last else -1 + if (batchIds.nonEmpty) batchIds.last else -1, + stateStoreMetadata.numColsPrefixKey ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index ef0770ca11453..7d88b69534ceb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -53,13 +53,15 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val operatorMetadata = OperatorStateMetadataV1(operatorInfo, stateStoreInfo.toArray) new OperatorStateMetadataWriter(statePath, hadoopConf).write(operatorMetadata) checkOperatorStateMetadata(checkpointDir.toString, 0, operatorMetadata) + val df = spark.read.format("state-metadata").load(checkpointDir.toString) // Commit log is empty, there is no available batch id. - checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), - Seq(Row(1, "Join", "store1", 200, 1, -1L, -1L), - Row(1, "Join", "store2", 200, 1, -1L, -1L), - Row(1, "Join", "store3", 200, 1, -1L, -1L), - Row(1, "Join", "store4", 200, 1, -1L, -1L) + checkAnswer(df, Seq(Row(1, "Join", "store1", 200, -1L, -1L), + Row(1, "Join", "store2", 200, -1L, -1L), + Row(1, "Join", "store3", 200, -1L, -1L), + Row(1, "Join", "store4", 200, -1L, -1L) )) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), + Seq(Row(1), Row(1), Row(1), Row(1))) } } @@ -113,12 +115,15 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { OperatorInfoV1(0, "symmetricHashJoin"), expectedStateStoreInfo) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) - checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), - Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0, 0L, 1L), - Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0, 0L, 1L), - Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0, 0L, 1L), - Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0, 0L, 1L) + val df = spark.read.format("state-metadata") + .load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0L, 1L) )) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), + Seq(Row(0), Row(0), Row(0), Row(0))) } } @@ -162,8 +167,10 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { ) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) - checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), - Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 1, 0L, 0L))) + val df = spark.read.format("state-metadata") + .load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L))) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1))) } } @@ -194,9 +201,18 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata0) checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1) - checkAnswer(spark.read.format("state-metadata").load(checkpointDir.toString), - Seq(Row(0, "stateStoreSave", "default", 5, 0, 0L, 1L), - Row(1, "stateStoreSave", "default", 5, 0, 0L, 1L))) + val df = spark.read.format("state-metadata") + .load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L), + Row(1, "stateStoreSave", "default", 5, 0L, 1L))) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0))) } } + + test("State metadata data source handle missing argument") { + val e = intercept[IllegalArgumentException] { + spark.read.format("state-metadata").load().collect() + } + assert(e.getMessage == "Checkpoint path is not specified for state metadata data source.") + } } From 40d4c6b0bc0f0dd6e698a767af131a440b7fafdf Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Tue, 7 Nov 2023 13:50:15 -0800 Subject: [PATCH 4/4] format --- .../streaming/state/OperatorStateMetadataSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index 7d88b69534ceb..340187fa49514 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -167,8 +167,7 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { ) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) - val df = spark.read.format("state-metadata") - .load(checkpointDir.toString) + val df = spark.read.format("state-metadata").load(checkpointDir.toString) checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L))) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1))) } @@ -201,8 +200,7 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata0) checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1) - val df = spark.read.format("state-metadata") - .load(checkpointDir.toString) + val df = spark.read.format("state-metadata").load(checkpointDir.toString) checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L), Row(1, "stateStoreSave", "default", 5, 0L, 1L))) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0)))