diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala b/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala new file mode 100644 index 000000000..c23178f00 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.exception + +/** + * Represents an unrecoverable exception in session management and statement execution. This + * exception is used for errors that cannot be handled or recovered from. + */ +class UnrecoverableException private (message: String, cause: Throwable) + extends RuntimeException(message, cause) { + + def this(cause: Throwable) = + this(cause.getMessage, cause) +} + +object UnrecoverableException { + def apply(cause: Throwable): UnrecoverableException = + new UnrecoverableException(cause) + + def apply(message: String, cause: Throwable): UnrecoverableException = + new UnrecoverableException(message, cause) +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 8e037a53e..63c120a2c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -22,6 +22,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.exception.UnrecoverableException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY import org.apache.spark.sql.types._ @@ -44,12 +45,13 @@ trait FlintJobExecutor { this: Logging => val mapper = new ObjectMapper() + val throwableHandler = new ThrowableHandler() var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() var environmentProvider: EnvironmentProvider = new RealEnvironment() var enableHiveSupport: Boolean = true - // termiante JVM in the presence non-deamon thread before exiting + // terminate JVM in the presence non-daemon thread before exiting var terminateJVM = true // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, @@ -435,11 +437,13 @@ trait FlintJobExecutor { } private def handleQueryException( - e: Exception, + t: Throwable, messagePrefix: String, errorSource: Option[String] = None, statusCode: Option[Int] = None): String = { - val errorMessage = s"$messagePrefix: ${e.getMessage}" + throwableHandler.setThrowable(t) + + val errorMessage = s"$messagePrefix: ${t.getMessage}" val errorDetails = new java.util.LinkedHashMap[String, String]() errorDetails.put("Message", errorMessage) errorSource.foreach(es => errorDetails.put("ErrorSource", es)) @@ -450,25 +454,25 @@ trait FlintJobExecutor { // CustomLogging will call log4j logger.error() underneath statusCode match { case Some(code) => - CustomLogging.logError(new OperationMessage(errorMessage, code), e) + CustomLogging.logError(new OperationMessage(errorMessage, code), t) case None => - CustomLogging.logError(errorMessage, e) + CustomLogging.logError(errorMessage, t) } errorJson } - def getRootCause(e: Throwable): Throwable = { - if (e.getCause == null) e - else getRootCause(e.getCause) + def getRootCause(t: Throwable): Throwable = { + if (t.getCause == null) t + else getRootCause(t.getCause) } /** * This method converts query exception into error string, which then persist to query result * metadata */ - def processQueryException(ex: Exception): String = { - getRootCause(ex) match { + def processQueryException(throwable: Throwable): String = { + getRootCause(throwable) match { case r: ParseException => handleQueryException(r, ExceptionMessages.SyntaxErrorPrefix) case r: AmazonS3Exception => @@ -495,15 +499,15 @@ trait FlintJobExecutor { handleQueryException(r, ExceptionMessages.QueryAnalysisErrorPrefix) case r: SparkException => handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix) - case r: Exception => - val rootCauseClassName = r.getClass.getName - val errMsg = r.getMessage + case t: Throwable => + val rootCauseClassName = t.getClass.getName + val errMsg = t.getMessage if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" && errMsg.contains("com.amazonaws.services.glue.model.AccessDeniedException")) { val e = new SecurityException(ExceptionMessages.GlueAccessDeniedMessage) handleQueryException(e, ExceptionMessages.QueryRunErrorPrefix) } else { - handleQueryException(r, ExceptionMessages.QueryRunErrorPrefix) + handleQueryException(t, ExceptionMessages.QueryRunErrorPrefix) } } } @@ -532,6 +536,14 @@ trait FlintJobExecutor { throw t } + def checkAndThrowUnrecoverableExceptions(): Unit = { + throwableHandler.exceptionThrown.foreach { + case e: UnrecoverableException => + throw e + case _ => // Do nothing for other types of exceptions + } + } + def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (Strings.isNullOrEmpty(className)) { defaultConstructor @@ -551,5 +563,4 @@ trait FlintJobExecutor { } } } - } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 9b6ff4ff6..3b283d23a 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -59,6 +59,10 @@ object FlintREPL extends Logging with FlintJobExecutor { private val statementRunningCount = new AtomicInteger(0) private var queryCountMetric = 0 + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + sys.addShutdownHook(checkAndThrowUnrecoverableExceptions()) + def main(args: Array[String]) { val (queryOption, resultIndexOption) = parseArgs(args) @@ -187,9 +191,9 @@ object FlintREPL extends Logging with FlintJobExecutor { } recordSessionSuccess(sessionTimerContext) } catch { - case e: Exception => + case t: Throwable => handleSessionError( - e, + t, applicationId, jobId, sessionId, @@ -420,15 +424,14 @@ object FlintREPL extends Logging with FlintJobExecutor { } def handleSessionError( - e: Exception, + t: Throwable, applicationId: String, jobId: String, sessionId: String, sessionManager: SessionManager, jobStartTime: Long, sessionTimerContext: Timer.Context): Unit = { - val error = s"Session error: ${e.getMessage}" - CustomLogging.logError(error, e) + throwableHandler.handleException(s"Session error: ${t.getMessage}", t) refreshSessionState( applicationId, @@ -437,7 +440,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionManager, jobStartTime, SessionStates.FAIL, - Some(e.getMessage)) + Some(t.getMessage)) recordSessionFailed(sessionTimerContext) } @@ -485,8 +488,8 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } - def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { - val error = super.processQueryException(ex) + def processQueryException(t: Throwable, flintStatement: FlintStatement): String = { + val error = super.processQueryException(t) flintStatement.fail() flintStatement.error = Some(error) error @@ -581,11 +584,13 @@ object FlintREPL extends Logging with FlintJobExecutor { } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) - case e: Exception => - val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" - CustomLogging.logError(error, e) + case e: Throwable => + throwableHandler.handleException( + s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""", + e) flintStatement.fail() } finally { + logInfo(s"command complete: $flintStatement") statementExecutionManager.updateStatement(flintStatement) recordStatementStateChange(flintStatement, statementTimerContext) } @@ -671,8 +676,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement, sessionId, startTime)) - case e: Exception => - val error = processQueryException(e, flintStatement) + case t: Throwable => + val error = processQueryException(t, flintStatement) Some( handleCommandFailureAndGetFailedData( applicationId, @@ -747,7 +752,7 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime)) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" - CustomLogging.logError(error, e) + throwableHandler.handleException(error, e) dataToWrite = Some( handleCommandFailureAndGetFailedData( applicationId, @@ -786,7 +791,6 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis) } - logInfo(s"command complete: $flintStatement") (dataToWrite, verificationResult) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 8582d3037..c38d31fc5 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -82,9 +82,6 @@ case class JobOperator( LangType.SQL, currentTimeProvider.currentEpochMillis()) - var exceptionThrown = true - var error: String = null - try { val futurePrepareQueryExecution = Future { statementExecutionManager.prepareStatementExecution() @@ -94,7 +91,7 @@ case class JobOperator( ThreadUtils.awaitResult(futurePrepareQueryExecution, Duration(1, MINUTES)) match { case Right(_) => data case Left(err) => - error = err + throwableHandler.setError(err) constructErrorDF( applicationId, jobId, @@ -107,11 +104,9 @@ case class JobOperator( "", startTime) }) - exceptionThrown = false } catch { case e: TimeoutException => - error = s"Preparation for query execution timed out" - logError(error, e) + throwableHandler.handleException(s"Preparation for query execution timed out", e) dataToWrite = Some( constructErrorDF( applicationId, @@ -119,13 +114,13 @@ case class JobOperator( sparkSession, dataSource, "TIMEOUT", - error, + throwableHandler.error, queryId, query, "", startTime)) - case e: Exception => - val error = processQueryException(e) + case t: Throwable => + val error = processQueryException(t) dataToWrite = Some( constructErrorDF( applicationId, @@ -146,27 +141,24 @@ case class JobOperator( try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) } catch { - case e: Exception => - exceptionThrown = true - error = s"Failed to write to result index. originalError='${error}'" - logError(error, e) + case t: Throwable => + throwableHandler.handleException( + s"Failed to write to result index. originalError='${throwableHandler.error}'", + t) } - if (exceptionThrown) statement.fail() else statement.complete() - statement.error = Some(error) + if (throwableHandler.hasException) statement.fail() else statement.complete() + statement.error = Some(throwableHandler.error) statementExecutionManager.updateStatement(statement) - cleanUpResources(exceptionThrown, threadPool, startTime) + cleanUpResources(threadPool) } } - def cleanUpResources( - exceptionThrown: Boolean, - threadPool: ThreadPoolExecutor, - startTime: Long): Unit = { + def cleanUpResources(threadPool: ThreadPoolExecutor): Unit = { val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) try { // Wait for streaming job complete if no error - if (!exceptionThrown && isStreaming) { + if (!throwableHandler.hasException && isStreaming) { // Clean Spark shuffle data after each microBatch. sparkSession.streams.addListener(new ShuffleCleaner(sparkSession)) // Await index monitor before the main thread terminates @@ -174,7 +166,7 @@ case class JobOperator( } else { logInfo(s""" | Skip streaming job await due to conditions not met: - | - exceptionThrown: $exceptionThrown + | - exceptionThrown: ${throwableHandler.hasException} | - streaming: $isStreaming | - activeStreams: ${sparkSession.streams.active.mkString(",")} |""".stripMargin) @@ -190,7 +182,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } - recordStreamingCompletionStatus(exceptionThrown) + recordStreamingCompletionStatus(throwableHandler.hasException) // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, @@ -219,8 +211,13 @@ case class JobOperator( logInfo("Stopped Spark session") } match { case Success(_) => - case Failure(e) => logError("unexpected error while stopping spark session", e) + case Failure(e) => + throwableHandler.handleException("unexpected error while stopping spark session", e) } + + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() } /** diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala new file mode 100644 index 000000000..d0ffd3782 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import org.opensearch.flint.core.logging.CustomLogging + +/** + * Handles and manages exceptions and error messages during each emr job run. Provides methods to + * set, retrieve, and reset exception information. + */ +class ThrowableHandler { + private var _throwableOption: Option[Throwable] = None + private var _error: String = _ + + def exceptionThrown: Option[Throwable] = _throwableOption + def error: String = _error + + def handleException(err: String, t: Throwable): Unit = { + _error = err + _throwableOption = Some(t) + CustomLogging.logError(err, t) + } + + def setError(err: String): Unit = { + _error = err + } + + def setThrowable(t: Throwable): Unit = { + _throwableOption = Some(t) + } + + def reset(): Unit = { + _throwableOption = None + _error = null + } + + def hasException: Boolean = _throwableOption.isDefined +}