From 8e02db846bdea2a1e007f2c92813d6e382f76b24 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 13 Jun 2024 21:16:58 -0700 Subject: [PATCH] refactor statement lifecycle --- .../org/apache/spark/sql/SessionManager.scala | 4 +- .../spark/sql/StatementLifecycleManager.scala | 14 + .../apache/spark/sql/StatementManager.scala | 15 - .../apache/spark/sql/FlintJobExecutor.scala | 25 -- .../org/apache/spark/sql/FlintREPL.scala | 356 ++++++++---------- .../spark/sql/QueryExecutionContext.scala | 3 +- .../apache/spark/sql/SessionManagerImpl.scala | 112 +++--- .../sql/StatementLifecycleManagerImpl.scala | 51 +++ ...cala => inMemoryQueryExecutionState.scala} | 9 +- 9 files changed, 299 insertions(+), 290 deletions(-) create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala delete mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala rename spark-sql-application/src/main/scala/org/apache/spark/sql/{CommandState.scala => inMemoryQueryExecutionState.scala} (61%) diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala index 345f97619..9b53240c8 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -15,11 +15,11 @@ trait SessionManager { def updateSessionDetails( sessionDetails: InteractiveSession, updateMode: SessionUpdateMode): Unit - def hasPendingStatement(sessionId: String): Boolean + def getNextStatement(sessionId: String): Option[FlintStatement] def recordHeartbeat(sessionId: String): Unit } object SessionUpdateMode extends Enumeration { type SessionUpdateMode = Value - val Update, Upsert, UpdateIf = Value + val UPDATE, UPSERT, UPDATE_IF = Value } diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala new file mode 100644 index 000000000..b4af9e5f6 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +trait StatementLifecycleManager { + def prepareStatementLifecycle(): Either[String, Unit] + def updateStatement(statement: FlintStatement): Unit + def closeStatementLifecycle(): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala deleted file mode 100644 index c0a24ab33..000000000 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.apache.spark.sql - -import org.opensearch.flint.data.FlintStatement - -trait StatementManager { - def prepareCommandLifecycle(): Either[String, Unit] - def initCommandLifecycle(sessionId: String): FlintStatement - def closeCommandLifecycle(): Unit - def updateCommandDetails(commandDetails: FlintStatement): Unit -} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 665ec5a27..f4a29f7d2 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -348,31 +348,6 @@ trait FlintJobExecutor { compareJson(inputJson, mappingJson) || compareJson(mappingJson, inputJson) } - def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { - try { - val existingSchema = osClient.getIndexMetadata(resultIndex) - if (!isSuperset(existingSchema, resultIndexMapping)) { - Left(s"The mapping of $resultIndex is incorrect.") - } else { - Right(()) - } - } catch { - case e: IllegalStateException - if e.getCause != null && - e.getCause.getMessage.contains("index_not_found_exception") => - createResultIndex(osClient, resultIndex, resultIndexMapping) - case e: InterruptedException => - val error = s"Interrupted by the main thread: ${e.getMessage}" - Thread.currentThread().interrupt() // Preserve the interrupt status - logError(error, e) - Left(error) - case e: Exception => - val error = s"Failed to verify existing mapping: ${e.getMessage}" - logError(error, e) - Left(error) - } - } - def createResultIndex( osClient: OSClient, resultIndex: String, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 9b0b66d28..06215141d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -22,12 +22,13 @@ import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} +import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SessionUpdateMode.UPDATE_IF import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.{ThreadUtils, Utils} @@ -47,11 +48,11 @@ import org.apache.spark.util.{ThreadUtils, Utils} object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + private val PREPARE_QUERY_EXEC_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 - val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L + private val INITIAL_DELAY_MILLIS = 3000L + private val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L @volatile var earlyExitFlag: Boolean = false @@ -111,10 +112,11 @@ object FlintREPL extends Logging with FlintJobExecutor { } val spark = createSparkSession(conf) + spark.sparkContext.getConf + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong( @@ -128,7 +130,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - val sessionManager = instantiateSessionManager() + val sessionManager = instantiateSessionManager val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -173,23 +175,33 @@ object FlintREPL extends Logging with FlintJobExecutor { return } - val commandContext = QueryExecutionContext( + val queryExecutionManager = instantiateQueryExecutionManager( + sessionManager.getSessionManagerMetadata) + val queryExecutionContext = QueryExecutionContext( spark, jobId, sessionId.get, sessionManager, + queryExecutionManager, + queryResultWriter, dataSource, - resultIndex, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { - queryLoop(commandContext) + queryLoop(queryExecutionContext) } recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => - handleSessionError(sessionTimerContext = sessionTimerContext, e = e) + handleSessionError( + applicationId, + jobId, + sessionId.get, + sessionManager, + sessionTimerContext, + jobStartTime, + e) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running @@ -275,16 +287,16 @@ object FlintREPL extends Logging with FlintJobExecutor { case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 - val excludeJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis + val excludedJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis - if (excludeJobIds.contains(jobId)) { + if (excludedJobIds.contains(jobId)) { logInfo(s"current job is excluded, exit the application.") return true } val sessionDetails = sessionManager.getSessionDetails(sessionId) val existingExcludedJobIds = sessionDetails.get.excludedJobIds - if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { + if (excludedJobIds.sorted == existingExcludedJobIds.sorted) { logInfo("duplicate job running, exit the application.") return true } @@ -296,20 +308,20 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, sessionManager, jobStartTime, - excludeJobIds) + excludedJobIds = excludedJobIds) } false } def queryLoop(queryExecutionContext: QueryExecutionContext): Unit = { - // 1 thread for updating heart beat + val statementLifecycleManager = queryExecutionContext.statementLifecycleManager + // 1 thread for query execution preparation val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var futureMappingCheck: Future[Either[String, Unit]] = null + implicit val futureExecutor = ExecutionContext.fromExecutor(threadPool) + var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { - checkAndCreateIndex(queryExecutionContext.osClient, queryExecutionContext.resultIndex) + futurePrepareQueryExecution = Future { + statementLifecycleManager.prepareStatementLifecycle() } var lastActivityTime = currentTimeProvider.currentEpochMillis() @@ -318,19 +330,17 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= queryExecutionContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo( - s"""read from ${queryExecutionContext.sessionIndex}, sessionId: ${queryExecutionContext.sessionId}""") + logInfo(s"""sessionId: ${queryExecutionContext.sessionId}""") try { - val commandState = CommandState( + val inMemoryQueryExecutionState = inMemoryQueryExecutionState( lastActivityTime, + lastCanPickCheckTime, verificationResult, - flintReader, - futureMappingCheck, - executionContext, - lastCanPickCheckTime) + futurePrepareQueryExecution, + futureExecutor) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(queryExecutionContext, commandState) + processCommands(queryExecutionContext, inMemoryQueryExecutionState) val ( updatedLastActivityTime, @@ -343,7 +353,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - flintReader.close() + statementLifecycleManager.closeStatementLifecycle() } Thread.sleep(100) @@ -355,37 +365,52 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - // TODO: Refactor this with getDetails + private def refreshSessionState( + applicationId: String, + jobId: String, + sessionId: String, + sessionManager: SessionManager, + jobStartTime: Long, + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + sessionDetails.state = state + sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionDetails + } + private def setupFlintJob( applicationId: String, jobId: String, sessionId: String, sessionManager: SessionManager, jobStartTime: Long, - excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + excludedJobIds: Seq[String] = Seq.empty[String]): Unit = { + refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) - - // TODO: serialize need to be refactored to be more flexible - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo(s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}}""") + SessionStates.RUNNING, + excludedJobIds = excludedJobIds) sessionRunningCount.incrementAndGet() } - def handleSessionError( + private def handleSessionError( applicationId: String, jobId: String, sessionId: String, @@ -395,39 +420,15 @@ object FlintREPL extends Logging with FlintJobExecutor { e: Exception): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - - val sessionDetails = sessionManager - .getSessionDetails(sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(sessionDetails, flintSessionIndexUpdater, sessionId) - if (sessionDetails.isFail) { - recordSessionFailed(sessionTimerContext) - } - } - - private def createFailedFlintInstance( - applicationId: String, - jobId: String, - sessionId: String, - jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - - private def updateFlintInstance( - flintInstance: InteractiveSession, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { - val currentTime = currentTimeProvider.currentEpochMillis() - flintSessionIndexUpdater.upsert( + refreshSessionState( + applicationId, + jobId, sessionId, - InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } /** @@ -479,7 +480,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processCommands( context: QueryExecutionContext, - state: CommandState): (Long, VerificationResult, Boolean, Long) = { + state: inMemoryQueryExecutionState): (Long, VerificationResult, Boolean, Long) = { import context._ import state._ @@ -492,8 +493,8 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { + // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionManager, sessionId, jobId) lastCanPickCheckTime = currentTime } @@ -501,35 +502,23 @@ object FlintREPL extends Logging with FlintJobExecutor { if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { - val statementTimerContext = getTimerContext( - MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) + sessionManager.getNextStatement(sessionId) match { + case Some(flintStatement) => + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex, - queryExecutionTimeout, - queryWaitTimeMillis) + val (dataToWrite, returnedVerificationResult) = + processStatementOnVerification(context, state, flintStatement) - verificationResult = returnedVerificationResult - finalizeCommand( - dataToWrite, - flintStatement, - resultIndex, - flintSessionIndexUpdater, - osClient, - statementTimerContext) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() + verificationResult = returnedVerificationResult + finalizeCommand(dataToWrite, flintStatement, queryResultWriter, statementTimerContext) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + + case None => + canProceed = false + } } } @@ -607,25 +596,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeAndHandle( - spark: SparkSession, + context: QueryExecutionContext, + state: inMemoryQueryExecutionState, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecuitonTimeOut: Duration, - queryWaitTimeMillis: Long): Option[DataFrame] = { + startTime: Long): Option[DataFrame] = { + import context._ + import state._ + try { - Some( - executeQueryAsync( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecuitonTimeOut, - queryWaitTimeMillis)) + Some(executeQueryAsync(context, state, flintStatement, startTime)) } catch { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" @@ -645,16 +624,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processStatementOnVerification( - recordedVerificationResult: VerificationResult, - spark: SparkSession, - flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String, - queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long) = { + context: QueryExecutionContext, + state: inMemoryQueryExecutionState, + flintStatement: FlintStatement) = { + import context._ + import state._ + val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None @@ -662,17 +637,9 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult match { case NotVerified => try { - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + ThreadUtils.awaitResult(futurePrepareQueryExecution, PREPARE_QUERY_EXEC_TIMEOUT) match { case Right(_) => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(context, state, flintStatement, startTime) verificationResult = VerifiedWithoutError case Left(error) => verificationResult = VerifiedWithError(error) @@ -713,15 +680,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, startTime)) case VerifiedWithoutError => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(context, state, flintStatement, startTime) } logInfo(s"command complete: $flintStatement") @@ -729,14 +688,13 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeQueryAsync( - spark: SparkSession, + context: QueryExecutionContext, + state: inMemoryQueryExecutionState, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecutionTimeOut: Duration, - queryWaitTimeMillis: Long): DataFrame = { + startTime: Long): DataFrame = { + import context._ + import state._ + if (currentTimeProvider .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( @@ -755,7 +713,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.queryId, sessionId, false) - }(executionContext) + }(futureExecutor) // time out after 10 minutes ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) @@ -827,29 +785,27 @@ object FlintREPL extends Logging with FlintJobExecutor { extends SparkListener with Logging { - // TODO: Refactor update - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { + sessionDetails.state = SessionStates.DEAD + sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + } } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } @@ -1008,25 +964,47 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } - private def instantiateSessionManager(): SessionManager = { - val options = FlintSparkConf().flintOptions() - val className = options.getCustomSessionManager() - - if (className.isEmpty) { - new SessionManagerImpl(options) + private def instantiateProvider[T](defaultProvider: => T, providerClassName: String): T = { + if (providerClassName.isEmpty) { + defaultProvider } else { try { - val providerClass = Utils.classForName(className) + val providerClass = Utils.classForName(providerClassName) val ctor = providerClass.getDeclaredConstructor() ctor.setAccessible(true) - ctor.newInstance().asInstanceOf[SessionManager] + ctor.newInstance().asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Failed to instantiate provider: $className", e) + throw new RuntimeException(s"Failed to instantiate provider: $providerClassName", e) } } } + private def instantiateManager[T]( + defaultInstance: => T, + customClassNameOption: () => String): T = { + val customClassName = customClassNameOption() + instantiateProvider(defaultInstance, customClassName) + } + + private def instantiateSessionManager: SessionManager = { + val options = FlintSparkConf().flintOptions() + instantiateManager(new SessionManagerImpl(options), options.getCustomSessionManager) + } + + private def instantiateQueryExecutionManager( + context: Map[String, Any]): StatementLifecycleManager = { + val options = FlintSparkConf().flintOptions() + instantiateManager( + new StatementLifecycleManagerImpl(context), + options.getCustomStatementManager) + } + + private def instantiateQueryResultWriter: QueryResultWriter = { + val options = FlintSparkConf().flintOptions() + instantiateManager(new QueryResultWriterImpl(), options.getCustomQueryResultWriter) + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala index 5108371ef..1d01e2b38 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala @@ -12,8 +12,9 @@ case class QueryExecutionContext( jobId: String, sessionId: String, sessionManager: SessionManager, + statementLifecycleManager: StatementLifecycleManager, + queryResultWriter: QueryResultWriter, dataSource: String, - resultIndex: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index 29f70ddbf..fa41234e8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -11,7 +11,7 @@ import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.storage.FlintReader -import org.opensearch.flint.data.InteractiveSession +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder @@ -65,6 +65,64 @@ class SessionManagerImpl(flintOptions: FlintOptions) } } + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.UPDATE => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.UPSERT => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("_seq_no") + .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .asInstanceOf[Long] + val primaryTerm = sessionDetails + .getContextValue("_primary_term") + .getOrElse( + throw new IllegalArgumentException("Missing _primary_term for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } + + override def getNextStatement(sessionId: String): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + logDebug(s"raw statement: $rawStatement") + val flintStatement = FlintStatement.deserialize(rawStatement) + logDebug(s"statement: $flintStatement") + Some(flintStatement) + } else { + None + } + } + override def recordHeartbeat(sessionId: String): Unit = { flintSessionIndexUpdater.upsert( sessionId, @@ -72,10 +130,6 @@ class SessionManagerImpl(flintOptions: FlintOptions) Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) } - override def hasPendingStatement(sessionId: String): Boolean = { - flintReader.hasNext - } - private def createQueryReader(sessionId: String, sessionIndex: String, dataSource: String) = { // all state in index are in lower case // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the @@ -116,52 +170,4 @@ class SessionManagerImpl(flintOptions: FlintOptions) val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) flintReader } - - override def updateSessionDetails( - sessionDetails: InteractiveSession, - sessionUpdateMode: SessionUpdateMode): Unit = { - sessionUpdateMode match { - case SessionUpdateMode.Update => - flintSessionIndexUpdater.update( - sessionDetails.sessionId, - InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) - case SessionUpdateMode.Upsert => - val includeJobId = - !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( - sessionDetails.jobId) - val serializedSession = if (includeJobId) { - InteractiveSession.serialize( - sessionDetails, - currentTimeProvider.currentEpochMillis(), - true) - } else { - InteractiveSession.serializeWithoutJobId( - sessionDetails, - currentTimeProvider.currentEpochMillis()) - } - flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) - case SessionUpdateMode.UpdateIf => - val executionContext = sessionDetails.executionContext.getOrElse( - throw new IllegalArgumentException("Missing executionContext for conditional update")) - val seqNo = executionContext - .get("_seq_no") - .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) - .asInstanceOf[Long] - val primaryTerm = executionContext - .get("_primary_term") - .getOrElse( - throw new IllegalArgumentException("Missing _primary_term for conditional update")) - .asInstanceOf[Long] - flintSessionIndexUpdater.updateIf( - sessionDetails.sessionId, - InteractiveSession.serializeWithoutJobId( - sessionDetails, - currentTimeProvider.currentEpochMillis()), - seqNo, - primaryTerm) - } - - logInfo( - s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") - } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala new file mode 100644 index 000000000..e66ebb1ae --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging + +class StatementLifecycleManagerImpl(context: Map[String, Any]) + extends StatementLifecycleManager + with FlintJobExecutor + with Logging { + val osClient = context("osClient").asInstanceOf[OSClient] + val resultIndex = context("osClient").asInstanceOf[String] + + override def prepareStatementLifecycle(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + + override def initCommandLifecycle(sessionId: String): FlintStatement = ??? + + override def closeCommandLifecycle(): Unit = ??? + + override def getNextStatement(statement: FlintStatement): Option[FlintStatement] = ??? + + override def updateStatement(statement: FlintStatement): Unit = ??? +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/inMemoryQueryExecutionState.scala similarity index 61% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala rename to spark-sql-application/src/main/scala/org/apache/spark/sql/inMemoryQueryExecutionState.scala index ad49201f0..c09ec823b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/inMemoryQueryExecutionState.scala @@ -9,10 +9,9 @@ import scala.concurrent.{ExecutionContextExecutor, Future} import org.opensearch.flint.core.storage.FlintReader -case class CommandState( +case class inMemoryQueryExecutionState( recordedLastActivityTime: Long, + recordedLastCanPickCheckTime: Long, recordedVerificationResult: VerificationResult, - flintReader: FlintReader, - futureMappingCheck: Future[Either[String, Unit]], - executionContext: ExecutionContextExecutor, - recordedLastCanPickCheckTime: Long) + futurePrepareQueryExecution: Future[Either[String, Unit]], + futureExecutor: ExecutionContextExecutor)