Skip to content

Commit

Permalink
Support UnrecoverableException
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Nov 13, 2024
1 parent dd9c0cf commit 8faba92
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.exception

class UnrecoverableException private (message: String, cause: Throwable)
extends RuntimeException(message, cause) {

def this(cause: Throwable) =
this(cause.getMessage, cause)
}

object UnrecoverableException {
def apply(cause: Throwable): UnrecoverableException =
new UnrecoverableException(cause)

def apply(message: String, cause: Throwable): UnrecoverableException =
new UnrecoverableException(message, cause)
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ trait FlintJobExecutor {
this: Logging =>

val mapper = new ObjectMapper()
val exceptionHandler = new ExceptionHandler()

var currentTimeProvider: TimeProvider = new RealTimeProvider()
var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory()
var environmentProvider: EnvironmentProvider = new RealEnvironment()
var enableHiveSupport: Boolean = true
// termiante JVM in the presence non-deamon thread before exiting
// terminate JVM in the presence non-daemon thread before exiting
var terminateJVM = true

// The enabled setting, which can be applied only to the top-level mapping definition and to object fields,
Expand Down Expand Up @@ -435,7 +436,7 @@ trait FlintJobExecutor {
}

private def handleQueryException(
e: Exception,
e: Throwable,
messagePrefix: String,
errorSource: Option[String] = None,
statusCode: Option[Int] = None): String = {
Expand Down Expand Up @@ -467,7 +468,7 @@ trait FlintJobExecutor {
* This method converts query exception into error string, which then persist to query result
* metadata
*/
def processQueryException(ex: Exception): String = {
def processQueryException(ex: Throwable): String = {
getRootCause(ex) match {
case r: ParseException =>
handleQueryException(r, ExceptionMessages.SyntaxErrorPrefix)
Expand Down Expand Up @@ -495,7 +496,7 @@ trait FlintJobExecutor {
handleQueryException(r, ExceptionMessages.QueryAnalysisErrorPrefix)
case r: SparkException =>
handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix)
case r: Exception =>
case r: Throwable =>
val rootCauseClassName = r.getClass.getName
val errMsg = r.getMessage
if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" &&
Expand Down Expand Up @@ -551,5 +552,4 @@ trait FlintJobExecutor {
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.spark.FlintSpark

import org.apache.spark.internal.Logging
import org.apache.spark.sql.exception.UnrecoverableException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.ShuffleCleaner
import org.apache.spark.sql.util.{ExceptionHandler, ShuffleCleaner}
import org.apache.spark.util.ThreadUtils

case class JobOperator(
Expand Down Expand Up @@ -82,9 +83,6 @@ case class JobOperator(
LangType.SQL,
currentTimeProvider.currentEpochMillis())

var exceptionThrown = true
var error: String = null

try {
val futurePrepareQueryExecution = Future {
statementExecutionManager.prepareStatementExecution()
Expand All @@ -94,7 +92,7 @@ case class JobOperator(
ThreadUtils.awaitResult(futurePrepareQueryExecution, Duration(1, MINUTES)) match {
case Right(_) => data
case Left(err) =>
error = err
exceptionHandler.setError(err)
constructErrorDF(
applicationId,
jobId,
Expand All @@ -107,24 +105,22 @@ case class JobOperator(
"",
startTime)
})
exceptionThrown = false
} catch {
case e: TimeoutException =>
error = s"Preparation for query execution timed out"
logError(error, e)
exceptionHandler.handleException(s"Preparation for query execution timed out", e)
dataToWrite = Some(
constructErrorDF(
applicationId,
jobId,
sparkSession,
dataSource,
"TIMEOUT",
error,
exceptionHandler.error,
queryId,
query,
"",
startTime))
case e: Exception =>
case e: Throwable =>
val error = processQueryException(e)
dataToWrite = Some(
constructErrorDF(
Expand All @@ -146,35 +142,32 @@ case class JobOperator(
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception =>
exceptionThrown = true
error = s"Failed to write to result index. originalError='${error}'"
logError(error, e)
case e: Throwable =>
exceptionHandler.handleException(
s"Failed to write to result index. originalError='${exceptionHandler.error}'",
e)
}
if (exceptionThrown) statement.fail() else statement.complete()
statement.error = Some(error)
if (exceptionHandler.hasException) statement.fail() else statement.complete()
statement.error = Some(exceptionHandler.error)
statementExecutionManager.updateStatement(statement)

cleanUpResources(exceptionThrown, threadPool, startTime)
cleanUpResources(threadPool)
}
}

def cleanUpResources(
exceptionThrown: Boolean,
threadPool: ThreadPoolExecutor,
startTime: Long): Unit = {
def cleanUpResources(threadPool: ThreadPoolExecutor): Unit = {
val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING)
try {
// Wait for streaming job complete if no error
if (!exceptionThrown && isStreaming) {
if (!exceptionHandler.hasException && isStreaming) {
// Clean Spark shuffle data after each microBatch.
sparkSession.streams.addListener(new ShuffleCleaner(sparkSession))
// Await index monitor before the main thread terminates
new FlintSpark(sparkSession).flintIndexMonitor.awaitMonitor()
} else {
logInfo(s"""
| Skip streaming job await due to conditions not met:
| - exceptionThrown: $exceptionThrown
| - exceptionThrown: ${exceptionHandler.hasException}
| - streaming: $isStreaming
| - activeStreams: ${sparkSession.streams.active.mkString(",")}
|""".stripMargin)
Expand All @@ -190,7 +183,7 @@ case class JobOperator(
} catch {
case e: Exception => logError("Fail to close threadpool", e)
}
recordStreamingCompletionStatus(exceptionThrown)
recordStreamingCompletionStatus(exceptionHandler.hasException)

// Check for non-daemon threads that may prevent the driver from shutting down.
// Non-daemon threads other than the main thread indicate that the driver is still processing tasks,
Expand Down Expand Up @@ -219,7 +212,16 @@ case class JobOperator(
logInfo("Stopped Spark session")
} match {
case Success(_) =>
case Failure(e) => logError("unexpected error while stopping spark session", e)
case Failure(e) =>
exceptionHandler.handleException("unexpected error while stopping spark session", e)
}

// After handling any exceptions from stopping the Spark session,
// check if there's a stored exception and throw it if it's an UnrecoverableException
exceptionHandler.exceptionThrown.foreach {
case e: UnrecoverableException =>
throw e
case _ => // Do nothing for other types of exceptions
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.apache.spark.internal.Logging

class ExceptionHandler extends Logging {
private var _exceptionThrown: Option[Throwable] = None
private var _error: String = _

def exceptionThrown: Option[Throwable] = _exceptionThrown
def error: String = _error

def handleException(err: String, e: Throwable): Unit = {
_error = err
_exceptionThrown = Some(e)
logError(err, e)
}

def setError(err: String): Unit = {
_error = err
}

def reset(): Unit = {
_exceptionThrown = None
_error = null
}

def hasException: Boolean = _exceptionThrown.isDefined
}

0 comments on commit 8faba92

Please sign in to comment.