Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44839][SS][CONNECT] Better Error Logging when user tries to serialize spark session #42594

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyspark/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
PySparkRuntimeError,
PySparkAssertionError,
PySparkNotImplementedError,
PySparkPicklingError,
)


Expand All @@ -67,4 +68,5 @@
"PySparkRuntimeError",
"PySparkAssertionError",
"PySparkNotImplementedError",
"PySparkPicklingError",
]
5 changes: 5 additions & 0 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,11 @@
"pandas iterator UDF should exhaust the input iterator."
]
},
"STREAMING_CONNECT_SERIALIZATION_ERROR" : {
"message" : [
"Cannot serialize the function `<name>`. If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For `foreachBatch`, please access the Spark session using `df.sparkSession`, where `df` is the first parameter in your `foreachBatch` function. For `StreamingQueryListener`, please access the Spark session using `self.spark`. For details please check out the PySpark doc for `foreachBatch` and `StreamingQueryListener`."
]
},
"TOO_MANY_VALUES" : {
"message" : [
"Expected <expected> values for `<item>`, got <actual>."
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/errors/exceptions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, Optional, cast

from pyspark.errors.utils import ErrorClassesReader
from pickle import PicklingError


class PySparkException(Exception):
Expand Down Expand Up @@ -226,3 +227,9 @@ class PySparkNotImplementedError(PySparkException, NotImplementedError):
"""
Wrapper class for NotImplementedError to support error classes.
"""


class PySparkPicklingError(PySparkException, PicklingError):
"""
Wrapper class for pickle.PicklingError to support error classes.
"""
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
LiteralExpression,
)
from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkPicklingError

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
Expand Down Expand Up @@ -2206,7 +2206,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF:
try:
udtf.command = CloudPickleSerializer().dumps(self._func)
except pickle.PicklingError:
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
import sys
import pickle
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional

from pyspark.errors import StreamingQueryException, PySparkValueError
Expand All @@ -32,6 +33,7 @@
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.errors import PySparkPicklingError

__all__ = ["StreamingQuery", "StreamingQueryManager"]

Expand Down Expand Up @@ -237,7 +239,13 @@ def addListener(self, listener: StreamingQueryListener) -> None:
listener._init_listener_id()
cmd = pb2.StreamingQueryManagerCommand()
expr = proto.PythonUDF()
expr.command = CloudPickleSerializer().dumps(listener)
try:
expr.command = CloudPickleSerializer().dumps(listener)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "addListener"},
)
expr.python_ver = get_python_ver()
cmd.add_listener.python_listener_payload.CopyFrom(expr)
cmd.add_listener.id = listener._id
Expand Down
27 changes: 20 additions & 7 deletions python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
check_dependencies(__name__)

import sys
import pickle
from typing import cast, overload, Callable, Dict, List, Optional, TYPE_CHECKING, Union

from pyspark.serializers import CloudPickleSerializer
Expand All @@ -33,7 +34,7 @@
)
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.types import Row, StructType
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkPicklingError

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession
Expand Down Expand Up @@ -488,18 +489,30 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt
serializer = AutoBatchedSerializer(CPickleSerializer())
command = (func, None, serializer, serializer)
# Python ForeachWriter isn't really a PythonUDF. But we reuse it for simplicity.
self._write_proto.foreach_writer.python_function.command = CloudPickleSerializer().dumps(
command
)
try:
self._write_proto.foreach_writer.python_function.command = (
CloudPickleSerializer().dumps(command)
)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreach"},
)
self._write_proto.foreach_writer.python_function.python_ver = "%d.%d" % sys.version_info[:2]
return self

foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__

def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter":
self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps(
func
)
try:
self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps(
func
)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreachBatch"},
)
self._write_proto.foreach_batch.python_function.python_ver = get_python_ver()
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyspark.sql.tests.streaming.test_streaming_foreachBatch import StreamingTestsForeachBatchMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.errors import PySparkPicklingError


class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase):
Expand All @@ -30,6 +31,35 @@ def test_streaming_foreachBatch_propagates_python_errors(self):
def test_streaming_foreachBatch_graceful_stop(self):
super().test_streaming_foreachBatch_graceful_stop()

# class StreamingForeachBatchParityTests(ReusedConnectTestCase):
def test_accessing_spark_session(self):
spark = self.spark

def func(df, _):
spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect()

error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)

def test_accessing_spark_session_through_df(self):
dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")])

def func(df, _):
dataframe.collect()

error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)


if __name__ == "__main__":
import unittest
Expand Down
49 changes: 49 additions & 0 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest
import time

from pyspark.errors import PySparkPicklingError
from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin
from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent
from pyspark.sql.types import StructType, StructField, StringType
Expand Down Expand Up @@ -83,6 +84,54 @@ def test_listener_events(self):
# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)

def test_accessing_spark_session(self):
spark = self.spark

class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect()

def onQueryProgress(self, event):
pass

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
pass

error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)

def test_accessing_spark_session_through_df(self):
dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")])

class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
dataframe.collect()

def onQueryProgress(self, event):
pass

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
pass

error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)


if __name__ == "__main__":
import unittest
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
PythonException,
PySparkTypeError,
AnalysisException,
PySparkRuntimeError,
PySparkPicklingError,
)
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
Expand Down Expand Up @@ -872,7 +872,7 @@ def eval(self):
file_obj
yield 1,

with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"):
with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"):
TestUDTF().collect()

def test_udtf_access_spark_session(self):
Expand All @@ -884,7 +884,7 @@ def eval(self):
df.collect()
yield 1,

with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"):
with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"):
TestUDTF().collect()

def test_udtf_no_eval(self):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from py4j.java_gateway import JavaObject

from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError
from pyspark.errors import PySparkAttributeError, PySparkPicklingError, PySparkTypeError
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
Expand Down Expand Up @@ -234,7 +234,7 @@ def _create_judtf(self, func: Type) -> JavaObject:
wrapped_func = _wrap_function(sc, func)
except pickle.PicklingError as e:
if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e):
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand All @@ -244,7 +244,7 @@ def _create_judtf(self, func: Type) -> JavaObject:
"and try again.",
},
) from None
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand Down