diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2b9d526937942..f9e2144d334e3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -452,6 +452,7 @@ def __hash__(self): "pyspark.sql.tests.test_group", "pyspark.sql.tests.test_pandas_cogrouped_map", "pyspark.sql.tests.test_pandas_grouped_map", + "pyspark.sql.tests.test_pandas_grouped_map_with_state", "pyspark.sql.tests.test_pandas_map", "pyspark.sql.tests.test_arrow_map", "pyspark.sql.tests.test_pandas_udf", diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py new file mode 100644 index 0000000000000..7eb3bb92b843e --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import cast + +from pyspark.sql.streaming.state import GroupStateTimeout, GroupState +from pyspark.sql.types import ( + LongType, + StringType, + StructType, + StructField, + Row, +) +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pandas: + import pandas as pd + +if have_pyarrow: + import pyarrow as pa # noqa: F401 + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): + def test_apply_in_pandas_with_state_basic(self): + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_type = StructType( + [StructField("key", StringType()), StructField("countAsString", StringType())] + ) + state_type = StructType([StructField("c", LongType())]) + + def func(key, pdf_iter, state): + assert isinstance(state, GroupState) + + total_len = 0 + for pdf in pdf_iter: + total_len += len(pdf) + + state.update((total_len,)) + assert state.get[0] == 1 + yield pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]}) + + def check_results(batch_df, _): + self.assertEqual( + set(batch_df.collect()), + {Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")}, + ) + + q = ( + df.groupBy(df["value"]) + .applyInPandasWithState( + func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout + ) + .writeStream.queryName("this_query") + .foreachBatch(check_results) + .outputMode("update") + .start() + ) + + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 827cfcf32fead..9e97e53d9e835 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType} /** * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF, @@ -190,7 +191,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasFunc: Array[Byte] = if (shouldTestScalarPandasUDFs) { + private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -213,7 +214,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) { + private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -235,6 +236,33 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } + private def createPandasGroupedMapFuncWithState(pythonScript: String): Array[Byte] = { + if (shouldTestPandasUDFs) { + var binaryPandasFunc: Array[Byte] = null + withTempPath { codePath => + Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8)) + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + "from pyspark.serializers import CloudPickleSerializer; " + + s"f = open('$path', 'wb');" + + s"exec(open('$codePath', 'r').read());" + + "f.write(CloudPickleSerializer().dumps((" + + "func, tpe)))"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } else { + throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") + } + } + // Make sure this map stays mutable - this map gets updated later in Python runners. private val workerEnv = new java.util.HashMap[String, String]() workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") @@ -251,11 +279,9 @@ object IntegratedUDFTestUtils extends SQLHelper { lazy val shouldTestPythonUDFs: Boolean = isPythonAvailable && isPySparkAvailable - lazy val shouldTestScalarPandasUDFs: Boolean = + lazy val shouldTestPandasUDFs: Boolean = isPythonAvailable && isPandasAvailable && isPyArrowAvailable - lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs - /** * A base trait for various UDFs defined in this object. */ @@ -420,6 +446,41 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String = "Grouped Aggregate Pandas UDF" } + /** + * Arbitrary stateful processing in Python is used for + * `DataFrame.groupBy.applyInPandasWithState`. It requires `pythonScript` to + * define `func` (Python function) and `tpe` (`StructType` for state key). + * + * Virtually equivalent to: + * + * {{{ + * # exec defines 'func' and 'tpe' (struct type for state key) + * exec(pythonScript) + * + * # ... are filled when this UDF is invoked, see also 'PythonFlatMapGroupsWithStateSuite'. + * df.groupBy(...).applyInPandasWithState(func, ..., tpe, ..., ...) + * }}} + */ + case class TestGroupedMapPandasUDFWithState(name: String, pythonScript: String) extends TestUDF { + private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( + name = name, + func = SimplePythonFunction( + command = createPandasGroupedMapFuncWithState(pythonScript), + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = NullType, // This is not respected. + pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + udfDeterministic = true) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Grouped Map Pandas UDF with State" + } + /** * A Scala UDF that takes one column, casts into string, executes the * Scala native function, and casts back to the type of input column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 83980535f878c..1585e6342ec28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -251,14 +251,14 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper /* Do nothing */ } case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestPandasUDFs => ignore(s"${testCase.name} is skipped because pyspark," + s"pandas and/or pyarrow were not available in [$pythonExec].") { /* Do nothing */ } case udfTestCase: UDFTest if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] && - !shouldTestGroupedAggPandasUDFs => + !shouldTestPandasUDFs => ignore(s"${testCase.name} is skipped because pyspark," + s"pandas and/or pyarrow were not available in [$pythonExec].") { /* Do nothing */ @@ -447,12 +447,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper if udfTestCase.udf.isInstanceOf[TestPythonUDF] && shouldTestPythonUDFs => s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}" case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestPandasUDFs => s"${testCase.name}${System.lineSeparator()}" + s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}" case udfTestCase: UDFTest if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] && - shouldTestGroupedAggPandasUDFs => + shouldTestPandasUDFs => s"${testCase.name}${System.lineSeparator()}" + s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}" case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 00c774e2d1bee..92aadb6779e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -128,7 +128,7 @@ class QueryCompilationErrorsSuite test("INVALID_PANDAS_UDF_PLACEMENT: Using aggregate function with grouped aggregate pandas UDF") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), @@ -180,7 +180,7 @@ class QueryCompilationErrorsSuite test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 4ad7f90105373..42e4b1accde72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -73,7 +73,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-39962: Global aggregation of Pandas UDF should respect the column order") { - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPythonUDFs) val df = Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("a", "b") val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala new file mode 100644 index 0000000000000..985e45602a2ba --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.IntegratedUDFTestUtils.{shouldTestPandasUDFs, TestGroupedMapPandasUDFWithState} +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.util.{StatefulOpClusteredDistributionTestHelper, StreamManualClock} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} + +class FlatMapGroupsInPandasWithStateDistributionSuite extends StreamTest + with StatefulOpClusteredDistributionTestHelper { + + import testImplicits._ + + test("applyInPandasWithState should require StatefulOpClusteredDistribution " + + "from children - without initial state") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType, IntegerType + | + |tpe = StructType([ + | StructField("key1", StringType()), + | StructField("key2", StringType()), + | StructField("count", IntegerType())]) + | + |def func(key, pdf_iter, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | if count >= 3: + | state.remove() + | yield pd.DataFrame() + | else: + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'count': [count]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[(String, String, Long)] + val outputStructType = StructType( + Seq( + StructField("key1", StringType), + StructField("key2", StringType), + StructField("count", IntegerType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + val result = + inputDataDS + .withWatermark("timestamp", "10 second") + .repartition($"key1") + .groupBy($"key1", $"key2") + .applyInPandasWithState( + pythonFunc(inputDataDS("key1"), inputDataDS("key2"), inputDataDS("timestamp")) + .expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + .select("key1", "key2", "count") + + val clock = new StreamManualClock + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a is processed here for the first time. + CheckNewAnswer(("a", "a", 1)), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsInPandasWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsInPandasWithStateExec => f + } + + assert(flatMapGroupsInPandasWithStateExecs.length === 1) + assert(requireStatefulOpClusteredDistribution( + flatMapGroupsInPandasWithStateExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsInPandasWithStateExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala new file mode 100644 index 0000000000000..d8f7aeb5ac886 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.IntegratedUDFTestUtils._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update} +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ + +class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { + + import testImplicits._ + + test("applyInPandasWithState - streaming") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | if count >= 3: + | state.remove() + | yield pd.DataFrame() + | else: + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("applyInPandasWithState - streaming, multiple groups in partition, " + + "multiple outputs per grouping key") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import IntegerType, StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("value", IntegerType()), + | StructField("valueAsString", StringType()), + | StructField("prevCountAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | prev_count = state.getOption + | if prev_count is None: + | prev_count = 0 + | else: + | prev_count = prev_count[0] + | + | count = prev_count + | for pdf in pdf_iter: + | count += len(pdf) + | yield pdf.assign(valueAsString=lambda x: x.value.apply(str), + | prevCountAsString=str(prev_count)) + | + | state.update((count,)) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputData = MemoryStream[(String, Int)] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("value", IntegerType), + StructField("valueAsString", StringType), + StructField("prevCountAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS().selectExpr("_1 AS key", "_2 AS value") + val result = + inputDataDS + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDS("key"), inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + .select("key", "value", "valueAsString", "prevCountAsString") + + testStream(result, Update)( + AddData(inputData, ("a", 1)), + CheckNewAnswer(("a", 1, "1", "0")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, ("a", 2), ("a", 3), ("b", 1)), + CheckNewAnswer(("a", 2, "2", "1"), ("a", 3, "3", "1"), ("b", 1, "1", "0")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, ("b", 2), ("c", 1), ("d", 1), ("e", 1)), + CheckNewAnswer(("b", 2, "2", "1"), ("c", 1, "1", "0"), ("d", 1, "1", "0"), + ("e", 1, "1", "0")), + assertNumStateRows(total = 5, updated = 4), + AddData(inputData, ("a", 4)), + CheckNewAnswer(("a", 4, "4", "3")) + ) + } + } + + test("applyInPandasWithState - streaming + aggregation") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count,)) + | + | ret = pd.DataFrame() + | if count >= 3: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Append", + "NoTimeout") + .groupBy("key") + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckNewAnswer(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckNewAnswer(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckNewAnswer(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("applyInPandasWithState - streaming with processing time timeout") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | ret = None + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count,)) + | state.setTimeoutDuration(10000) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + StopStream, + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(11 * 1000), + CheckNewAnswer(("b", "-1"), ("c", "1")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows( + total = Seq(0), updated = Seq(0), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } + + test("applyInPandasWithState - streaming w/ event time timeout + watermark") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + val pythonScript = + """ + |import calendar + |import os + |import datetime + |import pandas as pd + |from pyspark.sql.types import StructType, StringType, StructField, IntegerType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("maxEventTimeSec", IntegerType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | assert state.getCurrentWatermarkMs() >= -1 + | + | timeout_delay_sec = 5 + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) + | else: + | m = state.getOption + | if m is None: + | max_event_time_sec = 0 + | else: + | max_event_time_sec = m[0] + | + | for pdf in pdf_iter: + | pser = pdf.eventTime.apply( + | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) + | max_event_time_sec = int(max(pser.max(), max_event_time_sec)) + | + | state.update((max_event_time_sec,)) + | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec + | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + | yield pd.DataFrame({'key': [key[0]], + | 'maxEventTimeSec': [max_event_time_sec]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[(String, Int)] + val inputDataDF = + inputData.toDF.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("maxEventTimeSec", IntegerType))) + val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) + val result = + inputDataDF + .withWatermark("eventTime", "10 seconds") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "EventTimeTimeout") + + testStream(result, Update)( + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + } + + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // String, (String, Long), RunningCount(Long) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count,)) + | state.setTimeoutDuration(10000) + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val inputDataDF = inputData + .toDF.toDF("key", "time") + .selectExpr("key", "timestamp_seconds(time) as timestamp") + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDF + .withWatermark("timestamp", "10 second") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("timestamp")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")) + ) + } + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) + + test("applyInPandasWithState - uses state format version 2 by default") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | state.update((count,)) + | + | if count >= 3: + | state.remove() + | yield pd.DataFrame() + | else: + | yield pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsInPandasWithStateExec => f + } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } + + test("applyInPandasWithState - streaming - arrow RecordBatch size with chunking") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import IntegerType, StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("chunkCount", IntegerType())]) + | + |def func(key, pdf_iter, state): + | chunk_count = 0 + | for pdf in pdf_iter: + | chunk_count += 1 + | yield pd.DataFrame({'key': [key[0]], 'chunkCount': [chunk_count]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val testData = Seq("a", "a", "a", "b", "b", "b", "b", "b") + + // check a few cases which should provide deterministic number of chunks + val arrowBatchSizeAndExpectedOutputs = Seq( + (1, Seq(Row("a", 3), Row("b", 5))), + (2, Seq(Row("a", 2), Row("b", 3))), + (4, Seq(Row("a", 1), Row("b", 2))), + (8, Seq(Row("a", 1), Row("b", 1))) + ) + + arrowBatchSizeAndExpectedOutputs.foreach { case (arrowBatchSize, expectedOutputs) => + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowBatchSize.toString) { + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("chunkCount", IntegerType))) + val stateStructType = StructType(Seq(StructField("notUsed", IntegerType))) + val inputDataDS = inputData.toDS().selectExpr("value AS key") + val result = + inputDataDS + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDS("key")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, testData: _*), + CheckNewAnswer(expectedOutputs: _*) + ) + } + } + } + + test("applyInPandasWithState - streaming - partial consume of iterator in user function") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import IntegerType, StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("numBatches", IntegerType())]) + | + |def func(key, pdf_iter, state): + | numBatches = state.getOption + | if numBatches is None: + | numBatches = 0 + | else: + | numBatches = numBatches[0] + | numBatches += 1 + | + | # only consume the first element in the iterator + | pdf = next(pdf_iter) + | state.update((numBatches, )) + | yield pd.DataFrame({'key': [key[0]], 'numBatches': [numBatches]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + (1 to 3).foreach { arrowBatchSize => + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowBatchSize.toString) { + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("numBatches", IntegerType))) + val stateStructType = StructType(Seq(StructField("numBatches", IntegerType))) + val inputDataDS = inputData.toDS().selectExpr("value AS key") + val result = + inputDataDS + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDS("key")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + val testData = (1 to arrowBatchSize * 2).map(_ => "a") ++ + (1 to (arrowBatchSize * 2.5).toInt).map(_ => "b") ++ + Seq("c") + + testStream(result, Update)( + AddData(inputData, testData: _*), + CheckNewAnswer(("a", 1), ("b", 1), ("c", 1)), + AddData(inputData, testData: _*), + CheckNewAnswer(("a", 2), ("b", 2), ("c", 2)) + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 26c201d5921ed..e6a0af48190ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -279,7 +279,7 @@ class ContinuousSuite extends ContinuousSuiteBase { Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).foreach { udf => test(s"continuous mode with various UDFs - ${udf.prettyName}") { assume( - shouldTestScalarPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || + shouldTestPythonUDFs && udf.isInstanceOf[TestScalarPandasUDF] || shouldTestPythonUDFs && udf.isInstanceOf[TestPythonUDF] || udf.isInstanceOf[TestScalaUDF])