diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala index 86bf567f5..59016d6bc 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -76,7 +76,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount) job.envinromentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) - + job.terminateJVM = false job.start() } futureResult.onComplete { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 1814a8d8e..ccd5c8f3f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -34,6 +34,8 @@ trait FlintJobExecutor { var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() var envinromentProvider: EnvironmentProvider = new RealEnvironment() var enableHiveSupport: Boolean = true + // termiante JVM in the presence non-deamon thread before exiting + var terminateJVM = true // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, val resultIndexMapping = diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 78314a68b..76e5f692c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -57,8 +57,6 @@ object FlintREPL extends Logging with FlintJobExecutor { val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L @volatile var earlyExitFlag: Boolean = false - // termiante JVM in the presence non-deamon thread before exiting - var terminateJVM = true def updateSessionIndex(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index bbaceb15d..4fb272938 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -19,7 +19,7 @@ import org.opensearch.flint.core.storage.OpenSearchUpdater import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.FlintJob.createSparkSession -import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, updateFlintInstanceBeforeShutdown} +import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, threadPoolFactory, updateFlintInstanceBeforeShutdown} import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.ThreadUtils @@ -106,6 +106,18 @@ case class JobOperator( case e: Exception => logError("Fail to close threadpool", e) } recordStreamingCompletionStatus(exceptionThrown) + + // Check for non-daemon threads that may prevent the driver from shutting down. + // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, + // which may be due to unresolved bugs in dependencies or threads not being properly shut down. + if (terminateJVM && threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + logInfo("A non-daemon thread in the driver is seen.") + // Exit the JVM to prevent resource leaks and potential emr-s job hung. + // A zero status code is used for a graceful shutdown without indicating an error. + // If exiting with non-zero status, emr-s job will fail. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully + System.exit(0) + } } def stop(): Unit = {