Skip to content

Commit

Permalink
[SPARK-40670][SS][PYTHON] Fix NPE in applyInPandasWithState when the …
Browse files Browse the repository at this point in the history
…input schema has "non-nullable" column(s)

### What changes were proposed in this pull request?

This PR fixes a bug which occurs NPE when the input schema of applyInPandasWithState has "non-nullable" column(s).
This PR also leaves a code comment explaining the fix. Quoting:

```
  // See processTimedOutState: we create a row which contains the actual values for grouping key,
  // but all nulls for value side by intention. This technically changes the schema of input to
  // be "nullable", hence the schema information and the internal projection of row should take
  // this into consideration. Strictly saying, it's not applied to the part of grouping key, but
  // it doesn't hurt much even if we apply the same for grouping key as well.
```

### Why are the changes needed?

There's a bug which we didn't take the non-null columns into account. This PR fixes the bug.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New UT. The new test case failed with NPE without the fix, and succeeded with the fix.

Closes apache#38115 from HeartSaVioR/SPARK-40670.

Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
HeartSaVioR committed Oct 6, 2022
1 parent 5523972 commit edd6076
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,17 @@ case class FlatMapGroupsInPandasWithStateExec(
private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
groupingAttributes ++ child.output, groupingAttributes)
private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output)

// See processTimedOutState: we create a row which contains the actual values for grouping key,
// but all nulls for value side by intention. This technically changes the schema of input to
// be "nullable", hence the schema information and the internal projection of row should take
// this into consideration. Strictly saying, it's not applied to the part of grouping key, but
// it doesn't hurt much even if we apply the same for grouping key as well.
private lazy val dedupAttributesWithNull =
dedupAttributes.map(_.withNullability(newNullability = true))
private lazy val childOutputWithNull = child.output.map(_.withNullability(newNullability = true))
private lazy val unsafeProj = UnsafeProjection.create(dedupAttributesWithNull,
childOutputWithNull)

override def requiredChildDistribution: Seq[Distribution] =
StatefulOperatorPartitioning.getCompatibleDistribution(
Expand Down Expand Up @@ -134,7 +144,7 @@ case class FlatMapGroupsInPandasWithStateExec(
val joinedKeyRow = unsafeProj(
new JoinedRow(
stateData.keyRow,
new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any))))
new GenericInternalRow(Array.fill(dedupAttributesWithNull.length)(null: Any))))

(stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
}
Expand All @@ -150,7 +160,7 @@ case class FlatMapGroupsInPandasWithStateExec(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
Array(argOffsets),
StructType.fromAttributes(dedupAttributes),
StructType.fromAttributes(dedupAttributesWithNull),
sessionLocalTimeZone,
pythonRunnerConf,
stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTim
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update}
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.functions.{lit, timestamp_seconds}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -738,4 +738,89 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest {
}
}
}

test("SPARK-40670: applyInPandasWithState - streaming having non-null columns") {
// scalastyle:off assume
assume(shouldTestPandasUDFs)
// scalastyle:on assume

// Function to maintain the count as state and set the proc. time timeout delay of 10 seconds.
// It returns the count if changed, or -1 if the state was removed by timeout.
val pythonScript =
"""
|import pandas as pd
|from pyspark.sql.types import StructType, StructField, StringType
|
|tpe = StructType([
| StructField("key1", StringType()),
| StructField("key2", StringType()),
| StructField("countAsStr", StringType())])
|
|def func(key, pdf_iter, state):
| ret = None
| if state.hasTimedOut:
| state.remove()
| yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(-1)]})
| else:
| count = state.getOption
| if count is None:
| count = 0
| else:
| count = count[0]
|
| for pdf in pdf_iter:
| count += len(pdf)
|
| state.update((count,))
| state.setTimeoutDuration(10000)
| yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(count)]})
|""".stripMargin
val pythonFunc = TestGroupedMapPandasUDFWithState(
name = "pandas_grouped_map_with_state", pythonScript = pythonScript)

val clock = new StreamManualClock
val inputData = MemoryStream[String]
val inputDataDS = inputData.toDS
.withColumnRenamed("value", "key1")
// the type of columns with string literal will be non-nullable
.withColumn("key2", lit("__FAKE__"))
.withColumn("val1", lit("__FAKE__"))
.withColumn("val2", lit("__FAKE__"))
val outputStructType = StructType(
Seq(
StructField("key1", StringType),
StructField("key2", StringType),
StructField("countAsStr", StringType)))
val stateStructType = StructType(Seq(StructField("count", LongType)))
val result =
inputDataDS
.groupBy("key1", "key2")
.applyInPandasWithState(
pythonFunc(
inputDataDS("key1"), inputDataDS("key2"), inputDataDS("val1"), inputDataDS("val2")
).expr.asInstanceOf[PythonUDF],
outputStructType,
stateStructType,
"Update",
"ProcessingTimeTimeout")

testStream(result, Update)(
StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "a"),
AdvanceManualClock(1 * 1000),
CheckNewAnswer(("a", "__FAKE__", "1")),
assertNumStateRows(total = 1, updated = 1),

AddData(inputData, "b"),
AdvanceManualClock(1 * 1000),
CheckNewAnswer(("b", "__FAKE__", "1")),
assertNumStateRows(total = 2, updated = 1),

AddData(inputData, "b"),
AdvanceManualClock(10 * 1000),
CheckNewAnswer(("a", "__FAKE__", "-1"), ("b", "__FAKE__", "2")),
assertNumStateRows(
total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1)))
)
}
}

0 comments on commit edd6076

Please sign in to comment.