Skip to content

Commit

Permalink
Fix shutdown bug due to non-daemon thread in driver
Browse files Browse the repository at this point in the history
Resolve an issue where a non-daemon thread, potentially created by a bug in dependencies, prevents the driver from shutting down properly. The fix ensures the JVM exits gracefully, avoiding resource leaks and preventing hanging EMR-s jobs.

Tests:
- Reproduced the bug to confirm the issue.
- Applied the fix and verified that the driver now shuts down as expected.

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Nov 22, 2023
1 parent e7f4c73 commit ed4e6bb
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.ThreadPoolExecutor

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}
Expand Down Expand Up @@ -313,6 +312,11 @@ trait FlintJobExecutor {
case e: IllegalStateException
if e.getCause().getMessage().contains("index_not_found_exception") =>
createIndex(osClient, resultIndex, resultIndexMapping)
case e: InterruptedException =>
val error = s"Interrupted by the main thread: ${e.getMessage}"
Thread.currentThread().interrupt() // Preserve the interrupt status
logError(error, e)
Left(error)
case e: Exception =>
val error = s"Failed to verify existing mapping: ${e.getMessage}"
logError(error, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.apache.spark.sql

import java.net.ConnectException
import java.util.concurrent.ScheduledExecutorService
import java.time.Instant
import java.util.Map
import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture}

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES, _}
Expand Down Expand Up @@ -105,14 +107,17 @@ object FlintREPL extends Logging with FlintJobExecutor {

val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get)
addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get)

// 1 thread for updating heart beat
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)
val jobStartTime = currentTimeProvider.currentEpochMillis()
val threadPool =
threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)

val jobStartTime = currentTimeProvider.currentEpochMillis()
// update heart beat every 30 seconds
// OpenSearch triggers recovery after 1 minute outdated heart beat
var heartBeatFuture: ScheduledFuture[_] = null
try {
// update heart beat every 30 seconds
// OpenSearch triggers recovery after 1 minute outdated heart beat
createHeartBeatUpdater(
heartBeatFuture = createHeartBeatUpdater(
HEARTBEAT_INTERVAL_MILLIS,
flintSessionIndexUpdater,
sessionId.get,
Expand Down Expand Up @@ -161,9 +166,23 @@ object FlintREPL extends Logging with FlintJobExecutor {
osClient,
sessionIndex.get)
} finally {
spark.stop()
if (threadPool != null) {
threadPool.shutdown()
heartBeatFuture.cancel(true) // Pass `true` to interrupt if running
threadPoolFactory.shutdownThreadPool(threadPool)
}

spark.stop()

// Check for non-daemon threads that may prevent the driver from shutting down.
// Non-daemon threads other than the main thread indicate that the driver is still processing tasks,
// which may be due to unresolved bugs in dependencies or threads not being properly shut down.
if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) {
logInfo("A non-daemon thread in the driver is seen.")
// Exit the JVM to prevent resource leaks and potential emr-s job hung.
// A zero status code is used for a graceful shutdown without indicating an error.
// If exiting with non-zero status, emr-s job will fail.
// This is a part of the fault tolerance mechanism to handle such scenarios gracefully.
System.exit(0)
}
}
}
Expand Down Expand Up @@ -263,8 +282,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
// 1 thread for updating heart beat
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var futureMappingCheck: Future[Either[String, Unit]] = null
try {
val futureMappingCheck = Future {
futureMappingCheck = Future {
checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex)
}

Expand Down Expand Up @@ -307,7 +328,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
} finally {
if (threadPool != null) {
threadPool.shutdown()
threadPoolFactory.shutdownThreadPool(threadPool)
}
}
}
Expand Down Expand Up @@ -475,7 +496,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
} else if (!flintReader.hasNext) {
canProceed = false
} else {
lastActivityTime = currentTimeProvider.currentEpochMillis()
val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater)

val (dataToWrite, returnedVerificationResult) = processStatementOnVerification(
Expand All @@ -497,6 +517,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
resultIndex,
flintSessionIndexUpdater,
osClient)
// last query finish time is last activity time
lastActivityTime = currentTimeProvider.currentEpochMillis()
}
}

Expand Down Expand Up @@ -850,12 +872,18 @@ object FlintREPL extends Logging with FlintJobExecutor {
threadPool: ScheduledExecutorService,
osClient: OSClient,
sessionIndex: String,
initialDelayMillis: Long): Unit = {
initialDelayMillis: Long): ScheduledFuture[_] = {

threadPool.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = {
try {
// Check the thread's interrupt status at the beginning of the run method
if (Thread.interrupted()) {
logWarning("HeartBeatUpdater has been interrupted. Terminating.")
return // Exit the run method if the thread is interrupted
}

val getResponse = osClient.getDoc(sessionIndex, sessionId)
if (getResponse.isExists()) {
val source = getResponse.getSourceAsMap
Expand All @@ -871,6 +899,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
// do nothing if the session doc does not exist
} catch {
case ie: InterruptedException =>
// Preserve the interrupt status
Thread.currentThread().interrupt()
logError("HeartBeatUpdater task was interrupted", ie)
// maybe due to invalid sequence number or primary term
case e: Exception =>
logWarning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,66 @@

package org.apache.spark.sql.util

import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.{ScheduledExecutorService, TimeUnit}

trait ThreadPoolFactory {
import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging

trait ThreadPoolFactory extends Logging {
def newDaemonThreadPoolScheduledExecutor(
threadNamePrefix: String,
numThreads: Int): ScheduledExecutorService

def shutdownThreadPool(executor: ScheduledExecutorService): Unit = {
logInfo(s"terminate executor ${executor}")
executor.shutdown() // Disable new tasks from being submitted

try {
// Wait a while for existing tasks to terminate
if (!executor.awaitTermination(60, TimeUnit.SECONDS)) {
logWarning("Executor did not terminate in the specified time.")
val tasksNotExecuted = executor.shutdownNow() // Cancel currently executing tasks
// Log the tasks that were awaiting execution
logInfo(s"The following tasks were cancelled: $tasksNotExecuted")

// Wait a while for tasks to respond to being cancelled
if (!executor.awaitTermination(60, TimeUnit.SECONDS)) {
logError("Thread pool did not terminate after shutdownNow.")
}
}
} catch {
case ie: InterruptedException =>
// (Re-)Cancel if current thread also interrupted
executor.shutdownNow()
// Log the interrupted status
logError("Shutdown interrupted", ie)
// Preserve interrupt status
Thread.currentThread().interrupt()
}
}

/**
* Checks if there are any non-daemon threads other than the "main" thread.
*
* @return
* true if non-daemon threads other than "main" are active, false otherwise.
*/
def hasNonDaemonThreadsOtherThanMain(): Boolean = {
// Log thread information and check for non-daemon threads
Thread.getAllStackTraces.keySet.asScala.exists { t =>
val thread = t.asInstanceOf[Thread]
val isNonDaemon = !thread.isDaemon && thread.getName != "main"

// Log the thread information
logInfo(s"Name: ${thread.getName}; IsDaemon: ${thread.isDaemon}; State: ${thread.getState}")

// Log the stack trace
Option(Thread.getAllStackTraces.get(thread)).foreach(_.foreach(traceElement =>
logInfo(s" at $traceElement")))

// Return true if a non-daemon thread is found
isNonDaemon
}
}
}

0 comments on commit ed4e6bb

Please sign in to comment.