Skip to content

Commit

Permalink
[SPARK-50017][SS] Support Avro encoding for TransformWithState operator
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Currently, we use the internal byte representation to store state for stateful streaming operators in the StateStore. This PR introduces Avro serialization and deserialization capabilities in the RocksDBStateEncoder so that we can instead use Avro encoding to store state. This is currently enabled for the TransformWithState operator via SQLConf to support all functionality supported by TWS

### Why are the changes needed?

UnsafeRow is an inherently unstable format that makes no guarantees of being backwards-compatible. Therefore, if the format changes between Spark releases, this could cause StateStore corruptions. Avro is more stable, and inherently enables schema evolution.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Amended and added to unit tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #48401 from ericm-db/avro.

Lead-authored-by: Eric Marnadi <[email protected]>
Co-authored-by: Eric Marnadi <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
2 people authored and HeartSaVioR committed Nov 26, 2024
1 parent f712213 commit 331d0bf
Show file tree
Hide file tree
Showing 17 changed files with 744 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2221,6 +2221,17 @@ object SQLConf {
.intConf
.createWithDefault(1)

val STREAMING_STATE_STORE_ENCODING_FORMAT =
buildConf("spark.sql.streaming.stateStore.encodingFormat")
.doc("The encoding format used for stateful operators to store information " +
"in the state store")
.version("4.0.0")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
.checkValue(v => Set("unsaferow", "avro").contains(v),
"Valid values are 'unsaferow' and 'avro'")
.createWithDefault("unsaferow")

val STATE_STORE_COMPRESSION_CODEC =
buildConf("spark.sql.streaming.stateStore.compression.codec")
.internal()
Expand Down Expand Up @@ -5596,6 +5607,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def stateStoreCheckpointFormatVersion: Int = getConf(STATE_STORE_CHECKPOINT_FORMAT_VERSION)

def stateStoreEncodingFormat: String = getConf(STREAMING_STATE_STORE_ENCODING_FORMAT)

def checkpointRenamedFileCheck: Boolean = getConf(CHECKPOINT_RENAMEDFILE_CHECK_ENABLED)

def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,49 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._

object StateStoreColumnFamilySchemaUtils {

/**
* Avro uses zig-zag encoding for some fixed-length types, like Longs and Ints. For range scans
* we want to use big-endian encoding, so we need to convert the source schema to replace these
* types with BinaryType.
*
* @param schema The schema to convert
* @param ordinals If non-empty, only convert fields at these ordinals.
* If empty, convert all fields.
*/
def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): StructType = {
val ordinalSet = ordinals.toSet

StructType(schema.fields.zipWithIndex.flatMap { case (field, idx) =>
if ((ordinals.isEmpty || ordinalSet.contains(idx)) && isFixedSize(field.dataType)) {
// For each numeric field, create two fields:
// 1. Byte marker for null, positive, or negative values
// 2. The original numeric value in big-endian format
// Byte type is converted to Int in Avro, which doesn't work for us as Avro
// uses zig-zag encoding as opposed to big-endian for Ints
Seq(
StructField(s"${field.name}_marker", BinaryType, nullable = false),
field.copy(name = s"${field.name}_value", BinaryType)
)
} else {
Seq(field)
}
})
}

private def isFixedSize(dataType: DataType): Boolean = dataType match {
case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType => true
case _ => false
}

def getTtlColFamilyName(stateName: String): String = {
"$ttl_" + stateName
}

def getValueStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ abstract class StreamExecution(

object StreamExecution {
val QUERY_ID_KEY = "sql.streaming.queryId"
val RUN_ID_KEY = "sql.streaming.runId"
val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing"
val IO_EXCEPTION_NAMES = Seq(
classOf[InterruptedException].getName,
Expand Down
Loading

0 comments on commit 331d0bf

Please sign in to comment.