Skip to content

Commit

Permalink
[SPARK-40571][SS][TESTS] Construct a new test case for applyInPandasW…
Browse files Browse the repository at this point in the history
…ithState to verify fault-tolerance semantic with random python worker failures

### What changes were proposed in this pull request?

This PR proposes a new test case for applyInPandasWithState to verify fault-tolerance semantic is not broken despite of random python worker failure. If the sink provides end-to-end exactly-once, the query should respect the guarantee. Otherwise, the query should respect stateful exactly-once, but at-least-once in terms of outputs.

The test leverages file stream sink which is end-to-end exactly-once, but to make the verification simpler, we just verify whether the stateful exactly-once is guaranteed despite of python worker failures.

### Why are the changes needed?

This strengthen the test coverage, especially the fault-tolerance semantic.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test added.  Manually ran `./python/run-tests --testnames 'pyspark.sql.tests.test_pandas_grouped_map_with_state'` 10 times and all succeeded.

Closes apache#38008 from HeartSaVioR/SPARK-40571.

Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
HeartSaVioR committed Sep 27, 2022
1 parent 311a855 commit 37517df
Showing 1 changed file with 147 additions and 2 deletions.
149 changes: 147 additions & 2 deletions python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
# limitations under the License.
#

import random
import shutil
import string
import sys
import tempfile

import unittest
from typing import cast

from pyspark import SparkConf
from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
from pyspark.sql.types import (
LongType,
Expand All @@ -33,6 +40,7 @@
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import eventually

if have_pandas:
import pandas as pd
Expand All @@ -46,8 +54,23 @@
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
@classmethod
def conf(cls):
cfg = SparkConf()
cfg.set("spark.sql.shuffle.partitions", "5")
return cfg

def test_apply_in_pandas_with_state_basic(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
input_path = tempfile.mkdtemp()

def prepare_test_resource():
with open(input_path + "/text-test.txt", "w") as fw:
fw.write("hello\n")
fw.write("this\n")

prepare_test_resource()

df = self.spark.readStream.format("text").load(input_path)

for q in self.spark.streams.active:
q.stop()
Expand All @@ -71,7 +94,7 @@ def func(key, pdf_iter, state):

def check_results(batch_df, _):
self.assertEqual(
set(batch_df.collect()),
set(batch_df.sort("key").collect()),
{Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")},
)

Expand All @@ -90,6 +113,128 @@ def check_results(batch_df, _):
self.assertTrue(q.isActive)
q.processAllAvailable()

def test_apply_in_pandas_with_state_python_worker_random_failure(self):
input_path = tempfile.mkdtemp()
output_path = tempfile.mkdtemp()
checkpoint_loc = tempfile.mkdtemp()

shutil.rmtree(output_path)
shutil.rmtree(checkpoint_loc)

def prepare_test_resource():
data_range = list(string.ascii_lowercase)
for i in range(5):
picked_data = [
data_range[random.randrange(0, len(data_range) - 1)] for x in range(100)
]

with open(input_path + "/part-%i.txt" % i, "w") as fw:
for data in picked_data:
fw.write(data + "\n")

def run_query():
df = (
self.spark.readStream.format("text")
.option("maxFilesPerTrigger", "1")
.load(input_path)
)

for q in self.spark.streams.active:
q.stop()
self.assertTrue(df.isStreaming)

output_type = StructType(
[StructField("value", StringType()), StructField("count", LongType())]
)
state_type = StructType([StructField("cnt", LongType())])

def func(key, pdf_iter, state):
assert isinstance(state, GroupState)

# user function call will happen at most 26 times
# should be huge enough to not trigger kill in every batches
# but should be also reasonable to trigger kill multiple times across batches
if random.randrange(30) == 1:
sys.exit(1)

count = state.getOption
if count is None:
count = 0
else:
count = count[0]

for pdf in pdf_iter:
count += len(pdf)

state.update((count,))
yield pd.DataFrame({"value": [key[0]], "count": [count]})

query = (
df.groupBy(df["value"])
.applyInPandasWithState(
func, output_type, state_type, "Append", GroupStateTimeout.NoTimeout
)
.writeStream.queryName("this_query")
.format("json")
.outputMode("append")
.option("path", output_path)
.option("checkpointLocation", checkpoint_loc)
.start()
)

return query

prepare_test_resource()

expected = (
self.spark.read.format("text")
.load(input_path)
.groupBy("value")
.count()
.sort("value")
.collect()
)

q = run_query()
self.assertEqual(q.name, "this_query")
self.assertTrue(q.isActive)

def assert_test():
nonlocal q
if not q.isActive:
print("query has been terminated, rerunning query...")

# rerunning query as the query may have been killed by killed python worker
q = run_query()

self.assertEqual(q.name, "this_query")
self.assertTrue(q.isActive)

curr_status = q.status
if not curr_status["isDataAvailable"] and not curr_status["isTriggerActive"]:
# The query is active but not running due to no further data available
# Check the output now.
result = (
self.spark.read.schema("value string, count int")
.format("json")
.load(output_path)
.groupBy("value")
.max("count")
.selectExpr("value", "`max(count)` AS count")
.sort("value")
.collect()
)

return result == expected
else:
# still processing the data, defer checking the output.
return False

try:
eventually(assert_test, timeout=120)
finally:
q.stop()


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401
Expand Down

0 comments on commit 37517df

Please sign in to comment.