Skip to content

Commit

Permalink
add new PySparkPicklingError
Browse files Browse the repository at this point in the history
  • Loading branch information
WweiL committed Aug 22, 2023
1 parent ab971ae commit 01618a7
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 19 deletions.
2 changes: 2 additions & 0 deletions python/pyspark/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
PySparkRuntimeError,
PySparkAssertionError,
PySparkNotImplementedError,
PySparkPicklingError,
)


Expand All @@ -67,4 +68,5 @@
"PySparkRuntimeError",
"PySparkAssertionError",
"PySparkNotImplementedError",
"PySparkPicklingError",
]
7 changes: 7 additions & 0 deletions python/pyspark/errors/exceptions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, Optional, cast

from pyspark.errors.utils import ErrorClassesReader
from pickle import PicklingError


class PySparkException(Exception):
Expand Down Expand Up @@ -226,3 +227,9 @@ class PySparkNotImplementedError(PySparkException, NotImplementedError):
"""
Wrapper class for NotImplementedError to support error classes.
"""


class PySparkPicklingError(PySparkException, PicklingError):
"""
Wrapper class for pickle.PicklingError to support error classes.
"""
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
LiteralExpression,
)
from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkPicklingError

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
Expand Down Expand Up @@ -2206,7 +2206,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF:
try:
udtf.command = CloudPickleSerializer().dumps(self._func)
except pickle.PicklingError:
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.errors import PySparkRuntimeError
from pyspark.errors import PySparkPicklingError

__all__ = ["StreamingQuery", "StreamingQueryManager"]

Expand Down Expand Up @@ -242,7 +242,7 @@ def addListener(self, listener: StreamingQueryListener) -> None:
try:
expr.command = CloudPickleSerializer().dumps(listener)
except pickle.PicklingError:
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "addListener"},
)
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.types import Row, StructType
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkPicklingError

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession
Expand Down Expand Up @@ -494,7 +494,7 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt
CloudPickleSerializer().dumps(command)
)
except pickle.PicklingError:
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreach"},
)
Expand All @@ -509,7 +509,7 @@ def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamW
func
)
except pickle.PicklingError:
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreachBatch"},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pyspark.sql.tests.streaming.test_streaming_foreachBatch import StreamingTestsForeachBatchMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.errors import PySparkRuntimeError
from pyspark.errors import PySparkPicklingError


class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase):
Expand All @@ -41,7 +41,7 @@ def func(df, _):
error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkRuntimeError as e:
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)
Expand All @@ -55,7 +55,7 @@ def func(df, _):
error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkRuntimeError as e:
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest
import time

from pyspark.errors import PySparkRuntimeError
from pyspark.errors import PySparkPicklingError
from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin
from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent
from pyspark.sql.types import StructType, StructField, StringType
Expand Down Expand Up @@ -103,7 +103,7 @@ def onQueryTerminated(self, event):
error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkRuntimeError as e:
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)
Expand All @@ -127,7 +127,7 @@ def onQueryTerminated(self, event):
error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkRuntimeError as e:
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
PythonException,
PySparkTypeError,
AnalysisException,
PySparkRuntimeError,
PySparkPicklingError,
)
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
Expand Down Expand Up @@ -872,7 +872,7 @@ def eval(self):
file_obj
yield 1,

with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"):
with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"):
TestUDTF().collect()

def test_udtf_access_spark_session(self):
Expand All @@ -884,7 +884,7 @@ def eval(self):
df.collect()
yield 1,

with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"):
with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"):
TestUDTF().collect()

def test_udtf_no_eval(self):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from py4j.java_gateway import JavaObject

from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError
from pyspark.errors import PySparkAttributeError, PySparkPicklingError, PySparkTypeError
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
Expand Down Expand Up @@ -234,7 +234,7 @@ def _create_judtf(self, func: Type) -> JavaObject:
wrapped_func = _wrap_function(sc, func)
except pickle.PicklingError as e:
if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e):
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand All @@ -244,7 +244,7 @@ def _create_judtf(self, func: Type) -> JavaObject:
"and try again.",
},
) from None
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand Down

0 comments on commit 01618a7

Please sign in to comment.