diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala new file mode 100644 index 000000000..d69fbc30f --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +trait QueryResultWriter { + def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala new file mode 100644 index 000000000..345f97619 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} + +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode + +trait SessionManager { + def getSessionManagerMetadata: Map[String, Any] + def getSessionDetails(sessionId: String): Option[InteractiveSession] + def updateSessionDetails( + sessionDetails: InteractiveSession, + updateMode: SessionUpdateMode): Unit + def hasPendingStatement(sessionId: String): Boolean + def recordHeartbeat(sessionId: String): Unit +} + +object SessionUpdateMode extends Enumeration { + type SessionUpdateMode = Value + val Update, Upsert, UpdateIf = Value +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala new file mode 100644 index 000000000..c0a24ab33 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +trait StatementManager { + def prepareCommandLifecycle(): Either[String, Unit] + def initCommandLifecycle(sessionId: String): FlintStatement + def closeCommandLifecycle(): Unit + def updateCommandDetails(commandDetails: FlintStatement): Unit +} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala index c5eaee4f1..3727d14ee 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.data -import java.util.{Map => JavaMap} +import java.util.{List => JavaList, Map => JavaMap} import scala.collection.JavaConverters._ @@ -16,9 +16,9 @@ import org.json4s.native.Serialization object SessionStates { val RUNNING = "running" - val COMPLETE = "complete" - val FAILED = "failed" - val WAITING = "waiting" + val DEAD = "dead" + val FAIL = "fail" + val NOT_STARTED = "not_started" } /** @@ -57,9 +57,9 @@ class InteractiveSession( context = sessionContext // Initialize the context from the constructor def isRunning: Boolean = state == SessionStates.RUNNING - def isComplete: Boolean = state == SessionStates.COMPLETE - def isFailed: Boolean = state == SessionStates.FAILED - def isWaiting: Boolean = state == SessionStates.WAITING + def isDead: Boolean = state == SessionStates.DEAD + def isFail: Boolean = state == SessionStates.FAIL + def isNotStarted: Boolean = state == SessionStates.NOT_STARTED override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") @@ -129,10 +129,7 @@ object InteractiveSession { } // We safely handle the possibility of excludeJobIds being absent or not a list. - val excludeJobIds: Seq[String] = scalaSource.get("excludeJobIds") match { - case Some(lst: java.util.List[_]) => lst.asScala.toList.map(_.asInstanceOf[String]) - case _ => Seq.empty[String] - } + val excludeJobIds: Seq[String] = parseExcludedJobIds(scalaSource.get("excludeJobIds")) // Handle error similarly, ensuring we get an Option[String]. val maybeError: Option[String] = scalaSource.get("error") match { @@ -201,4 +198,13 @@ object InteractiveSession { def serializeWithoutJobId(job: InteractiveSession, currentTime: Long): String = { serialize(job, currentTime, includeJobId = false) } + private def parseExcludedJobIds(source: Option[Any]): Seq[String] = { + source match { + case Some(s: String) => Seq(s) + case Some(list: JavaList[_]) => list.asScala.toList.collect { case str: String => str } + case None => Seq.empty[String] + case _ => + Seq.empty + } + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 0cf643791..48107fe8c 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -49,6 +49,12 @@ public class FlintOptions implements Serializable { public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider"; + public static final String CUSTOM_SESSION_MANAGER = "customSessionManager"; + + public static final String CUSTOM_STATEMENT_MANAGER = "customStatementManager"; + + public static final String CUSTOM_QUERY_RESULT_WRITER = "customQueryResultWriter"; + /** * By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain. */ @@ -56,6 +62,8 @@ public class FlintOptions implements Serializable { public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex"; + public static final String FLINT_SESSION_ID = "spark.flint.job.sessionId"; + /** * Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader} */ @@ -137,6 +145,18 @@ public String getMetadataAccessAwsCredentialsProvider() { return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); } + public String getCustomSessionManager() { + return options.getOrDefault(CUSTOM_SESSION_MANAGER, ""); + } + + public String getCustomStatementManager() { + return options.getOrDefault(CUSTOM_STATEMENT_MANAGER, ""); + } + + public String getCustomQueryResultWriter() { + return options.getOrDefault(CUSTOM_QUERY_RESULT_WRITER, ""); + } + public String getUsername() { return options.getOrDefault(USERNAME, "flint"); } @@ -157,6 +177,10 @@ public String getSystemIndexName() { return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, ""); } + public String getSessionId() { + return options.getOrDefault(FLINT_SESSION_ID, null); + } + public int getBatchBytes() { // we did not expect this value could be large than 10mb = 10 * 1024 * 1024 return (int) org.apache.spark.network.util.JavaUtils diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index c6638c0b2..a70f3630b 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -203,6 +203,15 @@ object FlintSparkConf { FlintConfig("spark.metadata.accessAWSCredentialsProvider") .doc("AWS credentials provider for metadata access permission") .createOptional() + val CUSTOM_SESSION_MANAGER = + FlintConfig("spark.flint.job.customSessionManager") + .createOptional() + val CUSTOM_STATEMENT_MANAGER = + FlintConfig("spark.flint.job.customStatementManager") + .createOptional() + val CUSTOM_QUERY_RESULT_WRITER = + FlintConfig("spark.flint.job.customQueryResultWriter") + .createOptional() } /** @@ -277,6 +286,9 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable SESSION_ID, REQUEST_INDEX, METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, + CUSTOM_SESSION_MANAGER, + CUSTOM_STATEMENT_MANAGER, + CUSTOM_QUERY_RESULT_WRITER, EXCLUDE_JOB_IDS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 8cad8844b..9b0b66d28 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -15,7 +15,6 @@ import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import com.codahale.metrics.Timer -import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.core.FlintOptions @@ -24,14 +23,13 @@ import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Spark SQL Application entrypoint @@ -106,18 +104,13 @@ object FlintREPL extends Logging with FlintJobExecutor { jobOperator.start() } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) - if (sessionIndex.isEmpty) { - logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") - } if (sessionId.isEmpty) { logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") } val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") @@ -135,7 +128,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) + val sessionManager = instantiateSessionManager() val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -148,12 +141,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/320 */ spark.sparkContext.addSparkListener( - new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionId.get, - sessionTimerContext)) + new PreShutdownListener(sessionManager, sessionId.get, sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -161,42 +149,37 @@ object FlintREPL extends Logging with FlintJobExecutor { registerGauge(MetricConstants.REPL_RUNNING_METRIC, sessionRunningCount) registerGauge(MetricConstants.STATEMENT_RUNNING_METRIC, statementRunningCount) + val jobStartTime = currentTimeProvider.currentEpochMillis() // update heart beat every 30 seconds // OpenSearch triggers recovery after 1 minute outdated heart beat var heartBeatFuture: ScheduledFuture[_] = null try { heartBeatFuture = createHeartBeatUpdater( - HEARTBEAT_INTERVAL_MILLIS, - flintSessionIndexUpdater, sessionId.get, - threadPool, - osClient, - sessionIndex.get, - INITIAL_DELAY_MILLIS) + sessionManager, + HEARTBEAT_INTERVAL_MILLIS, + INITIAL_DELAY_MILLIS, + threadPool) if (setupFlintJobWithExclusionCheck( conf, - sessionIndex, - sessionId, - osClient, - jobId, applicationId, - flintSessionIndexUpdater, + jobId, + sessionId.get, + sessionManager, jobStartTime)) { earlyExitFlag = true return } - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( spark, + jobId, + sessionId.get, + sessionManager, dataSource, resultIndex, - sessionId.get, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - jobId, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) @@ -206,16 +189,7 @@ object FlintREPL extends Logging with FlintJobExecutor { recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => - handleSessionError( - e, - applicationId, - jobId, - sessionId.get, - jobStartTime, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionTimerContext) + handleSessionError(sessionTimerContext = sessionTimerContext, e = e) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running @@ -287,25 +261,17 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def setupFlintJobWithExclusionCheck( conf: SparkConf, - sessionIndex: Option[String], - sessionId: Option[String], - osClient: OSClient, - jobId: String, applicationId: String, - flintSessionIndexUpdater: OpenSearchUpdater, + jobId: String, + sessionId: String, + sessionManager: SessionManager, jobStartTime: Long): Boolean = { val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) confExcludeJobsOpt match { case None => // If confExcludeJobs is None, pass null or an empty sequence as per your setupFlintJob method's signature - setupFlintJob( - applicationId, - jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, - jobStartTime) + setupFlintJob(applicationId, jobId, sessionId, sessionManager, jobStartTime) case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 @@ -316,32 +282,26 @@ object FlintREPL extends Logging with FlintJobExecutor { return true } - val getResponse = osClient.getDoc(sessionIndex.get, sessionId.get) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source != null) { - val existingExcludedJobIds = parseExcludedJobIds(source) - if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { - logInfo("duplicate job running, exit the application.") - return true - } - } + val sessionDetails = sessionManager.getSessionDetails(sessionId) + val existingExcludedJobIds = sessionDetails.get.excludedJobIds + if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { + logInfo("duplicate job running, exit the application.") + return true } // If none of the edge cases are met, proceed with setup setupFlintJob( applicationId, jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, + sessionId, + sessionManager, jobStartTime, excludeJobIds) } false } - def queryLoop(commandContext: CommandContext): Unit = { + def queryLoop(queryExecutionContext: QueryExecutionContext): Unit = { // 1 thread for updating heart beat val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -349,7 +309,7 @@ object FlintREPL extends Logging with FlintJobExecutor { var futureMappingCheck: Future[Either[String, Unit]] = null try { futureMappingCheck = Future { - checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) + checkAndCreateIndex(queryExecutionContext.osClient, queryExecutionContext.resultIndex) } var lastActivityTime = currentTimeProvider.currentEpochMillis() @@ -357,15 +317,9 @@ object FlintREPL extends Logging with FlintJobExecutor { var canPickUpNextStatement = true var lastCanPickCheckTime = 0L while (currentTimeProvider - .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { + .currentEpochMillis() - lastActivityTime <= queryExecutionContext.inactivityLimitMillis && canPickUpNextStatement) { logInfo( - s"""read from ${commandContext.sessionIndex}, sessionId: ${commandContext.sessionId}""") - val flintReader: FlintReader = - createQueryReader( - commandContext.osClient, - commandContext.sessionId, - commandContext.sessionIndex, - commandContext.dataSource) + s"""read from ${queryExecutionContext.sessionIndex}, sessionId: ${queryExecutionContext.sessionId}""") try { val commandState = CommandState( @@ -376,7 +330,7 @@ object FlintREPL extends Logging with FlintJobExecutor { executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(commandContext, commandState) + processCommands(queryExecutionContext, commandState) val ( updatedLastActivityTime, @@ -401,12 +355,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } } + // TODO: Refactor this with getDetails private def setupFlintJob( applicationId: String, jobId: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionIndex: String, + sessionManager: SessionManager, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) @@ -420,55 +374,38 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime, excludeJobIds) + // TODO: serialize need to be refactored to be more flexible val serializedFlintInstance = if (includeJobId) { InteractiveSession.serialize(flintJob, currentTime, true) } else { InteractiveSession.serializeWithoutJobId(flintJob, currentTime) } flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo( - s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") + logInfo(s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}}""") sessionRunningCount.incrementAndGet() } def handleSessionError( - e: Exception, applicationId: String, jobId: String, sessionId: String, + sessionManager: SessionManager, + sessionTimerContext: Timer.Context, jobStartTime: Long, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, - sessionTimerContext: Timer.Context): Unit = { + e: Exception): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - val flintInstance = getExistingFlintInstance(osClient, sessionIndex, sessionId) + val sessionDetails = sessionManager + .getSessionDetails(sessionId) .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) - if (flintInstance.state.equals("fail")) { + updateFlintInstance(sessionDetails, flintSessionIndexUpdater, sessionId) + if (sessionDetails.isFail) { recordSessionFailed(sessionTimerContext) } } - private def getExistingFlintInstance( - osClient: OSClient, - sessionIndex: String, - sessionId: String): Option[InteractiveSession] = Try( - osClient.getDoc(sessionIndex, sessionId)) match { - case Success(getResponse) if getResponse.isExists() => - Option(getResponse.getSourceAsMap) - .map(InteractiveSession.deserializeFromMap) - case Failure(exception) => - CustomLogging.logError( - s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", - exception) - None - case _ => None - } - private def createFailedFlintInstance( applicationId: String, jobId: String, @@ -541,7 +478,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processCommands( - context: CommandContext, + context: QueryExecutionContext, state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ import state._ @@ -557,8 +494,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { - canPickNextStatementResult = - canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + canPickNextStatementResult = canPickNextStatement(sessionManager, sessionId, jobId) lastCanPickCheckTime = currentTime } @@ -885,43 +821,35 @@ object FlintREPL extends Logging with FlintJobExecutor { } class PreShutdownListener( - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, + sessionManager: SessionManager, sessionId: String, sessionTimerContext: Timer.Context) extends SparkListener with Logging { + // TODO: Refactor update + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (!getResponse.isExists()) { - return - } - val source = getResponse.getSourceAsMap - if (source == null) { - return - } - - val state = Option(source.get("state")).map(_.asInstanceOf[String]) - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { + updateFlintInstanceBeforeShutdown( + source, + getResponse, + flintSessionIndexUpdater, + sessionId, + sessionTimerContext) + } } } } @@ -960,13 +888,11 @@ object FlintREPL extends Logging with FlintJobExecutor { * the intial delay to start heartbeat */ def createHeartBeatUpdater( - currentInterval: Long, - flintSessionUpdater: OpenSearchUpdater, sessionId: String, - threadPool: ScheduledExecutorService, - osClient: OSClient, - sessionIndex: String, - initialDelayMillis: Long): ScheduledFuture[_] = { + sessionManager: SessionManager, + currentInterval: Long, + initialDelayMillis: Long, + threadPool: ScheduledExecutorService): ScheduledFuture[_] = { threadPool.scheduleAtFixedRate( new Runnable { @@ -978,12 +904,7 @@ object FlintREPL extends Logging with FlintJobExecutor { return // Exit the run method if the thread is interrupted } - flintSessionUpdater.upsert( - sessionId, - Serialization.write( - Map( - "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), - "state" -> "running"))) + sessionManager.recordHeartbeat(sessionId) } catch { case ie: InterruptedException => // Preserve the interrupt status @@ -1020,62 +941,36 @@ object FlintREPL extends Logging with FlintJobExecutor { * whether we can start fetching next statement or not */ def canPickNextStatement( + sessionManager: SessionManager, sessionId: String, - jobId: String, - osClient: OSClient, - sessionIndex: String): Boolean = { + jobId: String): Boolean = { try { - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source == null) { - logError(s"""Session id ${sessionId} is empty""") - // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) - return true - } - - val runJobId = Option(source.get("jobId")).map(_.asInstanceOf[String]).orNull - val excludeJobIds: Seq[String] = parseExcludedJobIds(source) - - if (runJobId != null && jobId != runJobId) { - logInfo(s"""the current job ID ${jobId} is not the running job ID ${runJobId}""") - return false - } - if (excludeJobIds != null && excludeJobIds.contains(jobId)) { - logInfo(s"""${jobId} is in the list of excluded jobs""") - return false - } - true - } else { - // still proceed since we are not sure what happened (e.g., session doc may not be available yet) - logError(s"""Fail to find id ${sessionId} from session index""") - true + sessionManager.getSessionDetails(sessionId) match { + case Some(sessionDetails) => + val runJobId = sessionDetails.jobId + val excludeJobIds = sessionDetails.excludedJobIds + + if (!runJobId.isEmpty && jobId != runJobId) { + logInfo(s"the current job ID $jobId is not the running job ID ${runJobId}") + return false + } + if (excludeJobIds.contains(jobId)) { + logInfo(s"$jobId is in the list of excluded jobs") + return false + } + true + case None => + logError(s"Failed to fetch sessionDetails by sessionId: $sessionId.") + true } } catch { - // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) + // still proceed with exception case e: Exception => CustomLogging.logError(s"""Fail to find id ${sessionId} from session index.""", e) true } } - private def parseExcludedJobIds(source: java.util.Map[String, AnyRef]): Seq[String] = { - - val rawExcludeJobIds = source.get("excludeJobIds") - Option(rawExcludeJobIds) - .map { - case s: String => Seq(s) - case list: java.util.List[_] @unchecked => - import scala.collection.JavaConverters._ - list.asScala.toList - .collect { case str: String => str } // Collect only strings from the list - case other => - logInfo(s"Unexpected type: ${other.getClass.getName}") - Seq.empty - } - .getOrElse(Seq.empty[String]) // In case of null, return an empty Seq - } - def exponentialBackoffRetry[T](maxRetries: Int, initialDelay: FiniteDuration)( block: => T): T = { var retries = 0 @@ -1113,6 +1008,25 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } + private def instantiateSessionManager(): SessionManager = { + val options = FlintSparkConf().flintOptions() + val className = options.getCustomSessionManager() + + if (className.isEmpty) { + new SessionManagerImpl(options) + } else { + try { + val providerClass = Utils.classForName(className) + val ctor = providerClass.getDeclaredConstructor() + ctor.setAccessible(true) + ctor.newInstance().asInstanceOf[SessionManager] + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to instantiate provider: $className", e) + } + } + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala similarity index 58% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala rename to spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala index fe2fa5212..5108371ef 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala @@ -5,20 +5,15 @@ package org.apache.spark.sql -import scala.concurrent.{ExecutionContextExecutor, Future} import scala.concurrent.duration.Duration -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} - -case class CommandContext( +case class QueryExecutionContext( spark: SparkSession, + jobId: String, + sessionId: String, + sessionManager: SessionManager, dataSource: String, resultIndex: String, - sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, - jobId: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala new file mode 100644 index 000000000..29f70ddbf --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.{Failure, Success, Try} + +import org.json4s.native.Serialization +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.logging.CustomLogging +import org.opensearch.flint.core.storage.FlintReader +import org.opensearch.flint.data.InteractiveSession +import org.opensearch.flint.data.InteractiveSession.formats +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode +import org.apache.spark.sql.flint.config.FlintSparkConf + +class SessionManagerImpl(flintOptions: FlintOptions) + extends SessionManager + with FlintJobExecutor + with Logging { + + // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. + val sessionIndex: String = flintOptions.getSystemIndexName + val sessionId: String = flintOptions.getSessionId + val dataSource: String = flintOptions.getDataSourceName + + if (sessionIndex.isEmpty) { + logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") + } + if (sessionId.isEmpty) { + logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") + } + if (dataSource.isEmpty) { + logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") + } + + val osClient = new OSClient(flintOptions) + val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + val flintReader: FlintReader = createQueryReader(sessionId, sessionIndex, dataSource) + + override def getSessionManagerMetadata: Map[String, Any] = { + Map( + "sessionIndex" -> sessionIndex, + "osClient" -> osClient, + "flintSessionIndexUpdater" -> flintSessionIndexUpdater, + "flintReader" -> flintReader) + } + + override def getSessionDetails(sessionId: String): Option[InteractiveSession] = { + Try(osClient.getDoc(sessionIndex, sessionId)) match { + case Success(getResponse) if getResponse.isExists => + Option(getResponse.getSourceAsMap) + .map(InteractiveSession.deserializeFromMap) + case Failure(exception) => + CustomLogging.logError( + s"Failed to retrieve existing InteractiveSession: ${exception.getMessage}", + exception) + None + case _ => None + } + } + + override def recordHeartbeat(sessionId: String): Unit = { + flintSessionIndexUpdater.upsert( + sessionId, + Serialization.write( + Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) + } + + override def hasPendingStatement(sessionId: String): Boolean = { + flintReader.hasNext + } + + private def createQueryReader(sessionId: String, sessionIndex: String, dataSource: String) = { + // all state in index are in lower case + // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the + // same doc + val dsl = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "type": "statement" + | } + | }, + | { + | "term": { + | "state": "waiting" + | } + | }, + | { + | "term": { + | "sessionId": "$sessionId" + | } + | }, + | { + | "term": { + | "dataSourceName": "$dataSource" + | } + | }, + | { + | "range": { + | "submitTime": { "gte": "now-1h" } + | } + | } + | ] + | } + |}""".stripMargin + + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) + flintReader + } + + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.Update => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.Upsert => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UpdateIf => + val executionContext = sessionDetails.executionContext.getOrElse( + throw new IllegalArgumentException("Missing executionContext for conditional update")) + val seqNo = executionContext + .get("_seq_no") + .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .asInstanceOf[Long] + val primaryTerm = executionContext + .get("_primary_term") + .getOrElse( + throw new IllegalArgumentException("Missing _primary_term for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 546cd8e97..45ec7b2cc 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -151,14 +151,7 @@ class FlintREPLTest }) // Invoke the method - FlintREPL.createHeartBeatUpdater( - 1000L, - flintSessionUpdater, - "session1", - threadPool, - osClient, - "sessionIndex", - 0) + FlintREPL.createHeartBeatUpdater() // Verifications verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) @@ -545,7 +538,7 @@ class FlintREPLTest try { val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( spark, dataSource, resultIndex, @@ -698,15 +691,7 @@ class FlintREPLTest val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking - val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + val result = FlintREPL.setupFlintJobWithExclusionCheck() assert(!result) // Expecting false as the job should proceed normally } @@ -720,15 +705,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") - val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + val result = FlintREPL.setupFlintJobWithExclusionCheck() assert(result) // Expecting true as the job should exit early } @@ -754,15 +731,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") - val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + val result = FlintREPL.setupFlintJobWithExclusionCheck() assert(result) // Expecting true for early exit due to duplicate job } @@ -775,15 +744,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") - val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + val result = FlintREPL.setupFlintJobWithExclusionCheck() assert(!result) // Expecting false as the job proceeds normally } @@ -794,15 +755,7 @@ class FlintREPLTest val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") assertThrows[NoSuchElementException] { - FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - None, // No sessionIndex provided - None, // No sessionId provided - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + FlintREPL.setupFlintJobWithExclusionCheck() } } @@ -826,7 +779,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( spark, dataSource, resultIndex, @@ -876,7 +829,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( spark, dataSource, resultIndex, @@ -932,7 +885,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( spark, dataSource, resultIndex, @@ -982,7 +935,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( spark, dataSource, resultIndex, @@ -1063,7 +1016,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( mockSparkSession, dataSource, resultIndex, @@ -1113,7 +1066,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = QueryExecutionContext( spark, dataSource, resultIndex,