Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shutdown bug due to non-daemon thread in driver #175

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.ThreadPoolExecutor

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}
Expand Down Expand Up @@ -313,6 +312,11 @@ trait FlintJobExecutor {
case e: IllegalStateException
if e.getCause().getMessage().contains("index_not_found_exception") =>
createIndex(osClient, resultIndex, resultIndexMapping)
case e: InterruptedException =>
val error = s"Interrupted by the main thread: ${e.getMessage}"
Thread.currentThread().interrupt() // Preserve the interrupt status
logError(error, e)
Left(error)
case e: Exception =>
val error = s"Failed to verify existing mapping: ${e.getMessage}"
logError(error, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.apache.spark.sql

import java.net.ConnectException
import java.util.concurrent.ScheduledExecutorService
import java.time.Instant
import java.util.Map
import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture}

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES, _}
Expand Down Expand Up @@ -105,14 +107,17 @@ object FlintREPL extends Logging with FlintJobExecutor {

val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get)
addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get)

// 1 thread for updating heart beat
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)
val jobStartTime = currentTimeProvider.currentEpochMillis()
val threadPool =
threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)

val jobStartTime = currentTimeProvider.currentEpochMillis()
// update heart beat every 30 seconds
// OpenSearch triggers recovery after 1 minute outdated heart beat
var heartBeatFuture: ScheduledFuture[_] = null
try {
// update heart beat every 30 seconds
// OpenSearch triggers recovery after 1 minute outdated heart beat
createHeartBeatUpdater(
heartBeatFuture = createHeartBeatUpdater(
HEARTBEAT_INTERVAL_MILLIS,
flintSessionIndexUpdater,
sessionId.get,
Expand Down Expand Up @@ -161,9 +166,23 @@ object FlintREPL extends Logging with FlintJobExecutor {
osClient,
sessionIndex.get)
} finally {
spark.stop()
if (threadPool != null) {
threadPool.shutdown()
heartBeatFuture.cancel(true) // Pass `true` to interrupt if running
threadPoolFactory.shutdownThreadPool(threadPool)
}

spark.stop()

// Check for non-daemon threads that may prevent the driver from shutting down.
// Non-daemon threads other than the main thread indicate that the driver is still processing tasks,
// which may be due to unresolved bugs in dependencies or threads not being properly shut down.
if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) {
logInfo("A non-daemon thread in the driver is seen.")
// Exit the JVM to prevent resource leaks and potential emr-s job hung.
// A zero status code is used for a graceful shutdown without indicating an error.
// If exiting with non-zero status, emr-s job will fail.
// This is a part of the fault tolerance mechanism to handle such scenarios gracefully.
System.exit(0)
}
}
}
Expand Down Expand Up @@ -263,8 +282,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
// 1 thread for updating heart beat
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var futureMappingCheck: Future[Either[String, Unit]] = null
try {
val futureMappingCheck = Future {
futureMappingCheck = Future {
checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex)
}

Expand Down Expand Up @@ -307,7 +328,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
} finally {
if (threadPool != null) {
threadPool.shutdown()
threadPoolFactory.shutdownThreadPool(threadPool)
}
}
}
Expand Down Expand Up @@ -475,7 +496,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
} else if (!flintReader.hasNext) {
canProceed = false
} else {
lastActivityTime = currentTimeProvider.currentEpochMillis()
val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater)

val (dataToWrite, returnedVerificationResult) = processStatementOnVerification(
Expand All @@ -497,6 +517,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
resultIndex,
flintSessionIndexUpdater,
osClient)
// last query finish time is last activity time
lastActivityTime = currentTimeProvider.currentEpochMillis()
}
}

Expand Down Expand Up @@ -850,12 +872,18 @@ object FlintREPL extends Logging with FlintJobExecutor {
threadPool: ScheduledExecutorService,
osClient: OSClient,
sessionIndex: String,
initialDelayMillis: Long): Unit = {
initialDelayMillis: Long): ScheduledFuture[_] = {

threadPool.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = {
try {
// Check the thread's interrupt status at the beginning of the run method
if (Thread.interrupted()) {
logWarning("HeartBeatUpdater has been interrupted. Terminating.")
return // Exit the run method if the thread is interrupted
}

val getResponse = osClient.getDoc(sessionIndex, sessionId)
if (getResponse.isExists()) {
val source = getResponse.getSourceAsMap
Expand All @@ -871,6 +899,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
// do nothing if the session doc does not exist
} catch {
case ie: InterruptedException =>
// Preserve the interrupt status
Thread.currentThread().interrupt()
logError("HeartBeatUpdater task was interrupted", ie)
// maybe due to invalid sequence number or primary term
case e: Exception =>
logWarning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,66 @@

package org.apache.spark.sql.util

import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.{ScheduledExecutorService, TimeUnit}

trait ThreadPoolFactory {
import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging

trait ThreadPoolFactory extends Logging {
def newDaemonThreadPoolScheduledExecutor(
threadNamePrefix: String,
numThreads: Int): ScheduledExecutorService

def shutdownThreadPool(executor: ScheduledExecutorService): Unit = {
logInfo(s"terminate executor ${executor}")
executor.shutdown() // Disable new tasks from being submitted

try {
// Wait a while for existing tasks to terminate
if (!executor.awaitTermination(60, TimeUnit.SECONDS)) {
penghuo marked this conversation as resolved.
Show resolved Hide resolved
logWarning("Executor did not terminate in the specified time.")
val tasksNotExecuted = executor.shutdownNow() // Cancel currently executing tasks
penghuo marked this conversation as resolved.
Show resolved Hide resolved
// Log the tasks that were awaiting execution
logInfo(s"The following tasks were cancelled: $tasksNotExecuted")

// Wait a while for tasks to respond to being cancelled
if (!executor.awaitTermination(60, TimeUnit.SECONDS)) {
logError("Thread pool did not terminate after shutdownNow.")
}
}
} catch {
case ie: InterruptedException =>
// (Re-)Cancel if current thread also interrupted
executor.shutdownNow()
// Log the interrupted status
logError("Shutdown interrupted", ie)
// Preserve interrupt status
Thread.currentThread().interrupt()
}
}

/**
* Checks if there are any non-daemon threads other than the "main" thread.
*
* @return
* true if non-daemon threads other than "main" are active, false otherwise.
*/
def hasNonDaemonThreadsOtherThanMain(): Boolean = {
// Log thread information and check for non-daemon threads
Thread.getAllStackTraces.keySet.asScala.exists { t =>
val thread = t.asInstanceOf[Thread]
val isNonDaemon = !thread.isDaemon && thread.getName != "main"

// Log the thread information
logInfo(s"Name: ${thread.getName}; IsDaemon: ${thread.isDaemon}; State: ${thread.getState}")

// Log the stack trace
Option(Thread.getAllStackTraces.get(thread)).foreach(_.foreach(traceElement =>
logInfo(s" at $traceElement")))

// Return true if a non-daemon thread is found
isNonDaemon
}
}
}
Loading