diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala b/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala new file mode 100644 index 000000000..c23178f00 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.exception + +/** + * Represents an unrecoverable exception in session management and statement execution. This + * exception is used for errors that cannot be handled or recovered from. + */ +class UnrecoverableException private (message: String, cause: Throwable) + extends RuntimeException(message, cause) { + + def this(cause: Throwable) = + this(cause.getMessage, cause) +} + +object UnrecoverableException { + def apply(cause: Throwable): UnrecoverableException = + new UnrecoverableException(cause) + + def apply(message: String, cause: Throwable): UnrecoverableException = + new UnrecoverableException(message, cause) +} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index 915d5e229..16c9747d9 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -52,7 +52,7 @@ class InteractiveSession( val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None, + var error: Option[String] = None, sessionContext: Map[String, Any] = Map.empty[String, Any]) extends ContextualDataStore with Logging { @@ -72,7 +72,7 @@ class InteractiveSession( val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") val errorStr = error.getOrElse("None") // Does not include context, which could contain sensitive information. - s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + + s"InteractiveSession(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" } } diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 1ddfa540b..51bcf8e40 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -17,15 +17,49 @@ import org.opensearch.OpenSearchStatusException import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} import org.opensearch.flint.core.{FlintClient, FlintOptions} -import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} -import org.opensearch.search.sort.SortOrder +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY -import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} +import org.apache.spark.sql.exception.UnrecoverableException +import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_STATEMENT_MANAGER, DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils +/** + * A StatementExecutionManagerImpl that throws UnrecoverableException during statement execution. + * Used for testing error handling in FlintREPL. + */ +class FailingStatementExecutionManager( + private var spark: SparkSession, + private var sessionId: String) + extends StatementExecutionManager { + + def this() = { + this(null, null) + } + + override def prepareStatementExecution(): Either[String, Unit] = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def executeStatement(statement: FlintStatement): DataFrame = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def getNextStatement(): Option[FlintStatement] = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def updateStatement(statement: FlintStatement): Unit = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def terminateStatementExecution(): Unit = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } +} + class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { var flintClient: FlintClient = _ @@ -584,6 +618,27 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } } + test("REPL should handle unrecoverable exception from statement execution") { + // Note: This test sharing system property with other test cases so cannot run alone + System.setProperty( + CUSTOM_STATEMENT_MANAGER.key, + "org.apache.spark.sql.FailingStatementExecutionManager") + try { + createSession(jobRunId, "") + FlintREPL.main(Array(resultIndex)) + fail("The REPL should throw an unrecoverable exception, but it succeeded instead.") + } catch { + case ex: UnrecoverableException => + assert( + ex.getMessage.contains("Simulated execution failure"), + s"Unexpected exception message: ${ex.getMessage}") + case ex: Throwable => + fail(s"Unexpected exception type: ${ex.getClass} with message: ${ex.getMessage}") + } finally { + System.setProperty(CUSTOM_STATEMENT_MANAGER.key, "") + } + } + /** * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or * removed when inside a JSON string. The same goes for tab characters, which should be diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 8e037a53e..63c120a2c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -22,6 +22,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.exception.UnrecoverableException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY import org.apache.spark.sql.types._ @@ -44,12 +45,13 @@ trait FlintJobExecutor { this: Logging => val mapper = new ObjectMapper() + val throwableHandler = new ThrowableHandler() var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() var environmentProvider: EnvironmentProvider = new RealEnvironment() var enableHiveSupport: Boolean = true - // termiante JVM in the presence non-deamon thread before exiting + // terminate JVM in the presence non-daemon thread before exiting var terminateJVM = true // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, @@ -435,11 +437,13 @@ trait FlintJobExecutor { } private def handleQueryException( - e: Exception, + t: Throwable, messagePrefix: String, errorSource: Option[String] = None, statusCode: Option[Int] = None): String = { - val errorMessage = s"$messagePrefix: ${e.getMessage}" + throwableHandler.setThrowable(t) + + val errorMessage = s"$messagePrefix: ${t.getMessage}" val errorDetails = new java.util.LinkedHashMap[String, String]() errorDetails.put("Message", errorMessage) errorSource.foreach(es => errorDetails.put("ErrorSource", es)) @@ -450,25 +454,25 @@ trait FlintJobExecutor { // CustomLogging will call log4j logger.error() underneath statusCode match { case Some(code) => - CustomLogging.logError(new OperationMessage(errorMessage, code), e) + CustomLogging.logError(new OperationMessage(errorMessage, code), t) case None => - CustomLogging.logError(errorMessage, e) + CustomLogging.logError(errorMessage, t) } errorJson } - def getRootCause(e: Throwable): Throwable = { - if (e.getCause == null) e - else getRootCause(e.getCause) + def getRootCause(t: Throwable): Throwable = { + if (t.getCause == null) t + else getRootCause(t.getCause) } /** * This method converts query exception into error string, which then persist to query result * metadata */ - def processQueryException(ex: Exception): String = { - getRootCause(ex) match { + def processQueryException(throwable: Throwable): String = { + getRootCause(throwable) match { case r: ParseException => handleQueryException(r, ExceptionMessages.SyntaxErrorPrefix) case r: AmazonS3Exception => @@ -495,15 +499,15 @@ trait FlintJobExecutor { handleQueryException(r, ExceptionMessages.QueryAnalysisErrorPrefix) case r: SparkException => handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix) - case r: Exception => - val rootCauseClassName = r.getClass.getName - val errMsg = r.getMessage + case t: Throwable => + val rootCauseClassName = t.getClass.getName + val errMsg = t.getMessage if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" && errMsg.contains("com.amazonaws.services.glue.model.AccessDeniedException")) { val e = new SecurityException(ExceptionMessages.GlueAccessDeniedMessage) handleQueryException(e, ExceptionMessages.QueryRunErrorPrefix) } else { - handleQueryException(r, ExceptionMessages.QueryRunErrorPrefix) + handleQueryException(t, ExceptionMessages.QueryRunErrorPrefix) } } } @@ -532,6 +536,14 @@ trait FlintJobExecutor { throw t } + def checkAndThrowUnrecoverableExceptions(): Unit = { + throwableHandler.exceptionThrown.foreach { + case e: UnrecoverableException => + throw e + case _ => // Do nothing for other types of exceptions + } + } + def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (Strings.isNullOrEmpty(className)) { defaultConstructor @@ -551,5 +563,4 @@ trait FlintJobExecutor { } } } - } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 9b6ff4ff6..6d7dcc0e7 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -187,9 +187,9 @@ object FlintREPL extends Logging with FlintJobExecutor { } recordSessionSuccess(sessionTimerContext) } catch { - case e: Exception => + case t: Throwable => handleSessionError( - e, + t, applicationId, jobId, sessionId, @@ -204,6 +204,10 @@ object FlintREPL extends Logging with FlintJobExecutor { stopTimer(sessionTimerContext) spark.stop() + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() + // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, // which may be due to unresolved bugs in dependencies or threads not being properly shut down. @@ -356,6 +360,11 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = updatedVerificationResult canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime + } catch { + case t: Throwable => + // Record and rethrow in query loop + throwableHandler.recordThrowable(s"Query loop execution failed.", t) + throw t } finally { statementsExecutionManager.terminateStatementExecution() } @@ -412,32 +421,40 @@ object FlintREPL extends Logging with FlintJobExecutor { error = error, excludedJobIds = excludedJobIds)) logInfo(s"Current session: ${sessionDetails}") - logInfo(s"State is: ${sessionDetails.state}") sessionDetails.state = state - logInfo(s"State is: ${sessionDetails.state}") + sessionDetails.error = error sessionManager.updateSessionDetails(sessionDetails, updateMode = UPSERT) + logInfo(s"Updated session: ${sessionDetails}") sessionDetails } def handleSessionError( - e: Exception, + t: Throwable, applicationId: String, jobId: String, sessionId: String, sessionManager: SessionManager, jobStartTime: Long, sessionTimerContext: Timer.Context): Unit = { - val error = s"Session error: ${e.getMessage}" - CustomLogging.logError(error, e) + val error = s"Session error: ${t.getMessage}" + throwableHandler.recordThrowable(error, t) + + try { + refreshSessionState( + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(error)) + } catch { + case t: Throwable => + throwableHandler.recordThrowable( + s"Failed to update session state. Original error: $error", + t) + } - refreshSessionState( - applicationId, - jobId, - sessionId, - sessionManager, - jobStartTime, - SessionStates.FAIL, - Some(e.getMessage)) recordSessionFailed(sessionTimerContext) } @@ -485,8 +502,8 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } - def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { - val error = super.processQueryException(ex) + def processQueryException(t: Throwable, flintStatement: FlintStatement): String = { + val error = super.processQueryException(t) flintStatement.fail() flintStatement.error = Some(error) error @@ -581,11 +598,13 @@ object FlintREPL extends Logging with FlintJobExecutor { } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) - case e: Exception => - val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" - CustomLogging.logError(error, e) + case e: Throwable => + throwableHandler.recordThrowable( + s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""", + e) flintStatement.fail() } finally { + logInfo(s"command complete: $flintStatement") statementExecutionManager.updateStatement(flintStatement) recordStatementStateChange(flintStatement, statementTimerContext) } @@ -671,8 +690,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement, sessionId, startTime)) - case e: Exception => - val error = processQueryException(e, flintStatement) + case t: Throwable => + val error = processQueryException(t, flintStatement) Some( handleCommandFailureAndGetFailedData( applicationId, @@ -747,7 +766,7 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime)) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" - CustomLogging.logError(error, e) + throwableHandler.recordThrowable(error, e) dataToWrite = Some( handleCommandFailureAndGetFailedData( applicationId, @@ -786,7 +805,6 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis) } - logInfo(s"command complete: $flintStatement") (dataToWrite, verificationResult) } @@ -858,7 +876,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } } } catch { - case e: Exception => logError(s"Failed to update session state for $sessionId", e) + case t: Throwable => + throwableHandler.recordThrowable(s"Failed to update session state for $sessionId", t) } } } @@ -897,10 +916,10 @@ object FlintREPL extends Logging with FlintJobExecutor { MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC ) // Record heartbeat failure metric // maybe due to invalid sequence number or primary term - case e: Exception => - CustomLogging.logWarning( + case t: Throwable => + throwableHandler.recordThrowable( s"""Fail to update the last update time of the flint instance ${sessionId}""", - e) + t) incrementCounter( MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC ) // Record heartbeat failure metric @@ -948,8 +967,10 @@ object FlintREPL extends Logging with FlintJobExecutor { } } catch { // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) - case e: Exception => - CustomLogging.logError(s"""Fail to find id ${sessionId} from session index.""", e) + case t: Throwable => + throwableHandler.recordThrowable( + s"""Fail to find id ${sessionId} from session index.""", + t) true } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 8582d3037..27b0be84f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -82,9 +82,6 @@ case class JobOperator( LangType.SQL, currentTimeProvider.currentEpochMillis()) - var exceptionThrown = true - var error: String = null - try { val futurePrepareQueryExecution = Future { statementExecutionManager.prepareStatementExecution() @@ -94,7 +91,7 @@ case class JobOperator( ThreadUtils.awaitResult(futurePrepareQueryExecution, Duration(1, MINUTES)) match { case Right(_) => data case Left(err) => - error = err + throwableHandler.setError(err) constructErrorDF( applicationId, jobId, @@ -107,11 +104,9 @@ case class JobOperator( "", startTime) }) - exceptionThrown = false } catch { case e: TimeoutException => - error = s"Preparation for query execution timed out" - logError(error, e) + throwableHandler.recordThrowable(s"Preparation for query execution timed out", e) dataToWrite = Some( constructErrorDF( applicationId, @@ -119,13 +114,13 @@ case class JobOperator( sparkSession, dataSource, "TIMEOUT", - error, + throwableHandler.error, queryId, query, "", startTime)) - case e: Exception => - val error = processQueryException(e) + case t: Throwable => + val error = processQueryException(t) dataToWrite = Some( constructErrorDF( applicationId, @@ -146,27 +141,32 @@ case class JobOperator( try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) } catch { - case e: Exception => - exceptionThrown = true - error = s"Failed to write to result index. originalError='${error}'" - logError(error, e) + case t: Throwable => + throwableHandler.recordThrowable( + s"Failed to write to result index. originalError='${throwableHandler.error}'", + t) } - if (exceptionThrown) statement.fail() else statement.complete() - statement.error = Some(error) - statementExecutionManager.updateStatement(statement) + if (throwableHandler.hasException) statement.fail() else statement.complete() + statement.error = Some(throwableHandler.error) - cleanUpResources(exceptionThrown, threadPool, startTime) + try { + statementExecutionManager.updateStatement(statement) + } catch { + case t: Throwable => + throwableHandler.recordThrowable( + s"Failed to update statement. originalError='${throwableHandler.error}'", + t) + } + + cleanUpResources(threadPool) } } - def cleanUpResources( - exceptionThrown: Boolean, - threadPool: ThreadPoolExecutor, - startTime: Long): Unit = { + def cleanUpResources(threadPool: ThreadPoolExecutor): Unit = { val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) try { // Wait for streaming job complete if no error - if (!exceptionThrown && isStreaming) { + if (!throwableHandler.hasException && isStreaming) { // Clean Spark shuffle data after each microBatch. sparkSession.streams.addListener(new ShuffleCleaner(sparkSession)) // Await index monitor before the main thread terminates @@ -174,7 +174,7 @@ case class JobOperator( } else { logInfo(s""" | Skip streaming job await due to conditions not met: - | - exceptionThrown: $exceptionThrown + | - exceptionThrown: ${throwableHandler.hasException} | - streaming: $isStreaming | - activeStreams: ${sparkSession.streams.active.mkString(",")} |""".stripMargin) @@ -190,7 +190,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } - recordStreamingCompletionStatus(exceptionThrown) + recordStreamingCompletionStatus(throwableHandler.hasException) // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, @@ -219,8 +219,13 @@ case class JobOperator( logInfo("Stopped Spark session") } match { case Success(_) => - case Failure(e) => logError("unexpected error while stopping spark session", e) + case Failure(e) => + throwableHandler.recordThrowable("unexpected error while stopping spark session", e) } + + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() } /** diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala new file mode 100644 index 000000000..01c90bdd4 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import org.opensearch.flint.core.logging.CustomLogging + +/** + * Handles and manages exceptions and error messages during each emr job run. Provides methods to + * set, retrieve, and reset exception information. + */ +class ThrowableHandler { + private var _throwableOption: Option[Throwable] = None + private var _error: String = _ + + def exceptionThrown: Option[Throwable] = _throwableOption + def error: String = _error + + def recordThrowable(err: String, t: Throwable): Unit = { + _error = err + _throwableOption = Some(t) + CustomLogging.logError(err, t) + } + + def setError(err: String): Unit = { + _error = err + } + + def setThrowable(t: Throwable): Unit = { + _throwableOption = Some(t) + } + + def reset(): Unit = { + _throwableOption = None + _error = null + } + + def hasException: Boolean = _throwableOption.isDefined +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 07ed94bdc..7edb0d4c3 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -33,9 +33,11 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.SparkListenerApplicationEnd import org.apache.spark.sql.FlintREPL.PreShutdownListener import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.exception.UnrecoverableException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} @@ -195,19 +197,44 @@ class FlintREPLTest scheduledFutureRaw }) - // Invoke the method FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) - // Verifications verify(sessionManager, atLeastOnce()).recordHeartbeat(sessionId) + FlintREPL.throwableHandler.hasException shouldBe false } - test("PreShutdownListener updates FlintInstance if conditions are met") { + test("createHeartBeatUpdater should handle unrecoverable exception") { + val threadPool = mock[ScheduledExecutorService] + val scheduledFutureRaw = mock[ScheduledFuture[_]] + val sessionManager = mock[SessionManager] + val sessionId = "session1" + + FlintREPL.throwableHandler.reset() + val unrecoverableException = + UnrecoverableException(new RuntimeException("Unrecoverable error")) + when(sessionManager.recordHeartbeat(sessionId)) + .thenThrow(unrecoverableException) + + when(threadPool + .scheduleAtFixedRate(any[Runnable], *, *, eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + .thenAnswer((invocation: InvocationOnMock) => { + val runnable = invocation.getArgument[Runnable](0) + runnable.run() + scheduledFutureRaw + }) + + FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + } + + test("PreShutdownListener updates InteractiveSession if conditions are met") { // Mock dependencies val sessionId = "testSessionId" val timerContext = mock[Timer.Context] val sessionManager = mock[SessionManager] + FlintREPL.throwableHandler.reset() val interactiveSession = new InteractiveSession( "app123", "job123", @@ -227,6 +254,28 @@ class FlintREPLTest interactiveSession.state shouldBe SessionStates.DEAD } + test("PreShutdownListener handles unrecoverable exception from sessionManager") { + val sessionId = "testSessionId" + val timerContext = mock[Timer.Context] + val sessionManager = mock[SessionManager] + + FlintREPL.throwableHandler.reset() + val unrecoverableException = + UnrecoverableException(new RuntimeException("Unrecoverable database error")) + when(sessionManager.getSessionDetails(sessionId)) + .thenThrow(unrecoverableException) + + val listener = new PreShutdownListener(sessionId, sessionManager, timerContext) + + listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + FlintREPL.throwableHandler.error shouldBe s"Failed to update session state for $sessionId" + + verify(sessionManager, never()) + .updateSessionDetails(any[InteractiveSession], any[SessionUpdateMode]) + } + test("Test super.constructErrorDF should construct dataframe properly") { // Define expected dataframe val dataSourceName = "myGlueS3" @@ -463,6 +512,29 @@ class FlintREPLTest assert(result) } + test("test canPickNextStatement: sessionManager throws unrecoverableException") { + val sessionId = "session123" + val jobId = "jobABC" + val sessionIndex = "sessionIndex" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + FlintREPL.throwableHandler.reset() + val sessionManager = mock[SessionManager] + val unrecoverableException = + UnrecoverableException(new RuntimeException("OpenSearch cluster unresponsive")) + when(sessionManager.getSessionDetails(sessionId)) + .thenThrow(unrecoverableException) + + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) + + assert(result) + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + } + test( "test canPickNextStatement: Doc Exists and excludeJobIds is a Single String Not Matching JobId") { val sessionId = "session123" @@ -521,6 +593,7 @@ class FlintREPLTest verify(mockFlintStatement).error = Some(expectedError) assert(result == expectedError) + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(exception) } test("processQueryException should handle MetaException with AccessDeniedException properly") { @@ -665,8 +738,6 @@ class FlintREPLTest override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] - val commandContext = CommandContext( applicationId, jobId, @@ -1026,6 +1097,87 @@ class FlintREPLTest assert(!result) // Expecting false as the job proceeds normally } + test("handleSessionError handles unrecoverable exception") { + val sessionManager = mock[SessionManager] + val timerContext = mock[Timer.Context] + val applicationId = "app123" + val jobId = "job123" + val sessionId = "session123" + val jobStartTime = System.currentTimeMillis() + + FlintREPL.throwableHandler.reset() + val unrecoverableException = + UnrecoverableException(new RuntimeException("Unrecoverable error")) + val interactiveSession = new InteractiveSession( + applicationId, + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + + FlintREPL.handleSessionError( + unrecoverableException, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + timerContext) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + + verify(sessionManager).updateSessionDetails( + argThat { (session: InteractiveSession) => + session.applicationId == applicationId && + session.jobId == jobId && + session.sessionId == sessionId && + session.state == SessionStates.FAIL && + session.error.contains(s"Session error: ${unrecoverableException.getMessage}") + }, + any[SessionUpdateMode]) + + verify(timerContext).stop() + } + + test("handleSessionError handles exception during refreshSessionState") { + val sessionManager = mock[SessionManager] + val timerContext = mock[Timer.Context] + val applicationId = "app123" + val jobId = "job123" + val sessionId = "session123" + val jobStartTime = System.currentTimeMillis() + + FlintREPL.throwableHandler.reset() + val initialException = UnrecoverableException(new RuntimeException("Unrecoverable error")) + val refreshException = + UnrecoverableException(new RuntimeException("Failed to refresh session state")) + + val interactiveSession = new InteractiveSession( + applicationId, + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + when(sessionManager.updateSessionDetails(any[InteractiveSession], any[SessionUpdateMode])) + .thenThrow(refreshException) + + FlintREPL.handleSessionError( + initialException, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + timerContext) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(refreshException) + verify(timerContext).stop() + } + test("queryLoop continue until inactivity limit is reached") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" @@ -1064,7 +1216,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1133,7 +1284,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1198,7 +1348,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1255,7 +1404,8 @@ class FlintREPLTest mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) // Simulate an exception thrown when hasNext is called - when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) + val unrecoverableException = UnrecoverableException(new RuntimeException("Test exception")) + when(mockReader.hasNext).thenThrow(unrecoverableException) when(mockOSClient.doesIndexExist(*)).thenReturn(true) when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) @@ -1268,7 +1418,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1287,13 +1436,15 @@ class FlintREPLTest // Mocking ThreadUtils to track the shutdown call val mockThreadPool = mock[ScheduledExecutorService] FlintREPL.threadPoolFactory = new MockThreadPoolFactory(mockThreadPool) + FlintREPL.throwableHandler.reset() - intercept[RuntimeException] { + intercept[UnrecoverableException] { FlintREPL.queryLoop(commandContext) } // Verify if the shutdown method was called on the thread pool verify(mockThreadPool).shutdown() + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) } finally { // Stop the SparkSession spark.stop() @@ -1436,7 +1587,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId,