diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index bc1a5ae17e4d5..1c7ccaf4fe2a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -97,6 +97,15 @@ case class FlatMapGroupsInPandasWithStateExec( private lazy val unsafeProj = UnsafeProjection.create(dedupAttributesWithNull, childOutputWithNull) + // See processTimedOutState: we create a row which contains the actual values for grouping key, + // but all nulls for value side by intention. The schema for this row is different from + // child.output, hence we should create another projection to deal with such schema. + private lazy val valueAttributesWithNull = childOutputWithNull.filterNot { attr => + groupingAttributes.exists(_.withNullability(newNullability = true) == attr) + } + private lazy val unsafeProjForTimedOut = UnsafeProjection.create(dedupAttributesWithNull, + groupingAttributes ++ valueAttributesWithNull) + override def requiredChildDistribution: Seq[Distribution] = StatefulOperatorPartitioning.getCompatibleDistribution( groupingAttributes, getStateInfo, conf) :: Nil @@ -142,12 +151,10 @@ case class FlatMapGroupsInPandasWithStateExec( state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } + val emptyValueRow = new GenericInternalRow( + Array.fill(valueAttributesWithNull.length)(null: Any)) val processIter = timingOutPairs.map { stateData => - val joinedKeyRow = unsafeProj( - new JoinedRow( - stateData.keyRow, - new GenericInternalRow(Array.fill(dedupAttributesWithNull.length)(null: Any)))) - + val joinedKeyRow = unsafeProjForTimedOut(new JoinedRow(stateData.keyRow, emptyValueRow)) (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 785cbec9805b0..a83f7cce4c1cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -803,4 +803,87 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))) ) } + + test("SPARK-41261: applyInPandasWithState - key in user function should be correct for " + + "timed out state despite of place for key columns") { + assume(shouldTestPandasUDFs) + + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key1", StringType()), + | StructField("key2", StringType()), + | StructField("countAsStr", StringType())]) + | + |def func(key, pdf_iter, state): + | ret = None + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count, )) + | state.setTimeoutDuration(10000) + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + // schema: val1, key2, val2, key1, val3 + val inputDataDS = inputData.toDS + .withColumnRenamed("value", "val1") + .withColumn("key2", $"val1") + // the type of columns with string literal will be non-nullable + .withColumn("val2", lit("__FAKE__")) + .withColumn("key1", lit("__FAKE__")) + .withColumn("val3", lit("__FAKE__")) + val outputStructType = StructType( + Seq( + StructField("key1", StringType), + StructField("key2", StringType), + StructField("countAsStr", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + // grouping columns: key1, key2 (swapped order) + .groupBy("key1", "key2") + .applyInPandasWithState( + pythonFunc( + inputDataDS("val1"), inputDataDS("key2"), inputDataDS("val2"), inputDataDS("key1"), + inputDataDS("val3") + ).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("__FAKE__", "a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(11 * 1000), + CheckNewAnswer(("__FAKE__", "a", "-1"), ("__FAKE__", "b", "1")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } }