Skip to content

Commit

Permalink
[SPARK-40435][SS][PYTHON] Add test suites for applyInPandasWithState …
Browse files Browse the repository at this point in the history
…in PySpark

### What changes were proposed in this pull request?

This PR adds the test suites for apache#37893, applyInPandasWithState. The new test suite mostly ports E2E test cases from existing [flatMapGroupsWithState](https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala).

### Why are the changes needed?

Tests are missing in apache#37893 by intention to reduce the size of change, and this PR fills the gap.

### Does this PR introduce _any_ user-facing change?

No, test only.

### How was this patch tested?

New test suites.

Closes apache#37894 from HeartSaVioR/SPARK-40435-on-top-of-SPARK-40434-SPARK-40433-SPARK-40432.

Lead-authored-by: Jungtaek Lim <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
HeartSaVioR and HyukjinKwon committed Sep 22, 2022
1 parent 32dd753 commit c22ddbe
Show file tree
Hide file tree
Showing 9 changed files with 1,035 additions and 14 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def __hash__(self):
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.test_pandas_cogrouped_map",
"pyspark.sql.tests.test_pandas_grouped_map",
"pyspark.sql.tests.test_pandas_grouped_map_with_state",
"pyspark.sql.tests.test_pandas_map",
"pyspark.sql.tests.test_arrow_map",
"pyspark.sql.tests.test_pandas_udf",
Expand Down
103 changes: 103 additions & 0 deletions python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
from typing import cast

from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
from pyspark.sql.types import (
LongType,
StringType,
StructType,
StructField,
Row,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)

if have_pandas:
import pandas as pd

if have_pyarrow:
import pyarrow as pa # noqa: F401


@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
def test_apply_in_pandas_with_state_basic(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")

for q in self.spark.streams.active:
q.stop()
self.assertTrue(df.isStreaming)

output_type = StructType(
[StructField("key", StringType()), StructField("countAsString", StringType())]
)
state_type = StructType([StructField("c", LongType())])

def func(key, pdf_iter, state):
assert isinstance(state, GroupState)

total_len = 0
for pdf in pdf_iter:
total_len += len(pdf)

state.update((total_len,))
assert state.get[0] == 1
yield pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]})

def check_results(batch_df, _):
self.assertEqual(
set(batch_df.collect()),
{Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")},
)

q = (
df.groupBy(df["value"])
.applyInPandasWithState(
func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout
)
.writeStream.queryName("this_query")
.foreachBatch(check_results)
.outputMode("update")
.start()
)

self.assertEqual(q.name, "this_query")
self.assertTrue(q.isActive)
q.processAllAvailable()


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}

import scala.collection.JavaConverters._
Expand All @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType}

/**
* This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF,
Expand Down Expand Up @@ -190,7 +191,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

private lazy val pandasFunc: Array[Byte] = if (shouldTestScalarPandasUDFs) {
private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { path =>
Process(
Expand All @@ -213,7 +214,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) {
private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { path =>
Process(
Expand All @@ -235,6 +236,33 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

private def createPandasGroupedMapFuncWithState(pythonScript: String): Array[Byte] = {
if (shouldTestPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { codePath =>
Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
withTempPath { path =>
Process(
Seq(
pythonExec,
"-c",
"from pyspark.serializers import CloudPickleSerializer; " +
s"f = open('$path', 'wb');" +
s"exec(open('$codePath', 'r').read());" +
"f.write(CloudPickleSerializer().dumps((" +
"func, tpe)))"),
None,
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
binaryPandasFunc = Files.readAllBytes(path.toPath)
}
}
assert(binaryPandasFunc != null)
binaryPandasFunc
} else {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}
}

// Make sure this map stays mutable - this map gets updated later in Python runners.
private val workerEnv = new java.util.HashMap[String, String]()
workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
Expand All @@ -251,11 +279,9 @@ object IntegratedUDFTestUtils extends SQLHelper {

lazy val shouldTestPythonUDFs: Boolean = isPythonAvailable && isPySparkAvailable

lazy val shouldTestScalarPandasUDFs: Boolean =
lazy val shouldTestPandasUDFs: Boolean =
isPythonAvailable && isPandasAvailable && isPyArrowAvailable

lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs

/**
* A base trait for various UDFs defined in this object.
*/
Expand Down Expand Up @@ -420,6 +446,41 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String = "Grouped Aggregate Pandas UDF"
}

/**
* Arbitrary stateful processing in Python is used for
* `DataFrame.groupBy.applyInPandasWithState`. It requires `pythonScript` to
* define `func` (Python function) and `tpe` (`StructType` for state key).
*
* Virtually equivalent to:
*
* {{{
* # exec defines 'func' and 'tpe' (struct type for state key)
* exec(pythonScript)
*
* # ... are filled when this UDF is invoked, see also 'PythonFlatMapGroupsWithStateSuite'.
* df.groupBy(...).applyInPandasWithState(func, ..., tpe, ..., ...)
* }}}
*/
case class TestGroupedMapPandasUDFWithState(name: String, pythonScript: String) extends TestUDF {
private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(
name = name,
func = SimplePythonFunction(
command = createPandasGroupedMapFuncWithState(pythonScript),
envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
pythonIncludes = List.empty[String].asJava,
pythonExec = pythonExec,
pythonVer = pythonVer,
broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
accumulator = null),
dataType = NullType, // This is not respected.
pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
udfDeterministic = true)

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Grouped Map Pandas UDF with State"
}

/**
* A Scala UDF that takes one column, casts into string, executes the
* Scala native function, and casts back to the type of input column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
/* Do nothing */
}
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestScalarPandasUDFs =>
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestPandasUDFs =>
ignore(s"${testCase.name} is skipped because pyspark," +
s"pandas and/or pyarrow were not available in [$pythonExec].") {
/* Do nothing */
}
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
!shouldTestGroupedAggPandasUDFs =>
!shouldTestPandasUDFs =>
ignore(s"${testCase.name} is skipped because pyspark," +
s"pandas and/or pyarrow were not available in [$pythonExec].") {
/* Do nothing */
Expand Down Expand Up @@ -447,12 +447,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
if udfTestCase.udf.isInstanceOf[TestPythonUDF] && shouldTestPythonUDFs =>
s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}"
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestScalarPandasUDFs =>
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestPandasUDFs =>
s"${testCase.name}${System.lineSeparator()}" +
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
shouldTestGroupedAggPandasUDFs =>
shouldTestPandasUDFs =>
s"${testCase.name}${System.lineSeparator()}" +
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class QueryCompilationErrorsSuite

test("INVALID_PANDAS_UDF_PLACEMENT: Using aggregate function with grouped aggregate pandas UDF") {
import IntegratedUDFTestUtils._
assume(shouldTestGroupedAggPandasUDFs)
assume(shouldTestPandasUDFs)

val df = Seq(
(536361, "85123A", 2, 17850),
Expand Down Expand Up @@ -180,7 +180,7 @@ class QueryCompilationErrorsSuite

test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") {
import IntegratedUDFTestUtils._
assume(shouldTestGroupedAggPandasUDFs)
assume(shouldTestPandasUDFs)

val df = Seq(
(536361, "85123A", 2, 17850),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
}

test("SPARK-39962: Global aggregation of Pandas UDF should respect the column order") {
assume(shouldTestGroupedAggPandasUDFs)
assume(shouldTestPythonUDFs)
val df = Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("a", "b")

val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf")
Expand Down
Loading

0 comments on commit c22ddbe

Please sign in to comment.