diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala new file mode 100644 index 000000000..7ddf6604b --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +trait QueryResultWriter { + def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala new file mode 100644 index 000000000..1c270fbc1 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} + +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode + +/** + * Trait defining the interface for managing interactive sessions. + */ +trait SessionManager { + + /** + * Retrieves metadata about the session manager. + */ + def getSessionContext: Map[String, Any] + + /** + * Fetches the details of a specific session. + */ + def getSessionDetails(sessionId: String): Option[InteractiveSession] + + /** + * Updates the details of a specific session. + */ + def updateSessionDetails( + sessionDetails: InteractiveSession, + updateMode: SessionUpdateMode): Unit + + /** + * Retrieves the next statement to be executed in a specific session. + */ + def getNextStatement(sessionId: String): Option[FlintStatement] + + /** + * Records a heartbeat for a specific session to indicate it is still active. + */ + def recordHeartbeat(sessionId: String): Unit +} + +object SessionUpdateMode extends Enumeration { + type SessionUpdateMode = Value + val UPDATE, UPSERT, UPDATE_IF = Value +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala new file mode 100644 index 000000000..1db736b97 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +/** + * Trait defining the interface for managing the lifecycle of executing a FlintStatement. + */ +trait StatementLifecycleManager { + + /** + * Prepares the statement lifecycle. + */ + def prepareStatementLifecycle(): Either[String, Unit] + + def executeStatement(statement: FlintStatement): DataFrame + + /** + * Updates a specific statement. + */ + def updateStatement(statement: FlintStatement): Unit + + /** + * Terminates the statement lifecycle. + */ + def terminateStatementLifecycle(): Unit +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index dc110afb9..f90989f17 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -219,6 +219,15 @@ object FlintSparkConf { FlintConfig("spark.metadata.accessAWSCredentialsProvider") .doc("AWS credentials provider for metadata access permission") .createOptional() + val CUSTOM_SESSION_MANAGER = + FlintConfig("spark.flint.job.customSessionManager") + .createOptional() + val CUSTOM_STATEMENT_MANAGER = + FlintConfig("spark.flint.job.customStatementManager") + .createOptional() + val CUSTOM_QUERY_RESULT_WRITER = + FlintConfig("spark.flint.job.customQueryResultWriter") + .createOptional() } /** diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 048f69ced..42e4f82d8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -13,12 +13,11 @@ import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} case class CommandContext( spark: SparkSession, dataSource: String, - resultIndex: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, + sessionManager: SessionManager, jobId: String, + statementLifecycleManager: StatementLifecycleManager, + queryResultWriter: QueryResultWriter, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala index ad49201f0..45b7e81cc 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala @@ -12,7 +12,6 @@ import org.opensearch.flint.core.storage.FlintReader case class CommandState( recordedLastActivityTime: Long, recordedVerificationResult: VerificationResult, - flintReader: FlintReader, - futureMappingCheck: Future[Either[String, Unit]], + futurePrepareQueryExecution: Future[Either[String, Unit]], executionContext: ExecutionContextExecutor, recordedLastCanPickCheckTime: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index bba999110..1278657cb 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ */ object FlintJob extends Logging with FlintJobExecutor { def main(args: Array[String]): Unit = { - val (queryOption, resultIndex) = parseArgs(args) + val (queryOption, resultIndexOption) = parseArgs(args) val conf = createSparkConf() val jobType = conf.get("spark.flint.job.type", "batch") @@ -41,6 +41,9 @@ object FlintJob extends Logging with FlintJobExecutor { if (query.isEmpty) { logAndThrow(s"Query undefined for the ${jobType} job.") } + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -58,7 +61,7 @@ object FlintJob extends Logging with FlintJobExecutor { createSparkSession(conf), query, dataSource, - resultIndex, + resultIndexOption.get, jobType.equalsIgnoreCase("streaming"), streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 37801a9e8..9b3841c7d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -493,16 +493,21 @@ trait FlintJobExecutor { } } - def parseArgs(args: Array[String]): (Option[String], String) = { + /** + * Before OS 2.13, there are two arguments from entry point: query and result index Starting + * from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also + * optional for non-OpenSearch result persist + */ + def parseArgs(args: Array[String]): (Option[String], Option[String]) = { args match { + case Array() => + (None, None) case Array(resultIndex) => - (None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument + (None, Some(resultIndex)) case Array(query, resultIndex) => - ( - Some(query), - resultIndex - ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.") + (Some(query), Some(resultIndex)) + case _ => + logAndThrow("Unsupported number of arguments. Expected no more than two arguments.") } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index e6b8b11ce..51bfffbe2 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,7 +18,7 @@ import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings -import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging @@ -31,8 +31,9 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.FlintREPLConfConstants._ +import org.apache.spark.sql.SessionUpdateMode.UPDATE_IF import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} object FlintREPLConfConstants { val HEARTBEAT_INTERVAL_MILLIS = 60000L @@ -61,19 +62,11 @@ object FlintREPL extends Logging with FlintJobExecutor { @volatile var earlyExitFlag: Boolean = false - def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { - updater.update(flintStatement.statementId, FlintStatement.serialize(flintStatement)) - } - private val sessionRunningCount = new AtomicInteger(0) private val statementRunningCount = new AtomicInteger(0) def main(args: Array[String]) { - val (queryOption, resultIndex) = parseArgs(args) - - if (Strings.isNullOrEmpty(resultIndex)) { - logAndThrow("resultIndex is not set") - } + val (queryOption, resultIndexOption) = parseArgs(args) // init SparkContext val conf: SparkConf = createSparkConf() @@ -95,7 +88,9 @@ object FlintREPL extends Logging with FlintJobExecutor { val query = getQuery(queryOption, jobType, conf) if (jobType.equalsIgnoreCase("streaming")) { - logInfo(s"""streaming query ${query}""") + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } configDYNMaxExecutors(conf, jobType) val streamingRunningCount = new AtomicInteger(0) val jobOperator = @@ -103,25 +98,17 @@ object FlintREPL extends Logging with FlintJobExecutor { createSparkSession(conf), query, dataSource, - resultIndex, + resultIndexOption.get, true, streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } else { - // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) - val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) - - if (sessionIndex.isEmpty) { - logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") - } - if (sessionId.isEmpty) { - logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") - } - + // we don't allow default value for sessionId. Throw exception if key not found. + val sessionId = getSessionId(conf) val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) + val sessionManager = instantiateSessionManager(spark, resultIndexOption) + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") @@ -142,7 +129,6 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong( "spark.flint.job.queryLoopExecutionFrequency", DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -155,12 +141,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/320 */ spark.sparkContext.addSparkListener( - new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionId.get, - sessionTimerContext)) + new PreShutdownListener(sessionId, sessionManager, sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -173,37 +154,31 @@ object FlintREPL extends Logging with FlintJobExecutor { // OpenSearch triggers recovery after 1 minute outdated heart beat var heartBeatFuture: ScheduledFuture[_] = null try { - heartBeatFuture = createHeartBeatUpdater( - HEARTBEAT_INTERVAL_MILLIS, - flintSessionIndexUpdater, - sessionId.get, - threadPool, - osClient, - sessionIndex.get, - INITIAL_DELAY_MILLIS) + heartBeatFuture = createHeartBeatUpdater(sessionId, sessionManager, threadPool) if (setupFlintJobWithExclusionCheck( conf, - sessionIndex, sessionId, - osClient, jobId, applicationId, - flintSessionIndexUpdater, + sessionManager, jobStartTime)) { earlyExitFlag = true return } + val statementLifecycleManager = + instantiateStatementLifecycleManager(spark, sessionManager.getSessionContext) + val queryResultWriter = + instantiateQueryResultWriter(spark, sessionManager.getSessionContext) val commandContext = CommandContext( spark, dataSource, - resultIndex, - sessionId.get, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, + sessionId, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis, @@ -218,11 +193,9 @@ object FlintREPL extends Logging with FlintJobExecutor { e, applicationId, jobId, - sessionId.get, + sessionId, + sessionManager, jobStartTime, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, sessionTimerContext) } finally { if (threadPool != null) { @@ -295,25 +268,17 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def setupFlintJobWithExclusionCheck( conf: SparkConf, - sessionIndex: Option[String], - sessionId: Option[String], - osClient: OSClient, + sessionId: String, jobId: String, applicationId: String, - flintSessionIndexUpdater: OpenSearchUpdater, + sessionManager: SessionManager, jobStartTime: Long): Boolean = { val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) confExcludeJobsOpt match { case None => // If confExcludeJobs is None, pass null or an empty sequence as per your setupFlintJob method's signature - setupFlintJob( - applicationId, - jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, - jobStartTime) + setupFlintJob(applicationId, jobId, sessionId, sessionManager, jobStartTime) case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 @@ -324,25 +289,19 @@ object FlintREPL extends Logging with FlintJobExecutor { return true } - val getResponse = osClient.getDoc(sessionIndex.get, sessionId.get) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source != null) { - val existingExcludedJobIds = parseExcludedJobIds(source) - if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { - logInfo("duplicate job running, exit the application.") - return true - } - } + val sessionDetails = sessionManager.getSessionDetails(sessionId) + val existingExcludedJobIds = sessionDetails.get.excludedJobIds + if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { + logInfo("duplicate job running, exit the application.") + return true } // If none of the edge cases are met, proceed with setup setupFlintJob( applicationId, jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, + sessionId, + sessionManager, jobStartTime, excludeJobIds) } @@ -350,14 +309,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { + import commandContext._ // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - var futureMappingCheck: Future[Either[String, Unit]] = null + var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { - checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) + futurePrepareQueryExecution = Future { + statementLifecycleManager.prepareStatementLifecycle() } var lastActivityTime = currentTimeProvider.currentEpochMillis() @@ -366,21 +326,13 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo( - s"""read from ${commandContext.sessionIndex}, sessionId: ${commandContext.sessionId}""") - val flintReader: FlintReader = - createQueryReader( - commandContext.osClient, - commandContext.sessionId, - commandContext.sessionIndex, - commandContext.dataSource) + logInfo(s"""Executing session with sessionId: ${sessionId}""") try { val commandState = CommandState( lastActivityTime, verificationResult, - flintReader, - futureMappingCheck, + futurePrepareQueryExecution, executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = @@ -397,7 +349,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - flintReader.close() + statementLifecycleManager.terminateStatementLifecycle() } Thread.sleep(commandContext.queryLoopExecutionFrequency) @@ -413,84 +365,72 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId: String, jobId: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionIndex: String, + sessionManager: SessionManager, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + val flintJob = refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) + SessionStates.RUNNING, + excludedJobIds = excludeJobIds) - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo( - s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") sessionRunningCount.incrementAndGet() } + private def refreshSessionState( + applicationId: String, + jobId: String, + sessionId: String, + sessionManager: SessionManager, + jobStartTime: Long, + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + sessionDetails.state = state + sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionDetails + } + def handleSessionError( e: Exception, applicationId: String, jobId: String, sessionId: String, + sessionManager: SessionManager, jobStartTime: Long, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, sessionTimerContext: Timer.Context): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - val flintInstance = getExistingFlintInstance(osClient, sessionIndex, sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) - if (flintInstance.isFail) { - recordSessionFailed(sessionTimerContext) - } - } - - private def getExistingFlintInstance( - osClient: OSClient, - sessionIndex: String, - sessionId: String): Option[InteractiveSession] = Try( - osClient.getDoc(sessionIndex, sessionId)) match { - case Success(getResponse) if getResponse.isExists() => - Option(getResponse.getSourceAsMap) - .map(InteractiveSession.deserializeFromMap) - case Failure(exception) => - CustomLogging.logError( - s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", - exception) - None - case _ => None + refreshSessionState( + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } - private def createFailedFlintInstance( - applicationId: String, - jobId: String, - sessionId: String, - jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - private def updateFlintInstance( flintInstance: InteractiveSession, flintSessionIndexUpdater: OpenSearchUpdater, @@ -565,43 +505,28 @@ object FlintREPL extends Logging with FlintJobExecutor { // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { - canPickNextStatementResult = - canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + canPickNextStatementResult = canPickNextStatement(sessionId, sessionManager, jobId) lastCanPickCheckTime = currentTime } if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { - val statementTimerContext = getTimerContext( - MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) - - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex, - queryExecutionTimeout, - queryWaitTimeMillis) - - verificationResult = returnedVerificationResult - finalizeCommand( - dataToWrite, - flintStatement, - resultIndex, - flintSessionIndexUpdater, - osClient, - statementTimerContext) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() + sessionManager.getNextStatement(sessionId) match { + case Some(flintStatement) => + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) + val (dataToWrite, returnedVerificationResult) = + processStatementOnVerification(flintStatement, state, context) + + verificationResult = returnedVerificationResult + finalizeCommand(context, dataToWrite, flintStatement, statementTimerContext) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + case None => + canProceed = false + } } } @@ -610,32 +535,25 @@ object FlintREPL extends Logging with FlintJobExecutor { } /** - * finalize command after processing + * finalize statement after processing * * @param dataToWrite * data to write * @param flintStatement - * flint command - * @param resultIndex - * result index - * @param flintSessionIndexUpdater - * flint session index updater + * flint statement */ private def finalizeCommand( + commandContext: CommandContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, - resultIndex: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, statementTimerContext: Timer.Context): Unit = { + import commandContext._ try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling flintStatement.complete() } - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -643,8 +561,9 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) flintStatement.fail() - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) + } finally { + statementLifecycleManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) } } @@ -720,16 +639,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processStatementOnVerification( - recordedVerificationResult: VerificationResult, - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String, - queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long) = { + commandState: CommandState, + commandContext: CommandContext) = { + import commandState._ + import commandContext._ + val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None @@ -737,7 +652,7 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult match { case NotVerified => try { - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { case Right(_) => dataToWrite = executeAndHandle( spark, @@ -762,7 +677,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } catch { case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" + val error = s"Query execution preparation timed out" CustomLogging.logError(error, e) dataToWrite = Some( handleCommandTimeout( @@ -842,70 +757,10 @@ object FlintREPL extends Logging with FlintJobExecutor { ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) } } - private def processCommandInitiation( - flintReader: FlintReader, - flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { - val command = flintReader.next() - logDebug(s"raw command: $command") - val flintStatement = FlintStatement.deserialize(command) - logDebug(s"command: $flintStatement") - flintStatement.running() - logDebug(s"command running: $flintStatement") - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - statementRunningCount.incrementAndGet() - flintStatement - } - - private def createQueryReader( - osClient: OSClient, - sessionId: String, - sessionIndex: String, - dataSource: String) = { - // all state in index are in lower case - // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the - // same doc - val dsl = - s"""{ - | "bool": { - | "must": [ - | { - | "term": { - | "type": "statement" - | } - | }, - | { - | "term": { - | "state": "waiting" - | } - | }, - | { - | "term": { - | "sessionId": "$sessionId" - | } - | }, - | { - | "term": { - | "dataSourceName": "$dataSource" - | } - | }, - | { - | "range": { - | "submitTime": { "gte": "now-1h" } - | } - | } - | ] - | } - |}""".stripMargin - - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader - } class PreShutdownListener( - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, sessionId: String, + sessionManager: SessionManager, sessionTimerContext: Timer.Context) extends SparkListener with Logging { @@ -913,77 +768,42 @@ object FlintREPL extends Logging with FlintJobExecutor { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (!getResponse.isExists()) { - return - } - - val source = getResponse.getSourceAsMap - if (source == null) { - return - } - - val state = Option(source.get("state")).map(_.asInstanceOf[String]) - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isComplete && !sessionDetails.isFail) { + sessionDetails.complete() + sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + recordSessionSuccess(sessionTimerContext) + } + } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } - private def updateFlintInstanceBeforeShutdown( - source: java.util.Map[String, AnyRef], - getResponse: GetResponse, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String, - sessionTimerContext: Timer.Context): Unit = { - val flintInstance = InteractiveSession.deserializeFromMap(source) - flintInstance.complete() - flintSessionIndexUpdater.updateIf( - sessionId, - InteractiveSession.serializeWithoutJobId( - flintInstance, - currentTimeProvider.currentEpochMillis()), - getResponse.getSeqNo, - getResponse.getPrimaryTerm) - recordSessionSuccess(sessionTimerContext) - } - /** - * Create a new thread to update the last update time of the flint instance. - * @param currentInterval - * the interval of updating the last update time. Unit is millisecond. - * @param flintSessionUpdater - * the updater of the flint instance. + * Create a new thread to update the last update time of the flint interactive session. + * * @param sessionId - * the session id of the flint instance. + * the session id of the flint interactive session. + * @param sessionManager + * the manager of the flint interactive session. * @param threadPool * the thread pool. - * @param osClient - * the OpenSearch client. - * @param initialDelayMillis - * the intial delay to start heartbeat */ def createHeartBeatUpdater( - currentInterval: Long, - flintSessionUpdater: OpenSearchUpdater, sessionId: String, - threadPool: ScheduledExecutorService, - osClient: OSClient, - sessionIndex: String, - initialDelayMillis: Long): ScheduledFuture[_] = { + sessionManager: SessionManager, + threadPool: ScheduledExecutorService): ScheduledFuture[_] = { threadPool.scheduleAtFixedRate( new Runnable { @@ -994,13 +814,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logWarning("HeartBeatUpdater has been interrupted. Terminating.") return // Exit the run method if the thread is interrupted } - - flintSessionUpdater.upsert( - sessionId, - Serialization.write( - Map( - "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), - "state" -> "running"))) + sessionManager.recordHeartbeat(sessionId) } catch { case ie: InterruptedException => // Preserve the interrupt status @@ -1020,8 +834,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } } }, - initialDelayMillis, - currentInterval, + INITIAL_DELAY_MILLIS, + HEARTBEAT_INTERVAL_MILLIS, java.util.concurrent.TimeUnit.MILLISECONDS) } @@ -1038,35 +852,26 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def canPickNextStatement( sessionId: String, - jobId: String, - osClient: OSClient, - sessionIndex: String): Boolean = { + sessionManager: SessionManager, + jobId: String): Boolean = { try { - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source == null) { - logError(s"""Session id ${sessionId} is empty""") - // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) - return true - } - - val runJobId = Option(source.get("jobId")).map(_.asInstanceOf[String]).orNull - val excludeJobIds: Seq[String] = parseExcludedJobIds(source) - - if (runJobId != null && jobId != runJobId) { - logInfo(s"""the current job ID ${jobId} is not the running job ID ${runJobId}""") - return false - } - if (excludeJobIds != null && excludeJobIds.contains(jobId)) { - logInfo(s"""${jobId} is in the list of excluded jobs""") - return false - } - true - } else { - // still proceed since we are not sure what happened (e.g., session doc may not be available yet) - logError(s"""Fail to find id ${sessionId} from session index""") - true + sessionManager.getSessionDetails(sessionId) match { + case Some(sessionDetails) => + val runJobId = sessionDetails.jobId + val excludeJobIds = sessionDetails.excludedJobIds + + if (!runJobId.isEmpty && jobId != runJobId) { + logInfo(s"the current job ID $jobId is not the running job ID ${runJobId}") + return false + } + if (excludeJobIds.contains(jobId)) { + logInfo(s"$jobId is in the list of excluded jobs") + return false + } + true + case None => + logError(s"Failed to fetch sessionDetails by sessionId: $sessionId.") + true } } catch { // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) @@ -1076,23 +881,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def parseExcludedJobIds(source: java.util.Map[String, AnyRef]): Seq[String] = { - - val rawExcludeJobIds = source.get("excludeJobIds") - Option(rawExcludeJobIds) - .map { - case s: String => Seq(s) - case list: java.util.List[_] @unchecked => - import scala.collection.JavaConverters._ - list.asScala.toList - .collect { case str: String => str } // Collect only strings from the list - case other => - logInfo(s"Unexpected type: ${other.getClass.getName}") - Seq.empty - } - .getOrElse(Seq.empty[String]) // In case of null, return an empty Seq - } - def exponentialBackoffRetry[T](maxRetries: Int, initialDelay: FiniteDuration)( block: => T): T = { var retries = 0 @@ -1130,6 +918,54 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } + private def getSessionId(conf: SparkConf): String = { + val sessionIdOption: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) + if (sessionIdOption.isEmpty) { + logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") + } + sessionIdOption.get + } + + private def instantiate[T](defaultConstructor: => T, className: String): T = { + if (className.isEmpty) { + defaultConstructor + } else { + try { + val classObject = Utils.classForName(className) + val ctor = classObject.getDeclaredConstructor() + ctor.setAccessible(true) + ctor.newInstance().asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to instantiate provider: $className", e) + } + } + } + + private def instantiateSessionManager( + spark: SparkSession, + resultIndex: Option[String]): SessionManager = { + instantiate( + new SessionManagerImpl(spark, resultIndex), + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key)) + } + + private def instantiateStatementLifecycleManager( + spark: SparkSession, + context: Map[String, Any]): StatementLifecycleManager = { + instantiate( + new StatementLifecycleManagerImpl(context), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key)) + } + + private def instantiateQueryResultWriter( + spark: SparkSession, + context: Map[String, Any]): QueryResultWriter = { + instantiate( + new QueryResultWriterImpl(context), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key)) + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala new file mode 100644 index 000000000..8d07e91ae --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch + +class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + + override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { + writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala new file mode 100644 index 000000000..9ac4b4f39 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -0,0 +1,174 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.{Failure, Success, Try} + +import org.json4s.native.Serialization +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.core.logging.CustomLogging +import org.opensearch.flint.core.storage.FlintReader +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode +import org.apache.spark.sql.flint.config.FlintSparkConf + +class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) + extends SessionManager + with FlintJobExecutor + with Logging { + + // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key) + val sessionId: String = spark.conf.get(FlintSparkConf.SESSION_ID.key) + val dataSource: String = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + + if (sessionIndex.isEmpty) { + logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") + } + if (resultIndex.isEmpty) { + logAndThrow("resultIndex is not set") + } + if (sessionId.isEmpty) { + logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") + } + if (dataSource.isEmpty) { + logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") + } + + val osClient = new OSClient(FlintSparkConf().flintOptions()) + val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + val flintReader: FlintReader = createOpenSearchQueryReader() + + override def getSessionContext: Map[String, Any] = { + Map( + "resultIndex" -> resultIndex.get, + "osClient" -> osClient, + "flintSessionIndexUpdater" -> flintSessionIndexUpdater, + "flintReader" -> flintReader) + } + + override def getSessionDetails(sessionId: String): Option[InteractiveSession] = { + Try(osClient.getDoc(sessionIndex, sessionId)) match { + case Success(getResponse) if getResponse.isExists => + Option(getResponse.getSourceAsMap) + .map(InteractiveSession.deserializeFromMap) + case Failure(exception) => + CustomLogging.logError( + s"Failed to retrieve existing InteractiveSession: ${exception.getMessage}", + exception) + None + case _ => None + } + } + + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.UPDATE => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.UPSERT => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("_seq_no") + .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .asInstanceOf[Long] + val primaryTerm = sessionDetails + .getContextValue("_primary_term") + .getOrElse( + throw new IllegalArgumentException("Missing _primary_term for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } + + override def getNextStatement(sessionId: String): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + logDebug(s"raw statement: $rawStatement") + val flintStatement = FlintStatement.deserialize(rawStatement) + logDebug(s"statement: $flintStatement") + Some(flintStatement) + } else { + None + } + } + + override def recordHeartbeat(sessionId: String): Unit = { + flintSessionIndexUpdater.upsert( + sessionId, + Serialization.write( + Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) + } + + private def createOpenSearchQueryReader() = { + // all state in index are in lower case + // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the + // same doc + val dsl = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "type": "statement" + | } + | }, + | { + | "term": { + | "state": "waiting" + | } + | }, + | { + | "term": { + | "sessionId": "$sessionId" + | } + | }, + | { + | "term": { + | "dataSourceName": "$dataSource" + | } + | }, + | { + | "range": { + | "submitTime": { "gte": "now-1h" } + | } + | } + | ] + | } + |}""".stripMargin + + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) + flintReader + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala new file mode 100644 index 000000000..5f8775adb --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.{createResultIndex, isSuperset, resultIndexMapping} +import org.apache.spark.sql.FlintREPL.executeQuery + +class StatementLifecycleManagerImpl(context: Map[String, Any]) + extends StatementLifecycleManager + with Logging { + + val sessionId = context("sessionId").asInstanceOf[String] + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + val flintSessionIndexUpdater = + context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] + val flintReader = context("flintReader").asInstanceOf[FlintReader] + + override def prepareStatementLifecycle(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + override def updateStatement(statement: FlintStatement): Unit = { + flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) + } + override def terminateStatementLifecycle(): Unit = { + flintReader.close() + } + + override def executeStatement(statement: FlintStatement): DataFrame = { + executeQuery() + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 9c193fc9a..acc6f6bcd 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -22,7 +22,7 @@ import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse -import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.scalatest.prop.TableDrivenPropertyChecks._ @@ -159,18 +159,18 @@ class FlintREPLTest test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks - val flintSessionUpdater = mock[OpenSearchUpdater] - val osClient = mock[OSClient] val threadPool = mock[ScheduledExecutorService] - val getResponse = mock[GetResponse] val scheduledFutureRaw = mock[ScheduledFuture[_]] - + val sessionManager = mock[SessionManager] + val sessionId = "session1" + val currentInterval = 1000L + val initialDelayMillis = 0L // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. when( threadPool.scheduleAtFixedRate( any[Runnable], - eqTo(0), - *, + eqTo(initialDelayMillis), + eqTo(currentInterval), eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) .thenAnswer((invocation: InvocationOnMock) => { val runnable = invocation.getArgument[Runnable](0) @@ -179,55 +179,35 @@ class FlintREPLTest }) // Invoke the method - FlintREPL.createHeartBeatUpdater( - 1000L, - flintSessionUpdater, - "session1", - threadPool, - osClient, - "sessionIndex", - 0) + FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) // Verifications - verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) + verify(sessionManager).recordHeartbeat(sessionId) } test("PreShutdownListener updates FlintInstance if conditions are met") { // Mock dependencies - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val sessionIndex = "testIndex" val sessionId = "testSessionId" val timerContext = mock[Timer.Context] + val sessionManager = mock[SessionManager] - // Setup the getDoc to return a document indicating the session is running - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn( - Map[String, Object]( - "applicationId" -> "app1", - "jobId" -> "job1", - "sessionId" -> "session1", - "state" -> "running", - "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError", - "state" -> "running", - "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + val interactiveSession = new InteractiveSession( + "app123", + "job123", + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Instantiate the listener - val listener = new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex, - sessionId, - timerContext) + val listener = new PreShutdownListener(sessionId, sessionManager, timerContext) // Simulate application end listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) - // Verify the update is called with the correct arguments - verify(flintSessionIndexUpdater).updateIf(*, *, *, *) + verify(sessionManager).updateSessionDetails(interactiveSession, SessionUpdateMode.UPDATE_IF) + interactiveSession.state shouldBe SessionStates.DEAD } test("Test super.constructErrorDF should construct dataframe properly") { @@ -300,18 +280,18 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists and Valid JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" - - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val sessionManager = mock[SessionManager] - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) assert(result) } @@ -320,19 +300,20 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val differentJobId = "jobXYZ" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", differentJobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, differentJobId) // Assertions assert(!result) // The function should return false @@ -343,6 +324,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -356,8 +338,13 @@ class FlintREPLTest sourceMap.put("excludeJobIds", excludeJobIdsList) // But jobId is in the exclude list when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = mock[InteractiveSession] + when(interactiveSession.jobId).thenReturn("jobABC") + when(interactiveSession.excludedJobIds).thenReturn(Seq(jobId)) + // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assertions assert(!result) // The function should return false because jobId is excluded @@ -368,6 +355,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Mock the getDoc response val getResponse = mock[GetResponse] @@ -376,7 +364,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assertions assert(result) // The function should return true despite the null source @@ -387,6 +375,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -401,7 +390,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) assert(result) // The function should return true } @@ -411,6 +400,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock GetResponse val getResponse = mock[GetResponse] @@ -418,7 +408,7 @@ class FlintREPLTest when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist // Execute the function under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assert the function returns true assert(result) @@ -429,13 +419,14 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock OSClient to throw an exception when(osClient.getDoc(sessionIndex, sessionId)) .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) // Execute the method under test and expect true, since the method is designed to return true even in case of an exception - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Verify the result is true despite the exception assert(result) @@ -448,6 +439,7 @@ class FlintREPLTest val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -461,7 +453,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return true since jobId is not excluded assert(result) @@ -515,6 +507,7 @@ class FlintREPLTest val osClient = mock[OSClient] val sessionIndex = "sessionIndex" val handleSessionError = mock[Function1[String, Unit]] + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -531,8 +524,20 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = new InteractiveSession( + applicationId = "app123", + sessionId = sessionId, + state = "active", + lastUpdateTime = System.currentTimeMillis(), + jobId = "jobABC", + excludedJobIds = Seq(jobId)) + + // Mock the sessionManager to return the mocked interactiveSession + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return false since jobId is excluded assert(!result) @@ -543,6 +548,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -559,8 +565,20 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = new InteractiveSession( + applicationId = "app123", + sessionId = sessionId, + state = "active", + lastUpdateTime = System.currentTimeMillis(), + jobId = "jobABC", + excludedJobIds = Seq(jobId)) + + // Mock the sessionManager to return the mocked interactiveSession + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return true since the jobId is not in the excludeJobIds list assert(result) @@ -588,17 +606,18 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -720,12 +739,18 @@ class FlintREPLTest flintStatement.error should not be None flintStatement.error.get should include("Syntax error:") } finally threadPool.shutdown() - } test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) when(getResponse.getSourceAsMap).thenReturn( @@ -740,18 +765,24 @@ class FlintREPLTest when(getResponse.getSeqNo).thenReturn(0L) when(getResponse.getPrimaryTerm).thenReturn(0L) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + lastUpdateTime = java.lang.Long.valueOf(12345L)) + + when(sessionManager.getSessionDetails(sessionId)) + .thenReturn(Some(interactiveSession)) val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(!result) // Expecting false as the job should proceed normally } @@ -759,21 +790,34 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the rest of the GetResponse as needed - val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq(jobId), + lastUpdateTime = java.lang.Long.valueOf(12345L)) + + when(sessionManager.getSessionDetails(sessionId)) + .thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(result) // Expecting true as the job should exit early } @@ -781,6 +825,12 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the GetResponse to simulate a scenario of a duplicate job @@ -799,15 +849,22 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") + val interactiveSession = new InteractiveSession( + "app1", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq("job-1", "job-2"), + lastUpdateTime = java.lang.Long.valueOf(12345L)) + // Mock sessionManager to return sessionDetails + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(result) // Expecting true for early exit due to duplicate job } @@ -815,20 +872,33 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") + val interactiveSession = new InteractiveSession( + "app1", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq("job-5", "job-6"), + lastUpdateTime = java.lang.Long.valueOf(12345L)) + // Mock sessionManager to return sessionDetails + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(!result) // Expecting false as the job proceeds normally } @@ -838,16 +908,19 @@ class FlintREPLTest val osClient = mock[OSClient] val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] assertThrows[NoSuchElementException] { FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - None, // No sessionIndex provided - None, // No sessionId provided - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) } } @@ -870,26 +943,25 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mock processCommands to always allow loop continuation - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) val startTime = System.currentTimeMillis() @@ -911,6 +983,10 @@ class FlintREPLTest .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(true) + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" @@ -926,12 +1002,11 @@ class FlintREPLTest val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, @@ -947,6 +1022,9 @@ class FlintREPLTest mockGetResponse }) + // Mock getNextStatement to return None, simulating the end of statements + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -973,6 +1051,11 @@ class FlintREPLTest val sessionId = "testSessionId" val jobId = "testJobId" + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing @@ -983,12 +1066,11 @@ class FlintREPLTest val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1028,18 +1110,20 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val flintSessionIndexUpdater = mock[OpenSearchUpdater] val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1113,15 +1197,18 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( mockSparkSession, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1163,6 +1250,10 @@ class FlintREPLTest val sessionId = "testSessionId" val jobId = "testJobId" + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() @@ -1171,12 +1262,11 @@ class FlintREPLTest val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60,