From 436ce5f3de3321162694a08c377f2de24e4741c3 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sun, 27 Nov 2022 11:01:32 +0900 Subject: [PATCH] [SPARK-41261][PYTHON][SS] Fix issue for applyInPandasWithState when the columns of grouping keys are not placed in order from earliest ### What changes were proposed in this pull request? This PR fixes the issue for applyInPandasWithState, which is triggered with the columns of grouping keys are not placed in order from earliest. If the condition is met, user function may get "incorrect" value of the key, including `None`. This is because the projection for the value is co-used between normal input row and row for timed-out state. The projection assumed that the schema for the row is same as output schema of the child node, whereas row for timed-out state is constructed via concatenating key row + null value row. This PR creates a separate projection for the row for timed-out state, so that the projection can pick up the values for grouping columns correctly. ### Why are the changes needed? Without this fix, user function may get "incorrect" value of the key, including `None`. ### Does this PR introduce _any_ user-facing change? No. This feature is not released yet. ### How was this patch tested? New test case. Closes #38798 from HeartSaVioR/SPARK-41261. Authored-by: Jungtaek Lim Signed-off-by: Hyukjin Kwon --- .../FlatMapGroupsInPandasWithStateExec.scala | 17 ++-- .../FlatMapGroupsInPandasWithStateSuite.scala | 83 +++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index bc1a5ae17e4d5..1c7ccaf4fe2a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -97,6 +97,15 @@ case class FlatMapGroupsInPandasWithStateExec( private lazy val unsafeProj = UnsafeProjection.create(dedupAttributesWithNull, childOutputWithNull) + // See processTimedOutState: we create a row which contains the actual values for grouping key, + // but all nulls for value side by intention. The schema for this row is different from + // child.output, hence we should create another projection to deal with such schema. + private lazy val valueAttributesWithNull = childOutputWithNull.filterNot { attr => + groupingAttributes.exists(_.withNullability(newNullability = true) == attr) + } + private lazy val unsafeProjForTimedOut = UnsafeProjection.create(dedupAttributesWithNull, + groupingAttributes ++ valueAttributesWithNull) + override def requiredChildDistribution: Seq[Distribution] = StatefulOperatorPartitioning.getCompatibleDistribution( groupingAttributes, getStateInfo, conf) :: Nil @@ -142,12 +151,10 @@ case class FlatMapGroupsInPandasWithStateExec( state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } + val emptyValueRow = new GenericInternalRow( + Array.fill(valueAttributesWithNull.length)(null: Any)) val processIter = timingOutPairs.map { stateData => - val joinedKeyRow = unsafeProj( - new JoinedRow( - stateData.keyRow, - new GenericInternalRow(Array.fill(dedupAttributesWithNull.length)(null: Any)))) - + val joinedKeyRow = unsafeProjForTimedOut(new JoinedRow(stateData.keyRow, emptyValueRow)) (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 785cbec9805b0..a83f7cce4c1cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -803,4 +803,87 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))) ) } + + test("SPARK-41261: applyInPandasWithState - key in user function should be correct for " + + "timed out state despite of place for key columns") { + assume(shouldTestPandasUDFs) + + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key1", StringType()), + | StructField("key2", StringType()), + | StructField("countAsStr", StringType())]) + | + |def func(key, pdf_iter, state): + | ret = None + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count, )) + | state.setTimeoutDuration(10000) + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + // schema: val1, key2, val2, key1, val3 + val inputDataDS = inputData.toDS + .withColumnRenamed("value", "val1") + .withColumn("key2", $"val1") + // the type of columns with string literal will be non-nullable + .withColumn("val2", lit("__FAKE__")) + .withColumn("key1", lit("__FAKE__")) + .withColumn("val3", lit("__FAKE__")) + val outputStructType = StructType( + Seq( + StructField("key1", StringType), + StructField("key2", StringType), + StructField("countAsStr", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + // grouping columns: key1, key2 (swapped order) + .groupBy("key1", "key2") + .applyInPandasWithState( + pythonFunc( + inputDataDS("val1"), inputDataDS("key2"), inputDataDS("val2"), inputDataDS("key1"), + inputDataDS("val3") + ).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("__FAKE__", "a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(11 * 1000), + CheckNewAnswer(("__FAKE__", "a", "-1"), ("__FAKE__", "b", "1")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } }