Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 0.6] Support UnrecoverableException #916

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.exception

/**
* Represents an unrecoverable exception in session management and statement execution. This
* exception is used for errors that cannot be handled or recovered from.
*/
class UnrecoverableException private (message: String, cause: Throwable)
extends RuntimeException(message, cause) {

def this(cause: Throwable) =
this(cause.getMessage, cause)
}

object UnrecoverableException {
def apply(cause: Throwable): UnrecoverableException =
new UnrecoverableException(cause)

def apply(message: String, cause: Throwable): UnrecoverableException =
new UnrecoverableException(message, cause)
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class InteractiveSession(
val lastUpdateTime: Long,
val jobStartTime: Long = 0,
val excludedJobIds: Seq[String] = Seq.empty[String],
val error: Option[String] = None,
var error: Option[String] = None,
sessionContext: Map[String, Any] = Map.empty[String, Any])
extends ContextualDataStore
with Logging {
Expand All @@ -72,7 +72,7 @@ class InteractiveSession(
val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]")
val errorStr = error.getOrElse("None")
// Does not include context, which could contain sensitive information.
s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " +
s"InteractiveSession(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " +
s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,49 @@ import org.opensearch.OpenSearchStatusException
import org.opensearch.flint.OpenSearchSuite
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}
import org.opensearch.flint.core.{FlintClient, FlintOptions}
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY
import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID}
import org.apache.spark.sql.exception.UnrecoverableException
import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_STATEMENT_MANAGER, DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID}
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils

/**
* A StatementExecutionManagerImpl that throws UnrecoverableException during statement execution.
* Used for testing error handling in FlintREPL.
*/
class FailingStatementExecutionManager(
private var spark: SparkSession,
private var sessionId: String)
extends StatementExecutionManager {

def this() = {
this(null, null)
}

override def prepareStatementExecution(): Either[String, Unit] = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def executeStatement(statement: FlintStatement): DataFrame = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def getNextStatement(): Option[FlintStatement] = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def updateStatement(statement: FlintStatement): Unit = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}

override def terminateStatementExecution(): Unit = {
throw UnrecoverableException(new RuntimeException("Simulated execution failure"))
}
}

class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {

var flintClient: FlintClient = _
Expand Down Expand Up @@ -584,6 +618,27 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
}
}

test("REPL should handle unrecoverable exception from statement execution") {
// Note: This test sharing system property with other test cases so cannot run alone
System.setProperty(
CUSTOM_STATEMENT_MANAGER.key,
"org.apache.spark.sql.FailingStatementExecutionManager")
try {
createSession(jobRunId, "")
FlintREPL.main(Array(resultIndex))
fail("The REPL should throw an unrecoverable exception, but it succeeded instead.")
} catch {
case ex: UnrecoverableException =>
assert(
ex.getMessage.contains("Simulated execution failure"),
s"Unexpected exception message: ${ex.getMessage}")
case ex: Throwable =>
fail(s"Unexpected exception type: ${ex.getClass} with message: ${ex.getMessage}")
} finally {
System.setProperty(CUSTOM_STATEMENT_MANAGER.key, "")
}
}

/**
* JSON does not support raw newlines (\n) in string values. All newlines must be escaped or
* removed when inside a JSON string. The same goes for tab characters, which should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.exception.UnrecoverableException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY
import org.apache.spark.sql.types._
Expand All @@ -44,12 +45,13 @@ trait FlintJobExecutor {
this: Logging =>

val mapper = new ObjectMapper()
val throwableHandler = new ThrowableHandler()

var currentTimeProvider: TimeProvider = new RealTimeProvider()
var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory()
var environmentProvider: EnvironmentProvider = new RealEnvironment()
var enableHiveSupport: Boolean = true
// termiante JVM in the presence non-deamon thread before exiting
// terminate JVM in the presence non-daemon thread before exiting
var terminateJVM = true

// The enabled setting, which can be applied only to the top-level mapping definition and to object fields,
Expand Down Expand Up @@ -435,11 +437,13 @@ trait FlintJobExecutor {
}

private def handleQueryException(
e: Exception,
t: Throwable,
messagePrefix: String,
errorSource: Option[String] = None,
statusCode: Option[Int] = None): String = {
val errorMessage = s"$messagePrefix: ${e.getMessage}"
throwableHandler.setThrowable(t)

val errorMessage = s"$messagePrefix: ${t.getMessage}"
val errorDetails = new java.util.LinkedHashMap[String, String]()
errorDetails.put("Message", errorMessage)
errorSource.foreach(es => errorDetails.put("ErrorSource", es))
Expand All @@ -450,25 +454,25 @@ trait FlintJobExecutor {
// CustomLogging will call log4j logger.error() underneath
statusCode match {
case Some(code) =>
CustomLogging.logError(new OperationMessage(errorMessage, code), e)
CustomLogging.logError(new OperationMessage(errorMessage, code), t)
case None =>
CustomLogging.logError(errorMessage, e)
CustomLogging.logError(errorMessage, t)
}

errorJson
}

def getRootCause(e: Throwable): Throwable = {
if (e.getCause == null) e
else getRootCause(e.getCause)
def getRootCause(t: Throwable): Throwable = {
if (t.getCause == null) t
else getRootCause(t.getCause)
}

/**
* This method converts query exception into error string, which then persist to query result
* metadata
*/
def processQueryException(ex: Exception): String = {
getRootCause(ex) match {
def processQueryException(throwable: Throwable): String = {
getRootCause(throwable) match {
case r: ParseException =>
handleQueryException(r, ExceptionMessages.SyntaxErrorPrefix)
case r: AmazonS3Exception =>
Expand All @@ -495,15 +499,15 @@ trait FlintJobExecutor {
handleQueryException(r, ExceptionMessages.QueryAnalysisErrorPrefix)
case r: SparkException =>
handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix)
case r: Exception =>
val rootCauseClassName = r.getClass.getName
val errMsg = r.getMessage
case t: Throwable =>
val rootCauseClassName = t.getClass.getName
val errMsg = t.getMessage
if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" &&
errMsg.contains("com.amazonaws.services.glue.model.AccessDeniedException")) {
val e = new SecurityException(ExceptionMessages.GlueAccessDeniedMessage)
handleQueryException(e, ExceptionMessages.QueryRunErrorPrefix)
} else {
handleQueryException(r, ExceptionMessages.QueryRunErrorPrefix)
handleQueryException(t, ExceptionMessages.QueryRunErrorPrefix)
}
}
}
Expand Down Expand Up @@ -532,6 +536,14 @@ trait FlintJobExecutor {
throw t
}

def checkAndThrowUnrecoverableExceptions(): Unit = {
throwableHandler.exceptionThrown.foreach {
case e: UnrecoverableException =>
throw e
case _ => // Do nothing for other types of exceptions
}
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (Strings.isNullOrEmpty(className)) {
defaultConstructor
Expand All @@ -551,5 +563,4 @@ trait FlintJobExecutor {
}
}
}

}
Loading
Loading