From ab1480db5db911d0a539c2e4a848e3b13237c52b Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 13 Jun 2024 21:16:58 -0700 Subject: [PATCH] refactor statement lifecycle --- ...ager.scala => QueryExecutionManager.scala} | 7 +- .../org/apache/spark/sql/SessionManager.scala | 3 +- .../org/apache/spark/sql/FlintREPL.scala | 171 ++++++++++-------- .../spark/sql/QueryExecutionManagerImpl.scala | 51 ++++++ .../apache/spark/sql/SessionManagerImpl.scala | 20 +- 5 files changed, 156 insertions(+), 96 deletions(-) rename flint-commons/src/main/scala/org/apache/spark/sql/{StatementManager.scala => QueryExecutionManager.scala} (55%) create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionManagerImpl.scala diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryExecutionManager.scala similarity index 55% rename from flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala rename to flint-commons/src/main/scala/org/apache/spark/sql/QueryExecutionManager.scala index c0a24ab33..68d24cfa9 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryExecutionManager.scala @@ -7,9 +7,10 @@ package org.apache.spark.sql import org.opensearch.flint.data.FlintStatement -trait StatementManager { - def prepareCommandLifecycle(): Either[String, Unit] +trait QueryExecutionManager { + def prepareQueryExecution(): Either[String, Unit] def initCommandLifecycle(sessionId: String): FlintStatement def closeCommandLifecycle(): Unit - def updateCommandDetails(commandDetails: FlintStatement): Unit + def getNextStatement(statement: FlintStatement): Option[FlintStatement] + def updateStatement(statement: FlintStatement): Unit } diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala index 345f97619..e35eb843b 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -15,11 +15,10 @@ trait SessionManager { def updateSessionDetails( sessionDetails: InteractiveSession, updateMode: SessionUpdateMode): Unit - def hasPendingStatement(sessionId: String): Boolean def recordHeartbeat(sessionId: String): Unit } object SessionUpdateMode extends Enumeration { type SessionUpdateMode = Value - val Update, Upsert, UpdateIf = Value + val UPDATE, UPSERT, UPDATE_IF = Value } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 9b0b66d28..37d9f8bcf 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -22,12 +22,13 @@ import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} +import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SessionUpdateMode.{SessionUpdateMode, UPDATE_IF} import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.{ThreadUtils, Utils} @@ -128,7 +129,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - val sessionManager = instantiateSessionManager() + val sessionManager = instantiateSessionManager(spark) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -189,7 +190,14 @@ object FlintREPL extends Logging with FlintJobExecutor { recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => - handleSessionError(sessionTimerContext = sessionTimerContext, e = e) + handleSessionError( + applicationId, + jobId, + sessionId.get, + sessionManager, + sessionTimerContext, + jobStartTime, + e) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running @@ -275,16 +283,16 @@ object FlintREPL extends Logging with FlintJobExecutor { case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 - val excludeJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis + val excludedJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis - if (excludeJobIds.contains(jobId)) { + if (excludedJobIds.contains(jobId)) { logInfo(s"current job is excluded, exit the application.") return true } val sessionDetails = sessionManager.getSessionDetails(sessionId) val existingExcludedJobIds = sessionDetails.get.excludedJobIds - if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { + if (excludedJobIds.sorted == existingExcludedJobIds.sorted) { logInfo("duplicate job running, exit the application.") return true } @@ -296,7 +304,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, sessionManager, jobStartTime, - excludeJobIds) + excludedJobIds = excludedJobIds) } false } @@ -355,37 +363,52 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - // TODO: Refactor this with getDetails + private def refreshSessionState( + applicationId: String, + jobId: String, + sessionId: String, + sessionManager: SessionManager, + jobStartTime: Long, + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + sessionDetails.state = state + sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionDetails + } + private def setupFlintJob( applicationId: String, jobId: String, sessionId: String, sessionManager: SessionManager, jobStartTime: Long, - excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + excludedJobIds: Seq[String] = Seq.empty[String]): Unit = { + refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) - - // TODO: serialize need to be refactored to be more flexible - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo(s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}}""") + SessionStates.RUNNING, + excludedJobIds = excludedJobIds) sessionRunningCount.incrementAndGet() } - def handleSessionError( + private def handleSessionError( applicationId: String, jobId: String, sessionId: String, @@ -395,39 +418,15 @@ object FlintREPL extends Logging with FlintJobExecutor { e: Exception): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - - val sessionDetails = sessionManager - .getSessionDetails(sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(sessionDetails, flintSessionIndexUpdater, sessionId) - if (sessionDetails.isFail) { - recordSessionFailed(sessionTimerContext) - } - } - - private def createFailedFlintInstance( - applicationId: String, - jobId: String, - sessionId: String, - jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - - private def updateFlintInstance( - flintInstance: InteractiveSession, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { - val currentTime = currentTimeProvider.currentEpochMillis() - flintSessionIndexUpdater.upsert( + refreshSessionState( + applicationId, + jobId, sessionId, - InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } /** @@ -501,8 +500,6 @@ object FlintREPL extends Logging with FlintJobExecutor { if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) @@ -827,29 +824,27 @@ object FlintREPL extends Logging with FlintJobExecutor { extends SparkListener with Logging { - // TODO: Refactor update - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { + sessionDetails.state = SessionStates.DEAD + sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + } } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } @@ -1008,7 +1003,7 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } - private def instantiateSessionManager(): SessionManager = { + private def instantiateSessionManager(spark: SparkSession): SessionManager = { val options = FlintSparkConf().flintOptions() val className = options.getCustomSessionManager() @@ -1027,6 +1022,26 @@ object FlintREPL extends Logging with FlintJobExecutor { } } + private def instantiateQueryExecutionManager( + context: Map[String, Any]): QueryExecutionManager = { + val options = FlintSparkConf().flintOptions() + val className = options.getCustomStatementManager + + if (className.isEmpty) { + new QueryExecutionManagerImpl(context) + } else { + try { + val providerClass = Utils.classForName(className) + val ctor = providerClass.getDeclaredConstructor() + ctor.setAccessible(true) + ctor.newInstance().asInstanceOf[QueryExecutionManager] + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to instantiate provider: $className", e) + } + } + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionManagerImpl.scala new file mode 100644 index 000000000..807f11783 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionManagerImpl.scala @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging + +class QueryExecutionManagerImpl(context: Map[String, Any]) + extends QueryExecutionManager + with FlintJobExecutor + with Logging { + val osClient = context("osClient").asInstanceOf[OSClient] + val resultIndex = context("osClient").asInstanceOf[String] + + override def prepareQueryExecution(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + + override def initCommandLifecycle(sessionId: String): FlintStatement = ??? + + override def closeCommandLifecycle(): Unit = ??? + + override def getNextStatement(statement: FlintStatement): Option[FlintStatement] = ??? + + override def updateStatement(statement: FlintStatement): Unit = ??? +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index 29f70ddbf..7fdb6623d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -72,10 +72,6 @@ class SessionManagerImpl(flintOptions: FlintOptions) Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) } - override def hasPendingStatement(sessionId: String): Boolean = { - flintReader.hasNext - } - private def createQueryReader(sessionId: String, sessionIndex: String, dataSource: String) = { // all state in index are in lower case // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the @@ -121,11 +117,11 @@ class SessionManagerImpl(flintOptions: FlintOptions) sessionDetails: InteractiveSession, sessionUpdateMode: SessionUpdateMode): Unit = { sessionUpdateMode match { - case SessionUpdateMode.Update => + case SessionUpdateMode.UPDATE => flintSessionIndexUpdater.update( sessionDetails.sessionId, InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) - case SessionUpdateMode.Upsert => + case SessionUpdateMode.UPSERT => val includeJobId = !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( sessionDetails.jobId) @@ -140,15 +136,13 @@ class SessionManagerImpl(flintOptions: FlintOptions) currentTimeProvider.currentEpochMillis()) } flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) - case SessionUpdateMode.UpdateIf => - val executionContext = sessionDetails.executionContext.getOrElse( - throw new IllegalArgumentException("Missing executionContext for conditional update")) - val seqNo = executionContext - .get("_seq_no") + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("_seq_no") .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) .asInstanceOf[Long] - val primaryTerm = executionContext - .get("_primary_term") + val primaryTerm = sessionDetails + .getContextValue("_primary_term") .getOrElse( throw new IllegalArgumentException("Missing _primary_term for conditional update")) .asInstanceOf[Long]