From 4097b054d573c7a5be90cfca834cbb7543abed1b Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 6 Nov 2023 00:50:42 +0100 Subject: [PATCH] [SPARK-XXX][CONNECT][PYTHON] Better error handlng --- ...SparkConnectFetchErrorDetailsHandler.scala | 6 +- .../spark/sql/connect/utils/ErrorUtils.scala | 14 +++ python/pyspark/errors/exceptions/connect.py | 93 ++++++++++++++----- python/pyspark/sql/connect/client/core.py | 13 ++- .../sql/tests/connect/test_connect_basic.py | 25 ++--- 5 files changed, 112 insertions(+), 39 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala index 17a6e9e434f37..b5a3c986d169b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala @@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.connect.proto.FetchErrorDetailsResponse -import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.sql.internal.SQLConf /** * Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler @@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler( ErrorUtils.throwableToFetchErrorDetailsResponse( st = error, - serverStackTraceEnabled = sessionHolder.session.conf.get( - Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get( - SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED)) + serverStackTraceEnabled = true) } .getOrElse(FetchErrorDetailsResponse.newBuilder().build()) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 744fa3c8aa1a4..7cb555ca47ec9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) + // Add the SQL State and Error Class to the response metadata of the ErrorInfoObject. + st match { + case e: SparkThrowable => + val state = e.getSqlState + if (state != null && state.nonEmpty) { + errorInfo.putMetadata("sqlState", state) + } + val errorClass = e.getErrorClass + if (errorClass != null && errorClass.nonEmpty) { + errorInfo.putMetadata("errorClass", errorClass) + } + case _ => + } + if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) { // Generate a new unique key for this exception. val errorId = UUID.randomUUID().toString diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 423fb2c6f0acc..29426b190ac93 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -16,7 +16,7 @@ # import pyspark.sql.connect.proto as pb2 import json -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING, overload from pyspark.errors.exceptions.base import ( AnalysisException as BaseAnalysisException, @@ -46,55 +46,68 @@ class SparkConnectException(PySparkException): def convert_exception( - info: "ErrorInfo", truncated_message: str, resp: Optional[pb2.FetchErrorDetailsResponse] + info: "ErrorInfo", + truncated_message: str, + resp: Optional[pb2.FetchErrorDetailsResponse], + display_stacktrace: bool = False ) -> SparkConnectException: classes = [] + sql_state = None + error_class = None + if "classes" in info.metadata: classes = json.loads(info.metadata["classes"]) + if "sqlState" in info.metadata: + sql_state = info.metadata["sqlState"] + + if "errorClass" in info.metadata: + error_class = info.metadata["errorClass"] + if resp is not None and resp.HasField("root_error_idx"): message = resp.errors[resp.root_error_idx].message stacktrace = _extract_jvm_stacktrace(resp) else: message = truncated_message - stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else "" - - if len(stacktrace) > 0: - message += f"\n\nJVM stacktrace:\n{stacktrace}" + stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None + display_stacktrace = display_stacktrace if stacktrace is not None else False if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: - return ParseException(message) + return ParseException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) # Order matters. ParseException inherits AnalysisException. elif "org.apache.spark.sql.AnalysisException" in classes: - return AnalysisException(message) + return AnalysisException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: - return StreamingQueryException(message) + return StreamingQueryException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: - return QueryExecutionException(message) + return QueryExecutionException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) # Order matters. NumberFormatException inherits IllegalArgumentException. elif "java.lang.NumberFormatException" in classes: - return NumberFormatException(message) + return NumberFormatException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.IllegalArgumentException" in classes: - return IllegalArgumentException(message) + return IllegalArgumentException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.ArithmeticException" in classes: - return ArithmeticException(message) + return ArithmeticException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.UnsupportedOperationException" in classes: - return UnsupportedOperationException(message) + return UnsupportedOperationException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.ArrayIndexOutOfBoundsException" in classes: - return ArrayIndexOutOfBoundsException(message) + return ArrayIndexOutOfBoundsException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.time.DateTimeException" in classes: - return DateTimeException(message) + return DateTimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.SparkRuntimeException" in classes: - return SparkRuntimeException(message) + return SparkRuntimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.SparkUpgradeException" in classes: - return SparkUpgradeException(message) + return SparkUpgradeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.api.python.PythonException" in classes: return PythonException( "\n An exception was thrown from the Python worker. " "Please see the stack trace below.\n%s" % message ) + # Make sure that the generic SparkException is handled last. + elif "org.apache.spark.SparkException" in classes: + return SparkException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) else: - return SparkConnectGrpcException(message, reason=info.reason) + return SparkConnectGrpcException(message, reason=info.reason, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: @@ -106,7 +119,7 @@ def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None: message = f"{error.error_type_hierarchy[0]}: {error.message}" if len(lines) == 0: - lines.append(message) + lines.append(error.error_type_hierarchy[0]) else: lines.append(f"Caused by: {message}") for elem in error.stack_trace: @@ -135,16 +148,48 @@ def __init__( error_class: Optional[str] = None, message_parameters: Optional[Dict[str, str]] = None, reason: Optional[str] = None, + sql_state: Optional[str] = None, + stacktrace: Optional[str] = None, + display_stacktrace: bool = False ) -> None: self.message = message # type: ignore[assignment] if reason is not None: self.message = f"({reason}) {self.message}" + # PySparkException has the assumption that error_class and message_parameters are + # only occurring together. If only one is set, we assume the message to be fully + # parsed. + tmp_error_class = error_class + tmp_message_parameters = message_parameters + if error_class is not None and message_parameters is None: + tmp_error_class = None + elif error_class is None and message_parameters is not None: + tmp_message_parameters = None + super().__init__( message=self.message, - error_class=error_class, - message_parameters=message_parameters, + error_class=tmp_error_class, + message_parameters=tmp_message_parameters ) + self.error_class = error_class + self._sql_state: Optional[str] = sql_state + self._stacktrace: Optional[str] = stacktrace + self._display_stacktrace: bool = display_stacktrace + + def getSqlState(self) -> None: + if self._sql_state is not None: + return self._sql_state + else: + return super().getSqlState() + + def getStackTrace(self) -> Optional[str]: + return self._stacktrace + + def __str__(self): + desc = self.message + if self._display_stacktrace: + desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace + return desc class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): @@ -223,3 +268,7 @@ class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException """ Exception thrown because of Spark upgrade from Spark Connect. """ + +class SparkException(SparkConnectGrpcException): + """ + """ \ No newline at end of file diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 11a1112ad1fe7..69afef992c34f 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1564,6 +1564,14 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet except grpc.RpcError: return None + def _display_stack_trace(self) -> bool: + from pyspark.sql.connect.conf import RuntimeConf + + conf = RuntimeConf(self) + if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true": + return True + return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true" + def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: """ Error handling helper for dealing with GRPC Errors. On the server side, certain @@ -1594,7 +1602,10 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: d.Unpack(info) raise convert_exception( - info, status.message, self._fetch_enriched_error(info) + info, + status.message, + self._fetch_enriched_error(info), + self._display_stack_trace(), ) from None raise SparkConnectGrpcException(status.message) from None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f024a03c2686c..daf6772e52bf5 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3378,35 +3378,37 @@ def test_error_enrichment_jvm_stacktrace(self): """select from_json( '{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))""" ).collect() - self.assertTrue("JVM stacktrace" in e.exception.message) - self.assertTrue("org.apache.spark.SparkUpgradeException:" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertTrue("org.apache.spark.SparkUpgradeException" in str(e.exception)) self.assertTrue( "at org.apache.spark.sql.errors.ExecutionErrors" - ".failToParseDateTimeInNewParserError" in e.exception.message + ".failToParseDateTimeInNewParserError" in str(e.exception) ) - self.assertTrue("Caused by: java.time.DateTimeException:" in e.exception.message) + self.assertTrue("Caused by: java.time.DateTimeException:" in str(e.exception)) def test_not_hitting_netty_header_limit(self): with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): with self.assertRaises(AnalysisException): - self.spark.sql("select " + "test" * 10000).collect() + self.spark.sql("select " + "test" * 1).collect() def test_error_stack_trace(self): with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): with self.assertRaises(AnalysisException) as e: self.spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) self.assertTrue( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}): with self.assertRaises(AnalysisException) as e: self.spark.sql("select x").collect() - self.assertFalse("JVM stacktrace" in e.exception.message) + self.assertFalse("JVM stacktrace" in str(e.exception)) + self.assertIsNone(e.exception.getStackTrace()) self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) # Create a new session with a different stack trace size. @@ -3421,9 +3423,10 @@ def test_error_stack_trace(self): spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True) with self.assertRaises(AnalysisException) as e: spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) spark.stop()