From ed4e6bb9e2be313a04bf2f8fb5616b11e1c50501 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Tue, 21 Nov 2023 10:28:36 -0800 Subject: [PATCH] Fix shutdown bug due to non-daemon thread in driver Resolve an issue where a non-daemon thread, potentially created by a bug in dependencies, prevents the driver from shutting down properly. The fix ensures the JVM exits gracefully, avoiding resource leaks and preventing hanging EMR-s jobs. Tests: - Reproduced the bug to confirm the issue. - Applied the fix and verified that the driver now shuts down as expected. Signed-off-by: Kaituo Li --- .../apache/spark/sql/FlintJobExecutor.scala | 6 +- .../org/apache/spark/sql/FlintREPL.scala | 56 +++++++++++++---- .../spark/sql/util/ThreadPoolFactory.scala | 60 ++++++++++++++++++- 3 files changed, 107 insertions(+), 15 deletions(-) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 903bcaa09..a44e70401 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -6,7 +6,6 @@ package org.apache.spark.sql import java.util.Locale -import java.util.concurrent.ThreadPoolExecutor import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} @@ -313,6 +312,11 @@ trait FlintJobExecutor { case e: IllegalStateException if e.getCause().getMessage().contains("index_not_found_exception") => createIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) case e: Exception => val error = s"Failed to verify existing mapping: ${e.getMessage}" logError(error, e) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 0f6c21786..fdde9463f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -6,7 +6,9 @@ package org.apache.spark.sql import java.net.ConnectException -import java.util.concurrent.ScheduledExecutorService +import java.time.Instant +import java.util.Map +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES, _} @@ -105,14 +107,17 @@ object FlintREPL extends Logging with FlintJobExecutor { val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + // 1 thread for updating heart beat - val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) - val jobStartTime = currentTimeProvider.currentEpochMillis() + val threadPool = + threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) + val jobStartTime = currentTimeProvider.currentEpochMillis() + // update heart beat every 30 seconds + // OpenSearch triggers recovery after 1 minute outdated heart beat + var heartBeatFuture: ScheduledFuture[_] = null try { - // update heart beat every 30 seconds - // OpenSearch triggers recovery after 1 minute outdated heart beat - createHeartBeatUpdater( + heartBeatFuture = createHeartBeatUpdater( HEARTBEAT_INTERVAL_MILLIS, flintSessionIndexUpdater, sessionId.get, @@ -161,9 +166,23 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient, sessionIndex.get) } finally { - spark.stop() if (threadPool != null) { - threadPool.shutdown() + heartBeatFuture.cancel(true) // Pass `true` to interrupt if running + threadPoolFactory.shutdownThreadPool(threadPool) + } + + spark.stop() + + // Check for non-daemon threads that may prevent the driver from shutting down. + // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, + // which may be due to unresolved bugs in dependencies or threads not being properly shut down. + if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + logInfo("A non-daemon thread in the driver is seen.") + // Exit the JVM to prevent resource leaks and potential emr-s job hung. + // A zero status code is used for a graceful shutdown without indicating an error. + // If exiting with non-zero status, emr-s job will fail. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully. + System.exit(0) } } } @@ -263,8 +282,10 @@ object FlintREPL extends Logging with FlintJobExecutor { // 1 thread for updating heart beat val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + var futureMappingCheck: Future[Either[String, Unit]] = null try { - val futureMappingCheck = Future { + futureMappingCheck = Future { checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) } @@ -307,7 +328,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } finally { if (threadPool != null) { - threadPool.shutdown() + threadPoolFactory.shutdownThreadPool(threadPool) } } } @@ -475,7 +496,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } else if (!flintReader.hasNext) { canProceed = false } else { - lastActivityTime = currentTimeProvider.currentEpochMillis() val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( @@ -497,6 +517,8 @@ object FlintREPL extends Logging with FlintJobExecutor { resultIndex, flintSessionIndexUpdater, osClient) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() } } @@ -850,12 +872,18 @@ object FlintREPL extends Logging with FlintJobExecutor { threadPool: ScheduledExecutorService, osClient: OSClient, sessionIndex: String, - initialDelayMillis: Long): Unit = { + initialDelayMillis: Long): ScheduledFuture[_] = { threadPool.scheduleAtFixedRate( new Runnable { override def run(): Unit = { try { + // Check the thread's interrupt status at the beginning of the run method + if (Thread.interrupted()) { + logWarning("HeartBeatUpdater has been interrupted. Terminating.") + return // Exit the run method if the thread is interrupted + } + val getResponse = osClient.getDoc(sessionIndex, sessionId) if (getResponse.isExists()) { val source = getResponse.getSourceAsMap @@ -871,6 +899,10 @@ object FlintREPL extends Logging with FlintJobExecutor { } // do nothing if the session doc does not exist } catch { + case ie: InterruptedException => + // Preserve the interrupt status + Thread.currentThread().interrupt() + logError("HeartBeatUpdater task was interrupted", ie) // maybe due to invalid sequence number or primary term case e: Exception => logWarning( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala index 9c97dfe96..34a206e5a 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala @@ -5,10 +5,66 @@ package org.apache.spark.sql.util -import java.util.concurrent.ScheduledExecutorService +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} -trait ThreadPoolFactory { +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging + +trait ThreadPoolFactory extends Logging { def newDaemonThreadPoolScheduledExecutor( threadNamePrefix: String, numThreads: Int): ScheduledExecutorService + + def shutdownThreadPool(executor: ScheduledExecutorService): Unit = { + logInfo(s"terminate executor ${executor}") + executor.shutdown() // Disable new tasks from being submitted + + try { + // Wait a while for existing tasks to terminate + if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { + logWarning("Executor did not terminate in the specified time.") + val tasksNotExecuted = executor.shutdownNow() // Cancel currently executing tasks + // Log the tasks that were awaiting execution + logInfo(s"The following tasks were cancelled: $tasksNotExecuted") + + // Wait a while for tasks to respond to being cancelled + if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { + logError("Thread pool did not terminate after shutdownNow.") + } + } + } catch { + case ie: InterruptedException => + // (Re-)Cancel if current thread also interrupted + executor.shutdownNow() + // Log the interrupted status + logError("Shutdown interrupted", ie) + // Preserve interrupt status + Thread.currentThread().interrupt() + } + } + + /** + * Checks if there are any non-daemon threads other than the "main" thread. + * + * @return + * true if non-daemon threads other than "main" are active, false otherwise. + */ + def hasNonDaemonThreadsOtherThanMain(): Boolean = { + // Log thread information and check for non-daemon threads + Thread.getAllStackTraces.keySet.asScala.exists { t => + val thread = t.asInstanceOf[Thread] + val isNonDaemon = !thread.isDaemon && thread.getName != "main" + + // Log the thread information + logInfo(s"Name: ${thread.getName}; IsDaemon: ${thread.isDaemon}; State: ${thread.getState}") + + // Log the stack trace + Option(Thread.getAllStackTraces.get(thread)).foreach(_.foreach(traceElement => + logInfo(s" at $traceElement"))) + + // Return true if a non-daemon thread is found + isNonDaemon + } + } }