Skip to content

Commit

Permalink
refactor statement lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Jun 14, 2024
1 parent dab2343 commit ab1480d
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

trait StatementManager {
def prepareCommandLifecycle(): Either[String, Unit]
trait QueryExecutionManager {
def prepareQueryExecution(): Either[String, Unit]
def initCommandLifecycle(sessionId: String): FlintStatement
def closeCommandLifecycle(): Unit
def updateCommandDetails(commandDetails: FlintStatement): Unit
def getNextStatement(statement: FlintStatement): Option[FlintStatement]
def updateStatement(statement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ trait SessionManager {
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit
def hasPendingStatement(sessionId: String): Boolean
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val Update, Upsert, UpdateIf = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SessionUpdateMode.{SessionUpdateMode, UPDATE_IF}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.{ThreadUtils, Utils}

Expand Down Expand Up @@ -128,7 +129,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
val queryWaitTimeoutMillis: Long =
conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS)

val sessionManager = instantiateSessionManager()
val sessionManager = instantiateSessionManager(spark)
val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC)

/**
Expand Down Expand Up @@ -189,7 +190,14 @@ object FlintREPL extends Logging with FlintJobExecutor {
recordSessionSuccess(sessionTimerContext)
} catch {
case e: Exception =>
handleSessionError(sessionTimerContext = sessionTimerContext, e = e)
handleSessionError(
applicationId,
jobId,
sessionId.get,
sessionManager,
sessionTimerContext,
jobStartTime,
e)
} finally {
if (threadPool != null) {
heartBeatFuture.cancel(true) // Pass `true` to interrupt if running
Expand Down Expand Up @@ -275,16 +283,16 @@ object FlintREPL extends Logging with FlintJobExecutor {

case Some(confExcludeJobs) =>
// example: --conf spark.flint.deployment.excludeJobs=job-1,job-2
val excludeJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis
val excludedJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis

if (excludeJobIds.contains(jobId)) {
if (excludedJobIds.contains(jobId)) {
logInfo(s"current job is excluded, exit the application.")
return true
}

val sessionDetails = sessionManager.getSessionDetails(sessionId)
val existingExcludedJobIds = sessionDetails.get.excludedJobIds
if (excludeJobIds.sorted == existingExcludedJobIds.sorted) {
if (excludedJobIds.sorted == existingExcludedJobIds.sorted) {
logInfo("duplicate job running, exit the application.")
return true
}
Expand All @@ -296,7 +304,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
sessionId,
sessionManager,
jobStartTime,
excludeJobIds)
excludedJobIds = excludedJobIds)
}
false
}
Expand Down Expand Up @@ -355,37 +363,52 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

// TODO: Refactor this with getDetails
private def refreshSessionState(
applicationId: String,
jobId: String,
sessionId: String,
sessionManager: SessionManager,
jobStartTime: Long,
state: String,
error: Option[String] = None,
excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = {

val sessionDetails = sessionManager
.getSessionDetails(sessionId)
.getOrElse(
new InteractiveSession(
applicationId,
jobId,
sessionId,
state,
currentTimeProvider.currentEpochMillis(),
jobStartTime,
error = error,
excludedJobIds = excludedJobIds))
sessionDetails.state = state
sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT)
sessionDetails
}

private def setupFlintJob(
applicationId: String,
jobId: String,
sessionId: String,
sessionManager: SessionManager,
jobStartTime: Long,
excludeJobIds: Seq[String] = Seq.empty[String]): Unit = {
val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId)
val currentTime = currentTimeProvider.currentEpochMillis()
val flintJob = new InteractiveSession(
excludedJobIds: Seq[String] = Seq.empty[String]): Unit = {
refreshSessionState(
applicationId,
jobId,
sessionId,
"running",
currentTime,
sessionManager,
jobStartTime,
excludeJobIds)

// TODO: serialize need to be refactored to be more flexible
val serializedFlintInstance = if (includeJobId) {
InteractiveSession.serialize(flintJob, currentTime, true)
} else {
InteractiveSession.serializeWithoutJobId(flintJob, currentTime)
}
flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance)
logInfo(s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}}""")
SessionStates.RUNNING,
excludedJobIds = excludedJobIds)
sessionRunningCount.incrementAndGet()
}

def handleSessionError(
private def handleSessionError(
applicationId: String,
jobId: String,
sessionId: String,
Expand All @@ -395,39 +418,15 @@ object FlintREPL extends Logging with FlintJobExecutor {
e: Exception): Unit = {
val error = s"Session error: ${e.getMessage}"
CustomLogging.logError(error, e)

val sessionDetails = sessionManager
.getSessionDetails(sessionId)
.getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error))

updateFlintInstance(sessionDetails, flintSessionIndexUpdater, sessionId)
if (sessionDetails.isFail) {
recordSessionFailed(sessionTimerContext)
}
}

private def createFailedFlintInstance(
applicationId: String,
jobId: String,
sessionId: String,
jobStartTime: Long,
errorMessage: String): InteractiveSession = new InteractiveSession(
applicationId,
jobId,
sessionId,
"fail",
currentTimeProvider.currentEpochMillis(),
jobStartTime,
error = Some(errorMessage))

private def updateFlintInstance(
flintInstance: InteractiveSession,
flintSessionIndexUpdater: OpenSearchUpdater,
sessionId: String): Unit = {
val currentTime = currentTimeProvider.currentEpochMillis()
flintSessionIndexUpdater.upsert(
refreshSessionState(
applicationId,
jobId,
sessionId,
InteractiveSession.serializeWithoutJobId(flintInstance, currentTime))
sessionManager,
jobStartTime,
SessionStates.FAIL,
Some(e.getMessage))
recordSessionFailed(sessionTimerContext)
}

/**
Expand Down Expand Up @@ -501,8 +500,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
if (!canPickNextStatementResult) {
earlyExitFlag = true
canProceed = false
} else if (!flintReader.hasNext) {
canProceed = false
} else {
val statementTimerContext = getTimerContext(
MetricConstants.STATEMENT_PROCESSING_TIME_METRIC)
Expand Down Expand Up @@ -827,29 +824,27 @@ object FlintREPL extends Logging with FlintJobExecutor {
extends SparkListener
with Logging {

// TODO: Refactor update

override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
logInfo("Shutting down REPL")
logInfo("earlyExitFlag: " + earlyExitFlag)

sessionManager.getSessionDetails(sessionId).foreach { sessionDetails =>
// It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true,
// it indicates that the control plane has already initiated a new session to handle remaining requests for the
// current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new
// session. However, the newly created session (initiated by the control plane) will enter a spin-wait state,
// where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption
// and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure
// the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate
// processing.
if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) {
updateFlintInstanceBeforeShutdown(
source,
getResponse,
flintSessionIndexUpdater,
sessionId,
sessionTimerContext)
try {
sessionManager.getSessionDetails(sessionId).foreach { sessionDetails =>
// It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true,
// it indicates that the control plane has already initiated a new session to handle remaining requests for the
// current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new
// session. However, the newly created session (initiated by the control plane) will enter a spin-wait state,
// where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption
// and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure
// the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate
// processing.
if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) {
sessionDetails.state = SessionStates.DEAD
sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF)
}
}
} catch {
case e: Exception => logError(s"Failed to update session state for $sessionId", e)
}
}
}
Expand Down Expand Up @@ -1008,7 +1003,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
result.getOrElse(throw new RuntimeException("Failed after retries"))
}

private def instantiateSessionManager(): SessionManager = {
private def instantiateSessionManager(spark: SparkSession): SessionManager = {
val options = FlintSparkConf().flintOptions()
val className = options.getCustomSessionManager()

Expand All @@ -1027,6 +1022,26 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

private def instantiateQueryExecutionManager(
context: Map[String, Any]): QueryExecutionManager = {
val options = FlintSparkConf().flintOptions()
val className = options.getCustomStatementManager

if (className.isEmpty) {
new QueryExecutionManagerImpl(context)
} else {
try {
val providerClass = Utils.classForName(className)
val ctor = providerClass.getDeclaredConstructor()
ctor.setAccessible(true)
ctor.newInstance().asInstanceOf[QueryExecutionManager]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
logInfo("Session Success")
stopTimer(sessionTimerContext)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

import org.apache.spark.internal.Logging

class QueryExecutionManagerImpl(context: Map[String, Any])
extends QueryExecutionManager
with FlintJobExecutor
with Logging {
val osClient = context("osClient").asInstanceOf[OSClient]
val resultIndex = context("osClient").asInstanceOf[String]

override def prepareQueryExecution(): Either[String, Unit] = {
try {
val existingSchema = osClient.getIndexMetadata(resultIndex)
if (!isSuperset(existingSchema, resultIndexMapping)) {
Left(s"The mapping of $resultIndex is incorrect.")
} else {
Right(())
}
} catch {
case e: IllegalStateException
if e.getCause != null &&
e.getCause.getMessage.contains("index_not_found_exception") =>
createResultIndex(osClient, resultIndex, resultIndexMapping)
case e: InterruptedException =>
val error = s"Interrupted by the main thread: ${e.getMessage}"
Thread.currentThread().interrupt() // Preserve the interrupt status
logError(error, e)
Left(error)
case e: Exception =>
val error = s"Failed to verify existing mapping: ${e.getMessage}"
logError(error, e)
Left(error)
}
}

override def initCommandLifecycle(sessionId: String): FlintStatement = ???

override def closeCommandLifecycle(): Unit = ???

override def getNextStatement(statement: FlintStatement): Option[FlintStatement] = ???

override def updateStatement(statement: FlintStatement): Unit = ???
}
Loading

0 comments on commit ab1480d

Please sign in to comment.