Skip to content

Commit

Permalink
refactor statement lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Jun 14, 2024
1 parent dab2343 commit a17dce3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ trait SessionManager {
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit
def hasPendingStatement(sessionId: String): Boolean
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val Update, Upsert, UpdateIf = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ trait StatementManager {
def prepareCommandLifecycle(): Either[String, Unit]
def initCommandLifecycle(sessionId: String): FlintStatement
def closeCommandLifecycle(): Unit
def updateCommandDetails(commandDetails: FlintStatement): Unit
def getNextStatement(statement: FlintStatement): Option[FlintStatement]
def updateStatement(statement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.{ThreadUtils, Utils}

Expand Down Expand Up @@ -355,37 +356,52 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

// TODO: Refactor this with getDetails
private def refreshSessionState(
applicationId: String,
jobId: String,
sessionId: String,
sessionManager: SessionManager,
jobStartTime: Long,
state: String,
error: Option[String] = None,
excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = {

val sessionDetails = sessionManager
.getSessionDetails(sessionId)
.getOrElse(
new InteractiveSession(
applicationId,
jobId,
sessionId,
state,
currentTimeProvider.currentEpochMillis(),
jobStartTime,
error = error,
excludedJobIds = excludedJobIds))
sessionDetails.state = state
sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT)
sessionDetails
}

private def setupFlintJob(
applicationId: String,
jobId: String,
sessionId: String,
sessionManager: SessionManager,
jobStartTime: Long,
excludeJobIds: Seq[String] = Seq.empty[String]): Unit = {
val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId)
val currentTime = currentTimeProvider.currentEpochMillis()
val flintJob = new InteractiveSession(
refreshSessionState(
applicationId,
jobId,
sessionId,
"running",
currentTime,
sessionManager,
jobStartTime,
excludeJobIds)

// TODO: serialize need to be refactored to be more flexible
val serializedFlintInstance = if (includeJobId) {
InteractiveSession.serialize(flintJob, currentTime, true)
} else {
InteractiveSession.serializeWithoutJobId(flintJob, currentTime)
}
flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance)
logInfo(s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}}""")
SessionStates.RUNNING,
excludeJobIds = excludeJobIds)
sessionRunningCount.incrementAndGet()
}

def handleSessionError(
private def handleSessionError(
applicationId: String,
jobId: String,
sessionId: String,
Expand All @@ -395,39 +411,15 @@ object FlintREPL extends Logging with FlintJobExecutor {
e: Exception): Unit = {
val error = s"Session error: ${e.getMessage}"
CustomLogging.logError(error, e)

val sessionDetails = sessionManager
.getSessionDetails(sessionId)
.getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error))

updateFlintInstance(sessionDetails, flintSessionIndexUpdater, sessionId)
if (sessionDetails.isFail) {
recordSessionFailed(sessionTimerContext)
}
}

private def createFailedFlintInstance(
applicationId: String,
jobId: String,
sessionId: String,
jobStartTime: Long,
errorMessage: String): InteractiveSession = new InteractiveSession(
applicationId,
jobId,
sessionId,
"fail",
currentTimeProvider.currentEpochMillis(),
jobStartTime,
error = Some(errorMessage))

private def updateFlintInstance(
flintInstance: InteractiveSession,
flintSessionIndexUpdater: OpenSearchUpdater,
sessionId: String): Unit = {
val currentTime = currentTimeProvider.currentEpochMillis()
flintSessionIndexUpdater.upsert(
refreshSessionState(
applicationId,
jobId,
sessionId,
InteractiveSession.serializeWithoutJobId(flintInstance, currentTime))
sessionManager,
jobStartTime,
SessionStates.FAIL,
Some(e.getMessage))
recordSessionFailed(sessionTimerContext)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ class SessionManagerImpl(flintOptions: FlintOptions)
Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running")))
}

override def hasPendingStatement(sessionId: String): Boolean = {
flintReader.hasNext
}

private def createQueryReader(sessionId: String, sessionIndex: String, dataSource: String) = {
// all state in index are in lower case
// we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the
Expand Down Expand Up @@ -121,11 +117,11 @@ class SessionManagerImpl(flintOptions: FlintOptions)
sessionDetails: InteractiveSession,
sessionUpdateMode: SessionUpdateMode): Unit = {
sessionUpdateMode match {
case SessionUpdateMode.Update =>
case SessionUpdateMode.UPDATE =>
flintSessionIndexUpdater.update(
sessionDetails.sessionId,
InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis()))
case SessionUpdateMode.Upsert =>
case SessionUpdateMode.UPSERT =>
val includeJobId =
!sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains(
sessionDetails.jobId)
Expand All @@ -140,7 +136,7 @@ class SessionManagerImpl(flintOptions: FlintOptions)
currentTimeProvider.currentEpochMillis())
}
flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession)
case SessionUpdateMode.UpdateIf =>
case SessionUpdateMode.UPDATEIF =>
val executionContext = sessionDetails.executionContext.getOrElse(
throw new IllegalArgumentException("Missing executionContext for conditional update"))
val seqNo = executionContext
Expand Down

0 comments on commit a17dce3

Please sign in to comment.