Skip to content

Commit

Permalink
[SPARK-50302][SS] Ensure secondary index sizes equal primary index si…
Browse files Browse the repository at this point in the history
…zes for TransformWithState stateful variables with TTL

### What changes were proposed in this pull request?

This PR ensures that the secondary indexes that state variables with TTL use are at most the size of the corresponding state variable's primary index. This change will eliminate unnecessary work done during the cleanup of stateful variables with TTL.

### Why are the changes needed?

#### Context

The `TransformWithState` operator (hereon out known as "TWS") will allow users write procedural logic over streams of records. To store state between micro-batches, Spark will provide users _stateful variables_, which persist between micro-batches. For example, you might want to emit an average of the past 5 records, every 5 records. You might only receive 2 records in the first micro-batch, so you have to _buffer_ these 2 records until you get 3 more in a subsequent batch.  TWS supports 3 different types of stateful variables: single values, lists, and maps.

The TWS operator also supports stateful variables with Time To Live; this allows you to say, "keep a certain record in state for `d` units of time". This TTL is per-record. This means that every record in a list (or map) can expiry at a different point in time, depending on when the element in the list is inserted. A record inserted into a stateful list (or map) at time `t1` will expire at `t1 + d`, and a second that expires at `t2 + d` will expire at `t2 + d`. (For value state, there's only one value, so "everything" expires at the same time.)

A very natural question to now ask is, how do we efficiently determine which elements have expired in the list, without having to do a full scan of every record in state? The idea here is to keep a secondary index from expiration timestamp, to the specific record that needs to be evicted. Not so hard, right?

#### The state cleanup strategy today

Today's cleanup strategy is about as simple as I indicated earlier: for every insert to a value/map/list, you:

1. Write to the primary index
2. Using the current timestamp, you write into the secondary index

The issue with this approach is that we do two _unconditional_ writes. This means that if the same state variable is written to with different timestamps, there will exist one element in the primary index, while there exists two elements in the secondary index. Consider the following example for a state variable `foo` with value `v1`, and TTL delay of 500:

For batch 0, `batchTimestampMs = 100`, `foo` updates to `v1`:

- Primary index: `[foo -> (v1, 600)]`
- Secondary index: `[(600, foo) -> EMPTY]`

Note that the state variable is included in the secondary index key because we might have several elements with the same expiration timestamp; we want `(600, foo)` to not overwrite a `(600, bar)`, just because they both expire at 600.

Batch 1: `batchTimestampMs = 200`, `foo` updates to `v2`.

Primary index: `[foo -> (v2, 700)]`
Secondary index: `[(600, foo) -> EMPTY, (700, foo) -> EMPTY]`

Now, we have two entries in our secondary index. If the current timestamp advanced to something like 800, we'd take the following steps:

1. We'd take the first element from the secondary index `(600, foo)`, and lookup `foo` in the primary index. That would yield `(v2, 700)`. The value of 700 in the primary index is still less than 800, so we would remove `foo` from the primary index.
2. Then, we would look at `(700, foo)`. We'd look up `foo` in the primary index and see nothing, so we'd do nothing.

You'll notice here that step 2 is _entirely_ redundant. We read `(700, foo)` and did a get to the primary index, for something that was doomed—it would have never returned anything.

While this isn't great, the story is unfortunately significantly worse for lists. The way that we store lists is by having a single key in RocksDB, whose value is the concatenated bytes of all the values in that list. When we do cleanup for a list, we go through _all_ of its records and Thus, it's possible for us to have a list that looks something like:

- Primary index: `[foo -> [(v1, 600), (v2, 700), (v3, 900)]]`
- Secondary index: `[(600, foo) -> EMPTY, (700, foo) -> EMPTY, (900, foo) -> EMPTY]`

Now, suppose that the current timestamp is 800. We need to expire the records in the list. So, we do the following:

1. We take the first element from the secondary index, `(600, foo)`. This tells us that the list `foo` needs cleaning up. We clean up everything in `foo` less than 800. Since we store lists as a single key, we issue a RocksDB `clear` operation, iterate through all of the existing values, eliminate `(v1, 600)` and `(v2, 700)`, and write back `(v3, 900)`.
2. But we still have things left in our secondary index! We now get `(700, foo)`, and we unknowingly do cleanup on `foo` _again_. This consists of clearing `foo`, iterating through its elements, and writing back `(v3, 900)`. But since cleanup already happened, this step is _entirely_ redundant.
3. We encounter `(900, foo)` from the secondary index, and since 900 > 800, we can bail out of cleanup.

Step 2 here is extremely wasteful. If we have `n` elements in our secondary index for the same key, then, in the worst case, we will do the extra cleanup `n-1` times; and each time is a _linear_ time operation! Thus, for a list that has `n` elements, `d` of which need to be cleaned up, the worst-case time complexity is in `O(d*(n-d))`, instead of `O(n)`. And it's _completely_ unnecessary work.

#### How does this PR fix the issue?

It's pretty simple to fix this for value state and map state. This is because every key in value or map state maps to exactly one element in the secondary index. We can maintain a one-to-one correspondence. Any time we modify value/map state, we make sure that we delete the previous entry in the secondary index. This logic is implemented by OneToOneTTLState.

The trickier aspect is handling this for ListState, where the secondary index goes from grouping key to the map that needs to be cleaned up. There's a one to many mapping here; one grouping key maps to multiple records, all of which could expire at a different time. The trick to making sure that secondary indexes don't explode is by making your secondary index store only the minimum expiration timestamp in a list. The rough intuition is that you don't need to store anything larger than that, since when you clean up due to the minimum expiration timestamp, you'll go through the list anyway, and you can find the next minimum timestamp; you can then put _that_ into your secondary index. This logic is implemented by OneToManyTTLState.

### How should reviewers review this PR?

- Start by reading this long description. If you have questions, please ping me in the comments. I would be more than happy to explain.
- Then, understand the class doc comments for `OneToOneTTLState` and `OneToManyTTLState` in `TTLState.scala`.
- Then, I'd recommend going through the unit tests, and making sure that the _behavior_ makes sense to you. If it doesn't, please leave a question.
- Finally, you can look at the actual stateful variable implementations.

### Does this PR introduce _any_ user-facing change?

No, but it is a format difference in the way TWS represents its internal state. However, since TWS is currently `private[sql]` and not publicly available, this is not an issue.

### How was this patch tested?

- Existing UTs have been modified to conform with this new behavior.
- New UTs added to verify that the new indices we added

### Was this patch authored or co-authored using generative AI tooling?

Generated-by: GitHub Copilot

Closes #48853 from neilramaswamy/spark-50302.

Authored-by: Neil Ramaswamy <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
neilramaswamy authored and HeartSaVioR committed Nov 26, 2024
1 parent 7cbfc2c commit 02bfce6
Show file tree
Hide file tree
Showing 12 changed files with 965 additions and 468 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator

/**
Expand All @@ -45,21 +44,13 @@ class ListStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs)
with ListStateMetricsImpl
with ListState[S] {

override def stateStore: StateStore = store
override def baseStateName: String = stateName
override def exprEncSchema: StructType = keyExprEnc.schema
metrics: Map[String, SQLMetric])
extends OneToManyTTLState(
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ListState[S] {

private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder,
stateName, hasTtl = true)

private lazy val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

initialize()

private def initialize(): Unit = {
Expand Down Expand Up @@ -106,61 +97,44 @@ class ListStateImplWithTTL[S](
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
var isFirst = true
var entryCount = 0L
TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")

newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
if (isFirst) {
store.put(encodedKey, encodedValue, stateName)
isFirst = false
} else {
store.merge(encodedKey, encodedValue, stateName)
}
entryCount += 1
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
val newStateUnsafeRows = newState.iterator.map { v =>
stateTypesEncoder.encodeValue(v, ttlExpirationMs)
}
upsertTTLForStateKey(encodedKey)
updateEntryCount(encodedKey, entryCount)

updatePrimaryAndSecondaryIndices(true, encodedKey, newStateUnsafeRows, ttlExpirationMs)
}

/** Append an entry to the list. */
override def appendValue(newState: S): Unit = {
StateStoreErrors.requireNonNullStateValue(newState, stateName)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
val entryCount = getEntryCount(encodedKey)
store.merge(encodedKey,
stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName)
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
upsertTTLForStateKey(encodedKey)
updateEntryCount(encodedKey, entryCount + 1)
val newStateUnsafeRow = stateTypesEncoder.encodeValue(newState, ttlExpirationMs)

updatePrimaryAndSecondaryIndices(false, encodedKey,
Iterator.single(newStateUnsafeRow), ttlExpirationMs)
}

/** Append an entire list to the existing value. */
override def appendList(newState: Array[S]): Unit = {
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
var entryCount = getEntryCount(encodedKey)
newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
store.merge(encodedKey, encodedValue, stateName)
entryCount += 1
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
// The UnsafeRows created here are reused: we do NOT copy them. As a result,
// this iterator must only be used lazily, and it should never be materialized,
// unless you call newStateUnsafeRows.map(_.copy()).
val newStateUnsafeRows = newState.iterator.map { v =>
stateTypesEncoder.encodeValue(v, ttlExpirationMs)
}
upsertTTLForStateKey(encodedKey)
updateEntryCount(encodedKey, entryCount)

updatePrimaryAndSecondaryIndices(false, encodedKey,
newStateUnsafeRows, ttlExpirationMs)
}

/** Remove this state. */
override def clear(): Unit = {
val encodedKey = stateTypesEncoder.encodeGroupingKey()
store.remove(encodedKey, stateName)
val entryCount = getEntryCount(encodedKey)
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount)
removeEntryCount(encodedKey)
clearTTLState()
val groupingKey = stateTypesEncoder.encodeGroupingKey()
clearAllStateForElementKey(groupingKey)
}

private def validateNewState(newState: Array[S]): Unit = {
Expand All @@ -175,36 +149,41 @@ class ListStateImplWithTTL[S](
/**
* Loops through all the values associated with the grouping key, and removes
* the expired elements from the list.
* @param groupingKey grouping key for which cleanup should be performed.
* @param elementKey grouping key for which cleanup should be performed.
*/
override def clearIfExpired(groupingKey: UnsafeRow): Long = {
override def clearExpiredValues(elementKey: UnsafeRow): ValueExpirationResult = {
var numValuesExpired = 0L
val unsafeRowValuesIterator = store.valuesIterator(groupingKey, stateName)
val unsafeRowValuesIterator = store.valuesIterator(elementKey, stateName)
// We clear the list, and use the iterator to put back all of the non-expired values
store.remove(groupingKey, stateName)
removeEntryCount(groupingKey)
store.remove(elementKey, stateName)

var newMinExpirationMsOpt: Option[Long] = None
var isFirst = true
var entryCount = 0L
unsafeRowValuesIterator.foreach { encodedValue =>
if (!stateTypesEncoder.isExpired(encodedValue, batchTimestampMs)) {
if (isFirst) {
store.put(groupingKey, encodedValue, stateName)
isFirst = false
store.put(elementKey, encodedValue, stateName)
} else {
store.merge(groupingKey, encodedValue, stateName)
store.merge(elementKey, encodedValue, stateName)
}

// If it is not expired, it needs to be reinserted (either via put or merge), but
// it also has an expiration time that might be the new minimum.
val currentExpirationMs = stateTypesEncoder.decodeTtlExpirationMs(encodedValue)

newMinExpirationMsOpt = newMinExpirationMsOpt match {
case Some(minExpirationMs) =>
Some(math.min(minExpirationMs, currentExpirationMs.get))
case None =>
Some(currentExpirationMs.get)
}
entryCount += 1
} else {
numValuesExpired += 1
}
}
updateEntryCount(groupingKey, entryCount)
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", numValuesExpired)
numValuesExpired
}

private def upsertTTLForStateKey(encodedGroupingKey: UnsafeRow): Unit = {
upsertTTLForStateKey(ttlExpirationMs, encodedGroupingKey)
ValueExpirationResult(numValuesExpired, newMinExpirationMsOpt)
}

/*
Expand Down Expand Up @@ -238,11 +217,23 @@ class ListStateImplWithTTL[S](
}
}

private[sql] def getMinValues(): Iterator[Long] = {
val groupingKey = stateTypesEncoder.encodeGroupingKey()
minIndexIterator()
.filter(_._1 == groupingKey)
.map(_._2)
}

/**
* Get all ttl values stored in ttl state for current implicit
* grouping key.
* Get the TTL value stored in TTL state for the current implicit grouping key,
* if it exists.
*/
private[sql] def getValuesInTTLState(): Iterator[Long] = {
getValuesInTTLState(stateTypesEncoder.encodeGroupingKey())
private[sql] def getValueInTTLState(): Option[Long] = {
val groupingKey = stateTypesEncoder.encodeGroupingKey()
val ttlRowsForGroupingKey = getTTLRows().filter(_.elementKey == groupingKey).toSeq

assert(ttlRowsForGroupingKey.size <= 1, "Multiple TTLRows found for grouping key " +
s"$groupingKey. Expected at most 1. Found: ${ttlRowsForGroupingKey.mkString(", ")}.")
ttlRowsForGroupingKey.headOption.map(_.expirationMs)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
Expand Down Expand Up @@ -48,17 +47,14 @@ class MapStateImplWithTTL[K, V](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
extends CompositeKeyTTLStateImpl[K](stateName, store,
keyExprEnc, userKeyEnc, batchTimestampMs)
with MapState[K, V] with Logging {
metrics: Map[String, SQLMetric])
extends OneToOneTTLState(
stateName, store, getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema), ttlConfig,
batchTimestampMs, metrics) with MapState[K, V] with Logging {

private val stateTypesEncoder = new CompositeKeyStateEncoder(
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = true)

private val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

initialize()

private def initialize(): Unit = {
Expand Down Expand Up @@ -102,15 +98,12 @@ class MapStateImplWithTTL[K, V](
StateStoreErrors.requireNonNullStateValue(key, stateName)
StateStoreErrors.requireNonNullStateValue(value, stateName)

val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val encodedUserKey = stateTypesEncoder.encodeUserKey(key)

val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
store.put(encodedCompositeKey, encodedValue, stateName)
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
val ttlExpirationMs = StateTTL
.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)
val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)

upsertTTLForStateKey(ttlExpirationMs, encodedGroupingKey, encodedUserKey)
updatePrimaryAndSecondaryIndices(encodedCompositeKey, encodedValue, ttlExpirationMs)
}

/** Get the map associated with grouping key */
Expand Down Expand Up @@ -161,41 +154,12 @@ class MapStateImplWithTTL[K, V](

/** Remove this state. */
override def clear(): Unit = {
keys().foreach { itr =>
removeKey(itr)
}
clearTTLState()
}

/**
* Clears the user state associated with this grouping key
* if it has expired. This function is called by Spark to perform
* cleanup at the end of transformWithState processing.
*
* Spark uses a secondary index to determine if the user state for
* this grouping key has expired. However, its possible that the user
* has updated the TTL and secondary index is out of date. Implementations
* must validate that the user State has actually expired before cleanup based
* on their own State data.
*
* @param groupingKey grouping key for which cleanup should be performed.
* @param userKey user key for which cleanup should be performed.
*/
override def clearIfExpired(
groupingKeyRow: UnsafeRow,
userKeyRow: UnsafeRow): Long = {
val compositeKeyRow = stateTypesEncoder.encodeCompositeKey(groupingKeyRow, userKeyRow)
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)

val retRow = store.get(compositeKeyRow, stateName)
var numRemovedElements = 0L
if (retRow != null) {
if (stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
store.remove(compositeKeyRow, stateName)
numRemovedElements += 1
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows")
}
unsafeRowPairIterator.foreach { rowPair =>
clearAllStateForElementKey(rowPair.key)
}
numRemovedElements
}

/*
Expand Down Expand Up @@ -243,30 +207,18 @@ class MapStateImplWithTTL[K, V](
* grouping key.
*/
private[sql] def getKeyValuesInTTLState(): Iterator[(K, Long)] = {
val ttlIterator = ttlIndexIterator()
val implicitGroupingKey = stateTypesEncoder.encodeGroupingKey()
var nextValue: Option[(K, Long)] = None

new Iterator[(K, Long)] {
override def hasNext: Boolean = {
while (nextValue.isEmpty && ttlIterator.hasNext) {
val nextTtlValue = ttlIterator.next()
val groupingKey = nextTtlValue.groupingKey
if (groupingKey equals implicitGroupingKey.getStruct(
0, keyExprEnc.schema.length)) {
val userKey = stateTypesEncoder.decodeUserKey(
nextTtlValue.userKey)
nextValue = Some(userKey.asInstanceOf[K], nextTtlValue.expirationMs)
}
}
nextValue.isDefined
}

override def next(): (K, Long) = {
val result = nextValue.get
nextValue = None
result
}
.getStruct(0, keyExprEnc.schema.length)

// We're getting composite rows back
getTTLRows().filter { ttlRow =>
val compositeKey = ttlRow.elementKey
val groupingKey = compositeKey.getStruct(0, keyExprEnc.schema.length)
groupingKey == implicitGroupingKey
}.map { ttlRow =>
val compositeKey = ttlRow.elementKey
val userKey = stateTypesEncoder.decodeCompositeKey(compositeKey)
(userKey.asInstanceOf[K], ttlRow.expirationMs)
}
}
}
Loading

0 comments on commit 02bfce6

Please sign in to comment.