Skip to content

Commit

Permalink
[SPARK-XXX][CONNECT][PYTHON] Better error handlng
Browse files Browse the repository at this point in the history
  • Loading branch information
grundprinzip committed Nov 5, 2023
1 parent 9cbc2d1 commit 4097b05
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.internal.SQLConf

/**
* Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler
Expand All @@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler(

ErrorUtils.throwableToFetchErrorDetailsResponse(
st = error,
serverStackTraceEnabled = sessionHolder.session.conf.get(
Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get(
SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
serverStackTraceEnabled = true)
}
.getOrElse(FetchErrorDetailsResponse.newBuilder().build())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging {
"classes",
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))

// Add the SQL State and Error Class to the response metadata of the ErrorInfoObject.
st match {
case e: SparkThrowable =>
val state = e.getSqlState
if (state != null && state.nonEmpty) {
errorInfo.putMetadata("sqlState", state)
}
val errorClass = e.getErrorClass
if (errorClass != null && errorClass.nonEmpty) {
errorInfo.putMetadata("errorClass", errorClass)
}
case _ =>
}

if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) {
// Generate a new unique key for this exception.
val errorId = UUID.randomUUID().toString
Expand Down
93 changes: 71 additions & 22 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
import pyspark.sql.connect.proto as pb2
import json
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING, overload

from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
Expand Down Expand Up @@ -46,55 +46,68 @@ class SparkConnectException(PySparkException):


def convert_exception(
info: "ErrorInfo", truncated_message: str, resp: Optional[pb2.FetchErrorDetailsResponse]
info: "ErrorInfo",
truncated_message: str,
resp: Optional[pb2.FetchErrorDetailsResponse],
display_stacktrace: bool = False
) -> SparkConnectException:
classes = []
sql_state = None
error_class = None

if "classes" in info.metadata:
classes = json.loads(info.metadata["classes"])

if "sqlState" in info.metadata:
sql_state = info.metadata["sqlState"]

if "errorClass" in info.metadata:
error_class = info.metadata["errorClass"]

if resp is not None and resp.HasField("root_error_idx"):
message = resp.errors[resp.root_error_idx].message
stacktrace = _extract_jvm_stacktrace(resp)
else:
message = truncated_message
stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else ""

if len(stacktrace) > 0:
message += f"\n\nJVM stacktrace:\n{stacktrace}"
stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None
display_stacktrace = display_stacktrace if stacktrace is not None else False

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(message)
return ParseException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
# Order matters. ParseException inherits AnalysisException.
elif "org.apache.spark.sql.AnalysisException" in classes:
return AnalysisException(message)
return AnalysisException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
return StreamingQueryException(message)
return StreamingQueryException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
return QueryExecutionException(message)
return QueryExecutionException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
# Order matters. NumberFormatException inherits IllegalArgumentException.
elif "java.lang.NumberFormatException" in classes:
return NumberFormatException(message)
return NumberFormatException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.IllegalArgumentException" in classes:
return IllegalArgumentException(message)
return IllegalArgumentException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.ArithmeticException" in classes:
return ArithmeticException(message)
return ArithmeticException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.UnsupportedOperationException" in classes:
return UnsupportedOperationException(message)
return UnsupportedOperationException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
return ArrayIndexOutOfBoundsException(message)
return ArrayIndexOutOfBoundsException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.time.DateTimeException" in classes:
return DateTimeException(message)
return DateTimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.SparkRuntimeException" in classes:
return SparkRuntimeException(message)
return SparkRuntimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.SparkUpgradeException" in classes:
return SparkUpgradeException(message)
return SparkUpgradeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
"\n An exception was thrown from the Python worker. "
"Please see the stack trace below.\n%s" % message
)
# Make sure that the generic SparkException is handled last.
elif "org.apache.spark.SparkException" in classes:
return SparkException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
else:
return SparkConnectGrpcException(message, reason=info.reason)
return SparkConnectGrpcException(message, reason=info.reason, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)


def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str:
Expand All @@ -106,7 +119,7 @@ def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str:
def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None:
message = f"{error.error_type_hierarchy[0]}: {error.message}"
if len(lines) == 0:
lines.append(message)
lines.append(error.error_type_hierarchy[0])
else:
lines.append(f"Caused by: {message}")
for elem in error.stack_trace:
Expand Down Expand Up @@ -135,16 +148,48 @@ def __init__(
error_class: Optional[str] = None,
message_parameters: Optional[Dict[str, str]] = None,
reason: Optional[str] = None,
sql_state: Optional[str] = None,
stacktrace: Optional[str] = None,
display_stacktrace: bool = False
) -> None:
self.message = message # type: ignore[assignment]
if reason is not None:
self.message = f"({reason}) {self.message}"

# PySparkException has the assumption that error_class and message_parameters are
# only occurring together. If only one is set, we assume the message to be fully
# parsed.
tmp_error_class = error_class
tmp_message_parameters = message_parameters
if error_class is not None and message_parameters is None:
tmp_error_class = None
elif error_class is None and message_parameters is not None:
tmp_message_parameters = None

super().__init__(
message=self.message,
error_class=error_class,
message_parameters=message_parameters,
error_class=tmp_error_class,
message_parameters=tmp_message_parameters
)
self.error_class = error_class
self._sql_state: Optional[str] = sql_state
self._stacktrace: Optional[str] = stacktrace
self._display_stacktrace: bool = display_stacktrace

def getSqlState(self) -> None:
if self._sql_state is not None:
return self._sql_state
else:
return super().getSqlState()

def getStackTrace(self) -> Optional[str]:
return self._stacktrace

def __str__(self):
desc = self.message
if self._display_stacktrace:
desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
return desc


class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
Expand Down Expand Up @@ -223,3 +268,7 @@ class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException
"""
Exception thrown because of Spark upgrade from Spark Connect.
"""

class SparkException(SparkConnectGrpcException):
"""
"""
13 changes: 12 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,14 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
except grpc.RpcError:
return None

def _display_stack_trace(self) -> bool:
from pyspark.sql.connect.conf import RuntimeConf

conf = RuntimeConf(self)
if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
return True
return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"

def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server side, certain
Expand Down Expand Up @@ -1594,7 +1602,10 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
d.Unpack(info)

raise convert_exception(
info, status.message, self._fetch_enriched_error(info)
info,
status.message,
self._fetch_enriched_error(info),
self._display_stack_trace(),
) from None

raise SparkConnectGrpcException(status.message) from None
Expand Down
25 changes: 14 additions & 11 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3378,35 +3378,37 @@ def test_error_enrichment_jvm_stacktrace(self):
"""select from_json(
'{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))"""
).collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue("org.apache.spark.SparkUpgradeException:" in e.exception.message)
self.assertTrue("JVM stacktrace" in str(e.exception))
self.assertTrue("org.apache.spark.SparkUpgradeException" in str(e.exception))
self.assertTrue(
"at org.apache.spark.sql.errors.ExecutionErrors"
".failToParseDateTimeInNewParserError" in e.exception.message
".failToParseDateTimeInNewParserError" in str(e.exception)
)
self.assertTrue("Caused by: java.time.DateTimeException:" in e.exception.message)
self.assertTrue("Caused by: java.time.DateTimeException:" in str(e.exception))

def test_not_hitting_netty_header_limit(self):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
with self.assertRaises(AnalysisException):
self.spark.sql("select " + "test" * 10000).collect()
self.spark.sql("select " + "test" * 1).collect()

def test_error_stack_trace(self):
with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue("JVM stacktrace" in str(e.exception))
self.assertIsNotNone(e.exception.getStackTrace())
self.assertTrue(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception)
)

with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertFalse("JVM stacktrace" in e.exception.message)
self.assertFalse("JVM stacktrace" in str(e.exception))
self.assertIsNone(e.exception.getStackTrace())
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception)
)

# Create a new session with a different stack trace size.
Expand All @@ -3421,9 +3423,10 @@ def test_error_stack_trace(self):
spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True)
with self.assertRaises(AnalysisException) as e:
spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue("JVM stacktrace" in str(e.exception))
self.assertIsNotNone(e.exception.getStackTrace())
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception)
)
spark.stop()

Expand Down

0 comments on commit 4097b05

Please sign in to comment.