Skip to content

Commit

Permalink
[SPARK-41261][PYTHON][SS] Fix issue for applyInPandasWithState when t…
Browse files Browse the repository at this point in the history
…he columns of grouping keys are not placed in order from earliest

### What changes were proposed in this pull request?

This PR fixes the issue for applyInPandasWithState, which is triggered with the columns of grouping keys are not placed in order from earliest. If the condition is met, user function may get "incorrect" value of the key, including `None`.

This is because the projection for the value is co-used between normal input row and row for timed-out state. The projection assumed that the schema for the row is same as output schema of the child node, whereas row for timed-out state is constructed via concatenating key row + null value row.

This PR creates a separate projection for the row for timed-out state, so that the projection can pick up the values for grouping columns correctly.

### Why are the changes needed?

Without this fix, user function may get "incorrect" value of the key, including `None`.

### Does this PR introduce _any_ user-facing change?

No. This feature is not released yet.

### How was this patch tested?

New test case.

Closes apache#38798 from HeartSaVioR/SPARK-41261.

Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HeartSaVioR authored and HyukjinKwon committed Nov 27, 2022
1 parent 77e2d45 commit 436ce5f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ case class FlatMapGroupsInPandasWithStateExec(
private lazy val unsafeProj = UnsafeProjection.create(dedupAttributesWithNull,
childOutputWithNull)

// See processTimedOutState: we create a row which contains the actual values for grouping key,
// but all nulls for value side by intention. The schema for this row is different from
// child.output, hence we should create another projection to deal with such schema.
private lazy val valueAttributesWithNull = childOutputWithNull.filterNot { attr =>
groupingAttributes.exists(_.withNullability(newNullability = true) == attr)
}
private lazy val unsafeProjForTimedOut = UnsafeProjection.create(dedupAttributesWithNull,
groupingAttributes ++ valueAttributesWithNull)

override def requiredChildDistribution: Seq[Distribution] =
StatefulOperatorPartitioning.getCompatibleDistribution(
groupingAttributes, getStateInfo, conf) :: Nil
Expand Down Expand Up @@ -142,12 +151,10 @@ case class FlatMapGroupsInPandasWithStateExec(
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
}

val emptyValueRow = new GenericInternalRow(
Array.fill(valueAttributesWithNull.length)(null: Any))
val processIter = timingOutPairs.map { stateData =>
val joinedKeyRow = unsafeProj(
new JoinedRow(
stateData.keyRow,
new GenericInternalRow(Array.fill(dedupAttributesWithNull.length)(null: Any))))

val joinedKeyRow = unsafeProjForTimedOut(new JoinedRow(stateData.keyRow, emptyValueRow))
(stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,4 +803,87 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest {
total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1)))
)
}

test("SPARK-41261: applyInPandasWithState - key in user function should be correct for " +
"timed out state despite of place for key columns") {
assume(shouldTestPandasUDFs)

// Function to maintain the count as state and set the proc. time timeout delay of 10 seconds.
// It returns the count if changed, or -1 if the state was removed by timeout.
val pythonScript =
"""
|import pandas as pd
|from pyspark.sql.types import StructType, StructField, StringType
|
|tpe = StructType([
| StructField("key1", StringType()),
| StructField("key2", StringType()),
| StructField("countAsStr", StringType())])
|
|def func(key, pdf_iter, state):
| ret = None
| if state.hasTimedOut:
| state.remove()
| yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(-1)]})
| else:
| count = state.getOption
| if count is None:
| count = 0
| else:
| count = count[0]
|
| for pdf in pdf_iter:
| count += len(pdf)
|
| state.update((count, ))
| state.setTimeoutDuration(10000)
| yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(count)]})
|""".stripMargin
val pythonFunc = TestGroupedMapPandasUDFWithState(
name = "pandas_grouped_map_with_state", pythonScript = pythonScript)

val clock = new StreamManualClock
val inputData = MemoryStream[String]
// schema: val1, key2, val2, key1, val3
val inputDataDS = inputData.toDS
.withColumnRenamed("value", "val1")
.withColumn("key2", $"val1")
// the type of columns with string literal will be non-nullable
.withColumn("val2", lit("__FAKE__"))
.withColumn("key1", lit("__FAKE__"))
.withColumn("val3", lit("__FAKE__"))
val outputStructType = StructType(
Seq(
StructField("key1", StringType),
StructField("key2", StringType),
StructField("countAsStr", StringType)))
val stateStructType = StructType(Seq(StructField("count", LongType)))
val result =
inputDataDS
// grouping columns: key1, key2 (swapped order)
.groupBy("key1", "key2")
.applyInPandasWithState(
pythonFunc(
inputDataDS("val1"), inputDataDS("key2"), inputDataDS("val2"), inputDataDS("key1"),
inputDataDS("val3")
).expr.asInstanceOf[PythonUDF],
outputStructType,
stateStructType,
"Update",
"ProcessingTimeTimeout")

testStream(result, Update)(
StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "a"),
AdvanceManualClock(1 * 1000),
CheckNewAnswer(("__FAKE__", "a", "1")),
assertNumStateRows(total = 1, updated = 1),

AddData(inputData, "b"),
AdvanceManualClock(11 * 1000),
CheckNewAnswer(("__FAKE__", "a", "-1"), ("__FAKE__", "b", "1")),
assertNumStateRows(
total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1)))
)
}
}

0 comments on commit 436ce5f

Please sign in to comment.