From 28bacdb75c5ac9898a5f3bdca4e94fb3452a6700 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Sun, 25 Aug 2024 18:21:09 -0700 Subject: [PATCH] Clean up and fix UTs Signed-off-by: Louis Chu --- .../common/model/InteractiveSession.scala | 6 +- .../apache/spark/sql/FlintJobITSuite.scala | 12 +- .../apache/spark/sql/FlintREPLITSuite.scala | 2 +- .../org/apache/spark/sql/CommandContext.scala | 22 +- .../scala/org/apache/spark/sql/FlintJob.scala | 6 + .../apache/spark/sql/FlintJobExecutor.scala | 20 +- .../org/apache/spark/sql/FlintREPL.scala | 123 ++- .../org/apache/spark/sql/JobOperator.scala | 40 +- .../apache/spark/sql/SessionManagerImpl.scala | 10 +- .../sql/StatementExecutionManagerImpl.scala | 29 +- .../org/apache/spark/sql/FlintJobTest.scala | 8 +- .../org/apache/spark/sql/FlintREPLTest.scala | 763 +++++++++++------- 12 files changed, 648 insertions(+), 393 deletions(-) diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index 9acdeab5f..915d5e229 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -14,6 +14,8 @@ import org.json4s.JsonAST.{JArray, JString} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization +import org.apache.spark.internal.Logging + object SessionStates { val RUNNING = "running" val DEAD = "dead" @@ -52,7 +54,8 @@ class InteractiveSession( val excludedJobIds: Seq[String] = Seq.empty[String], val error: Option[String] = None, sessionContext: Map[String, Any] = Map.empty[String, Any]) - extends ContextualDataStore { + extends ContextualDataStore + with Logging { context = sessionContext // Initialize the context from the constructor def running(): Unit = state = SessionStates.RUNNING @@ -96,6 +99,7 @@ object InteractiveSession { // Replace extractOpt with jsonOption and map val excludeJobIds: Seq[String] = meta \ "excludeJobIds" match { case JArray(lst) => lst.map(_.extract[String]) + case JString(s) => Seq(s) case _ => Seq.empty[String] } diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala index 7318e5c7c..9371557a2 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -98,9 +98,15 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { * JobOperator instance to accommodate specific runtime requirements. */ val job = - JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount) - job.envinromentProvider = new MockEnvironment( - Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + JobOperator( + appId, + jobRunId, + spark, + query, + dataSourceName, + resultIndex, + true, + streamingRunningCount) job.terminateJVM = false job.start() } diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 1d86a6589..24f3c9a89 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -169,7 +169,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { "spark.flint.job.queryLoopExecutionFrequency", queryLoopExecutionFrequency.toString) - FlintREPL.envinromentProvider = new MockEnvironment( + FlintREPL.environmentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) FlintREPL.enableHiveSupport = false FlintREPL.terminateJVM = false diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index dcc486922..0d5c062ae 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -11,14 +11,14 @@ import scala.concurrent.duration.Duration import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} case class CommandContext( - val spark: SparkSession, - val dataSource: String, - val sessionId: String, - val sessionManager: SessionManager, - val jobId: String, - var statementsExecutionManager: StatementExecutionManager, - val queryResultWriter: QueryResultWriter, - val queryExecutionTimeout: Duration, - val inactivityLimitMillis: Long, - val queryWaitTimeMillis: Long, - val queryLoopExecutionFrequency: Long) + applicationId: String, + jobId: String, + spark: SparkSession, + dataSource: String, + sessionId: String, + sessionManager: SessionManager, + queryResultWriter: QueryResultWriter, + queryExecutionTimeout: Duration, + inactivityLimitMillis: Long, + queryWaitTimeMillis: Long, + queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 1278657cb..c556e2786 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -55,9 +55,15 @@ object FlintJob extends Logging with FlintJobExecutor { conf.set("spark.sql.defaultCatalog", dataSource) configDYNMaxExecutors(conf, jobType) + val applicationId = + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( + applicationId, + jobId, createSparkSession(conf), query, dataSource, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 9b3841c7d..95d3ba0f1 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -41,7 +41,7 @@ trait FlintJobExecutor { var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() - var envinromentProvider: EnvironmentProvider = new RealEnvironment() + var environmentProvider: EnvironmentProvider = new RealEnvironment() var enableHiveSupport: Boolean = true // termiante JVM in the presence non-deamon thread before exiting var terminateJVM = true @@ -190,6 +190,7 @@ trait FlintJobExecutor { } } + // scalastyle:off /** * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. * @@ -201,6 +202,8 @@ trait FlintJobExecutor { * dataframe with result, schema and emr step id */ def getFormattedData( + applicationId: String, + jobId: String, result: DataFrame, spark: SparkSession, dataSource: String, @@ -231,14 +234,13 @@ trait FlintJobExecutor { // after consumed the query result. Streaming query shuffle data is cleaned after each // microBatch execution. cleaner.cleanUp(spark) - // Create the data rows val rows = Seq( ( resultToSave, resultSchemaToSave, - envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + jobId, + applicationId, dataSource, "SUCCESS", "", @@ -254,6 +256,8 @@ trait FlintJobExecutor { } def constructErrorDF( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, status: String, @@ -270,8 +274,8 @@ trait FlintJobExecutor { ( null, null, - envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + jobId, + applicationId, dataSource, status.toUpperCase(Locale.ROOT), error, @@ -396,6 +400,8 @@ trait FlintJobExecutor { } def executeQuery( + applicationId: String, + jobId: String, spark: SparkSession, query: String, dataSource: String, @@ -409,6 +415,8 @@ trait FlintJobExecutor { val result: DataFrame = spark.sql(query) // Get Data getFormattedData( + applicationId, + jobId, result, spark, dataSource, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 5cdc805dd..340c0656e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -19,7 +19,6 @@ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} -import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -79,6 +78,10 @@ object FlintREPL extends Logging with FlintJobExecutor { */ conf.set("spark.sql.defaultCatalog", dataSource) + val applicationId = + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val jobType = conf.get(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) CustomLogging.logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") conf.set(FlintSparkConf.JOB_TYPE.key, jobType) @@ -93,6 +96,8 @@ object FlintREPL extends Logging with FlintJobExecutor { val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( + applicationId, + jobId, createSparkSession(conf), query, dataSource, @@ -108,10 +113,6 @@ object FlintREPL extends Logging with FlintJobExecutor { val spark = createSparkSession(conf) val sessionManager = instantiateSessionManager(spark, resultIndexOption) - val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong( @@ -169,12 +170,12 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryResultWriter = instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - null, // StatementLifecycleManager will be instantiated inside query loop queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, @@ -315,25 +316,21 @@ object FlintREPL extends Logging with FlintJobExecutor { var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { + logInfo(s"""Executing session with sessionId: ${sessionId}""") + var lastActivityTime = currentTimeProvider.currentEpochMillis() var verificationResult: VerificationResult = NotVerified var canPickUpNextStatement = true var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo(s"""Executing session with sessionId: ${sessionId}""") val statementsExecutionManager = - instantiateStatementExecutionManager( - spark, - sessionId, - dataSource, - sessionManager.getSessionContext) + instantiateStatementExecutionManager(commandContext) futurePrepareQueryExecution = Future { statementsExecutionManager.prepareStatementExecution() } - commandContext.statementsExecutionManager = statementsExecutionManager try { val commandState = CommandState( lastActivityTime, @@ -342,7 +339,7 @@ object FlintREPL extends Logging with FlintJobExecutor { executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(commandContext, commandState) + processCommands(statementsExecutionManager, commandContext, commandState) val ( updatedLastActivityTime, @@ -374,9 +371,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionManager: SessionManager, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = refreshSessionState( + refreshSessionState( applicationId, jobId, sessionId, @@ -397,7 +392,7 @@ object FlintREPL extends Logging with FlintJobExecutor { state: String, error: Option[String] = None, excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { - + logInfo(s"refreshSessionState: ${jobId}") val sessionDetails = sessionManager .getSessionDetails(sessionId) .getOrElse( @@ -410,6 +405,7 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime, error = error, excludedJobIds = excludedJobIds)) + logInfo(s"Current session: ${sessionDetails}") logInfo(s"State is: ${sessionDetails.state}") sessionDetails.state = state logInfo(s"State is: ${sessionDetails.state}") @@ -460,6 +456,8 @@ object FlintREPL extends Logging with FlintJobExecutor { * failed data frame */ def handleCommandFailureAndGetFailedData( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, error: String, @@ -469,6 +467,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.fail() flintStatement.error = Some(error) super.constructErrorDF( + applicationId, + jobId, spark, dataSource, flintStatement.state, @@ -487,6 +487,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processCommands( + statementExecutionManager: StatementExecutionManager, context: CommandContext, state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ @@ -500,7 +501,6 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionId, sessionManager, jobId) @@ -511,19 +511,28 @@ object FlintREPL extends Logging with FlintJobExecutor { earlyExitFlag = true canProceed = false } else { - statementsExecutionManager.getNextStatement() match { + statementExecutionManager.getNextStatement() match { case Some(flintStatement) => flintStatement.running() - statementsExecutionManager.updateStatement(flintStatement) + statementExecutionManager.updateStatement(flintStatement) statementRunningCount.incrementAndGet() val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val (dataToWrite, returnedVerificationResult) = - processStatementOnVerification(flintStatement, state, context) + processStatementOnVerification( + statementExecutionManager, + flintStatement, + state, + context) verificationResult = returnedVerificationResult - finalizeCommand(context, dataToWrite, flintStatement, statementTimerContext) + finalizeCommand( + statementExecutionManager, + context, + dataToWrite, + flintStatement, + statementTimerContext) // last query finish time is last activity time lastActivityTime = currentTimeProvider.currentEpochMillis() case _ => @@ -545,6 +554,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * flint statement */ private def finalizeCommand( + statementExecutionManager: StatementExecutionManager, commandContext: CommandContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, @@ -564,18 +574,20 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) flintStatement.fail() } finally { - statementsExecutionManager.updateStatement(flintStatement) + statementExecutionManager.updateStatement(flintStatement) recordStatementStateChange(flintStatement, statementTimerContext) } } private def handleCommandTimeout( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, error: String, flintStatement: FlintStatement, sessionId: String, - startTime: Long): DataFrame = { + startTime: Long) = { /* * https://tinyurl.com/2ezs5xj9 * @@ -592,6 +604,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.timeout() flintStatement.error = Some(error) super.constructErrorDF( + applicationId, + jobId, spark, dataSource, flintStatement.state, @@ -602,10 +616,13 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } + // scalastyle:off def executeAndHandle( + applicationId: String, + jobId: String, spark: SparkSession, flintStatement: FlintStatement, - statementsExecutionManager: StatementExecutionManager, + statementExecutionManager: StatementExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -615,9 +632,11 @@ object FlintREPL extends Logging with FlintJobExecutor { try { Some( executeQueryAsync( + applicationId, + jobId, spark, flintStatement, - statementsExecutionManager: StatementExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -628,11 +647,22 @@ object FlintREPL extends Logging with FlintJobExecutor { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - Some(handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)) + Some( + handleCommandTimeout( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) case e: Exception => val error = processQueryException(e, flintStatement) Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -643,6 +673,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processStatementOnVerification( + statementExecutionManager: StatementExecutionManager, flintStatement: FlintStatement, commandState: CommandState, commandContext: CommandContext) = { @@ -659,9 +690,11 @@ object FlintREPL extends Logging with FlintJobExecutor { ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { case Right(_) => dataToWrite = executeAndHandle( + applicationId, + jobId, spark, flintStatement, - statementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -673,6 +706,8 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = VerifiedWithError(error) dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -686,6 +721,8 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) dataToWrite = Some( handleCommandTimeout( + applicationId, + jobId, spark, dataSource, error, @@ -697,6 +734,8 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -707,6 +746,8 @@ object FlintREPL extends Logging with FlintJobExecutor { case VerifiedWithError(err) => dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, err, @@ -715,9 +756,11 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime)) case VerifiedWithoutError => dataToWrite = executeAndHandle( + applicationId, + jobId, spark, flintStatement, - statementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -731,6 +774,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeQueryAsync( + applicationId: String, + jobId: String, spark: SparkSession, flintStatement: FlintStatement, statementsExecutionManager: StatementExecutionManager, @@ -743,6 +788,8 @@ object FlintREPL extends Logging with FlintJobExecutor { if (currentTimeProvider .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, "wait timeout", @@ -780,6 +827,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // processing. if (!earlyExitFlag && !sessionDetails.isComplete && !sessionDetails.isFail) { sessionDetails.complete() + logInfo(s"jobId before shutting down session: ${sessionDetails.jobId}") sessionManager.updateSessionDetails(sessionDetails, updateMode = UPDATE_IF) recordSessionSuccess(sessionTimerContext) } @@ -929,7 +977,6 @@ object FlintREPL extends Logging with FlintJobExecutor { private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (className.isEmpty) { - logInfo("Using default constructor") defaultConstructor } else { try { @@ -953,17 +1000,15 @@ object FlintREPL extends Logging with FlintJobExecutor { resultIndexOption: Option[String]): SessionManager = { instantiate( new SessionManagerImpl(spark, resultIndexOption), - spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) } private def instantiateStatementExecutionManager( - spark: SparkSession, - sessionId: String, - dataSource: String, - context: Map[String, Any]): StatementExecutionManager = { + commandContext: CommandContext): StatementExecutionManager = { + import commandContext._ instantiate( - new StatementExecutionManagerImpl(spark, sessionId, dataSource, context), - spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + new StatementExecutionManagerImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), spark, sessionId) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index c079b3e96..b49f4a9ed 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.util.ShuffleCleaner import org.apache.spark.util.ThreadUtils case class JobOperator( + applicationId: String, + jobId: String, spark: SparkSession, query: String, dataSource: String, @@ -51,13 +53,23 @@ case class JobOperator( val futureMappingCheck = Future { checkAndCreateIndex(osClient, resultIndex) } - val data = executeQuery(spark, query, dataSource, "", "", streaming) + val data = executeQuery(applicationId, jobId, spark, query, dataSource, "", "", streaming) val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) dataToWrite = Some(mappingCheckResult match { case Right(_) => data case Left(error) => - constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "FAILED", + error, + "", + query, + "", + startTime) }) exceptionThrown = false } catch { @@ -65,11 +77,31 @@ case class JobOperator( val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) dataToWrite = Some( - constructErrorDF(spark, dataSource, "TIMEOUT", error, "", query, "", startTime)) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "TIMEOUT", + error, + "", + query, + "", + startTime)) case e: Exception => val error = processQueryException(e) dataToWrite = Some( - constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime)) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "FAILED", + error, + "", + query, + "", + startTime)) } finally { cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index 79bd63200..2039159e4 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -8,10 +8,9 @@ package org.apache.spark.sql import scala.util.{Failure, Success, Try} import org.json4s.native.Serialization -import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.common.model.InteractiveSession import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.logging.CustomLogging -import org.opensearch.flint.core.storage.FlintReader import org.apache.spark.internal.Logging import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode @@ -27,14 +26,14 @@ class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String]) } // we don't allow default value for sessionIndex. Throw exception if key not found. - private val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") if (sessionIndex.isEmpty) { logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } - private val osClient = new OSClient(FlintSparkConf().flintOptions()) - private val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + val osClient = new OSClient(FlintSparkConf().flintOptions()) + lazy val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) override def getSessionContext: Map[String, Any] = { Map( @@ -50,7 +49,6 @@ class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String]) // Retrieve the source map and create session val sessionOption = Option(getResponse.getSourceAsMap) .map(InteractiveSession.deserializeFromMap) - // Retrieve sequence number and primary term from the response val seqNo = getResponse.getSeqNo val primaryTerm = getResponse.getPrimaryTerm diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index dc8414a17..0b059f1d3 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -6,21 +6,22 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.core.storage.OpenSearchUpdater import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createResultIndex, isSuperset, resultIndexMapping} -import org.apache.spark.sql.FlintREPL.executeQuery -class StatementExecutionManagerImpl( - spark: SparkSession, - sessionId: String, - dataSource: String, - context: Map[String, Any]) +/** + * StatementExecutionManagerImpl is session based implementation of StatementExecutionManager + * interface It uses FlintReader to fetch all pending queries in a mirco-batch + * @param commandContext + */ +class StatementExecutionManagerImpl(commandContext: CommandContext) extends StatementExecutionManager + with FlintJobExecutor with Logging { + private val context = commandContext.sessionManager.getSessionContext private val sessionIndex = context("sessionIndex").asInstanceOf[String] private val resultIndex = context("resultIndex").asInstanceOf[String] private val osClient = context("osClient").asInstanceOf[OSClient] @@ -53,10 +54,20 @@ class StatementExecutionManagerImpl( } override def executeStatement(statement: FlintStatement): DataFrame = { - executeQuery(spark, statement.query, dataSource, statement.queryId, sessionId, false) + import commandContext._ + executeQuery( + applicationId, + jobId, + spark, + statement.query, + dataSource, + statement.queryId, + sessionId, + false) } private def createOpenSearchQueryReader() = { + import commandContext._ // all state in index are in lower case // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the // same doc diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 19f596e31..339c1870d 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -11,6 +11,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider} class FlintJobTest extends SparkFunSuite with JobMatchers { + private val jobId = "testJobId" + private val applicationId = "testApplicationId" val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() @@ -55,8 +57,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { Array( "{'column_name':'Letter','data_type':'string'}", "{'column_name':'Number','data_type':'integer'}"), - "unknown", - "unknown", + jobId, + applicationId, dataSourceName, "SUCCESS", "", @@ -72,6 +74,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { // Compare the result val result = FlintJob.getFormattedData( + applicationId, + jobId, input, spark, dataSourceName, diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index bba0e40e2..433c2351b 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -18,14 +18,13 @@ import scala.reflect.runtime.universe.TypeTag import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer import org.mockito.{ArgumentMatchersSugar, Mockito} -import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} +import org.mockito.Mockito.{atLeastOnce, doNothing, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder -import org.scalatest.Ignore import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar @@ -49,6 +48,9 @@ class FlintREPLTest // By using a type alias and casting, I can bypass the type checking error. type AnyScheduledFuture = ScheduledFuture[_] + private val jobId = "testJobId" + private val applicationId = "testApplicationId" + test("parseArgs with no arguments should return (None, None)") { val args = Array.empty[String] val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) @@ -176,21 +178,15 @@ class FlintREPLTest } } - ignore("createHeartBeatUpdater should update heartbeat correctly") { + test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks val threadPool = mock[ScheduledExecutorService] val scheduledFutureRaw = mock[ScheduledFuture[_]] val sessionManager = mock[SessionManager] val sessionId = "session1" - val currentInterval = 1000L - val initialDelayMillis = 0L // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. - when( - threadPool.scheduleAtFixedRate( - any[Runnable], - eqTo(initialDelayMillis), - eqTo(currentInterval), - eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + when(threadPool + .scheduleAtFixedRate(any[Runnable], *, *, eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) .thenAnswer((invocation: InvocationOnMock) => { val runnable = invocation.getArgument[Runnable](0) runnable.run() @@ -201,7 +197,7 @@ class FlintREPLTest FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) // Verifications - verify(sessionManager).recordHeartbeat(sessionId) + verify(sessionManager, atLeastOnce()).recordHeartbeat(sessionId) } test("PreShutdownListener updates FlintInstance if conditions are met") { @@ -256,8 +252,8 @@ class FlintREPLTest Row( null, null, - "unknown", - "unknown", + jobId, + applicationId, dataSourceName, "FAILED", error, @@ -280,6 +276,8 @@ class FlintREPLTest // Compare the result val result = FlintREPL.handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSourceName, error, @@ -338,29 +336,21 @@ class FlintREPLTest assert(!result) // The function should return false } - ignore("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { + test("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - - val excludeJobIdsList = new java.util.ArrayList[String]() - excludeJobIdsList.add(jobId) // Add the jobId to the list to simulate exclusion - - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId) // The jobId matches - sourceMap.put("excludeJobIds", excludeJobIdsList) // But jobId is in the exclude list - when(getResponse.getSourceAsMap).thenReturn(sourceMap) - - // Mock the InteractiveSession - val interactiveSession = mock[InteractiveSession] - when(interactiveSession.jobId).thenReturn("jobABC") - when(interactiveSession.excludedJobIds).thenReturn(Seq(jobId)) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000, + Seq(jobId) // Add the jobId to the list to simulate exclusion + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -372,15 +362,9 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists but Source is Null") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" val sessionManager = mock[SessionManager] - // Mock the getDoc response - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null + when(sessionManager.getSessionDetails(sessionId)).thenReturn(None) // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -392,12 +376,21 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists with Unexpected Type in excludeJobIds") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) val sourceMap = new java.util.HashMap[String, Object]() @@ -417,13 +410,21 @@ class FlintREPLTest test("test canPickNextStatement: Doc Does Not Exist") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } // Set up the mock GetResponse val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist // Execute the function under test @@ -436,12 +437,20 @@ class FlintREPLTest test("test canPickNextStatement: OSClient Throws Exception") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } // Set up the mock OSClient to throw an exception - when(osClient.getDoc(sessionIndex, sessionId)) + when(mockOSClient.getDoc(sessionIndex, sessionId)) .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) // Execute the method under test and expect true, since the method is designed to return true even in case of an exception @@ -456,12 +465,20 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as a String that does NOT match jobId @@ -523,18 +540,30 @@ class FlintREPLTest test("Doc Exists and excludeJobIds is an ArrayList Containing JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val handleSessionError = mock[Function1[String, Unit]] - val sessionManager = mock[SessionManager] + val lastUpdateTime = System.currentTimeMillis() + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as an ArrayList containing jobId val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) // Creating an ArrayList and adding the jobId to it val excludeJobIdsList = new java.util.ArrayList[String]() @@ -542,19 +571,6 @@ class FlintREPLTest sourceMap.put("excludeJobIds", excludeJobIdsList.asInstanceOf[Object]) when(getResponse.getSourceAsMap).thenReturn(sourceMap) - - // Mock the InteractiveSession - val interactiveSession = new InteractiveSession( - applicationId = "app123", - sessionId = sessionId, - state = "active", - lastUpdateTime = System.currentTimeMillis(), - jobId = "jobABC", - excludedJobIds = Seq(jobId)) - - // Mock the sessionManager to return the mocked interactiveSession - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -562,20 +578,33 @@ class FlintREPLTest assert(!result) } - ignore("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { + test("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val lastUpdateTime = System.currentTimeMillis() + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) // Creating an ArrayList and adding a different jobId to it val excludeJobIdsList = new java.util.ArrayList[String]() @@ -584,18 +613,6 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) - // Mock the InteractiveSession - val interactiveSession = new InteractiveSession( - applicationId = "app123", - sessionId = sessionId, - state = "active", - lastUpdateTime = System.currentTimeMillis(), - jobId = "jobABC", - excludedJobIds = Seq(jobId)) - - // Mock the sessionManager to return the mocked interactiveSession - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -608,34 +625,52 @@ class FlintREPLTest val exception = new RuntimeException( new ConnectException( "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) when(mockReader.hasNext).thenThrow(exception) + when(mockOSClient.getIndexMetadata(any[String])).thenReturn(FlintREPL.resultIndexMapping) + val maxRetries = 1 var actualRetries = 0 - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" val applicationId = "testApplicationId" + val sessionIndex = "sessionIndex" + val lastUpdateTime = System.currentTimeMillis() + + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), 60, @@ -656,14 +691,15 @@ class FlintREPLTest } } - ignore("executeAndHandle should handle TimeoutException properly") { + test("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn("someSessionIndex") - val mockStatementsExecutionManager = mock[StatementExecutionManager] // val mockExecutionContextExecutor: ExecutionContextExecutor = mock[ExecutionContextExecutor] val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -688,19 +724,35 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val sparkContext = mock[SparkContext] when(mockSparkSession.sparkContext).thenReturn(sparkContext) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( + applicationId, + jobId, + mockSparkSession, + dataSource, + sessionId, + sessionManager, + queryResultWriter, + Duration(10, MINUTES), + 60, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + val statementExecutionManager = new StatementExecutionManagerImpl(commandContext) + val result = FlintREPL.executeAndHandle( + applicationId, + jobId, mockSparkSession, flintStatement, - mockStatementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, startTime, - // make sure it times out before mockSparkSession.sql can return, which takes 60 seconds Duration(1, SECONDS), 600000) @@ -712,13 +764,14 @@ class FlintREPLTest } finally threadPool.shutdown() } - ignore("executeAndHandle should handle ParseException properly") { + test("executeAndHandle should handle ParseException properly") { val mockSparkSession = mock[SparkSession] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) - val mockStatementsExecutionManager = mock[StatementExecutionManager] + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn("someSessionIndex") val flintStatement = new FlintStatement( @@ -748,15 +801,33 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( + applicationId, + jobId, + mockSparkSession, + dataSource, + sessionId, + sessionManager, + queryResultWriter, + Duration(10, MINUTES), + 60, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + val statementExecutionManager = new StatementExecutionManagerImpl(commandContext) + val result = FlintREPL.executeAndHandle( + applicationId, + jobId, mockSparkSession, flintStatement, - mockStatementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, startTime, - Duration.Inf, // Use Duration.Inf or a large enough duration to avoid a timeout, + Duration.Inf, 600000) // Verify that ParseException was caught and handled @@ -767,22 +838,35 @@ class FlintREPLTest } test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) when(getResponse.getSourceAsMap).thenReturn( Map[String, Object]( "applicationId" -> "app1", "jobId" -> "job1", - "sessionId" -> "session1", + "sessionId" -> sessionId, "lastUpdateTime" -> java.lang.Long.valueOf(12345L), "error" -> "someError", "state" -> "running", @@ -790,21 +874,10 @@ class FlintREPLTest when(getResponse.getSeqNo).thenReturn(0L) when(getResponse.getPrimaryTerm).thenReturn(0L) - val interactiveSession = new InteractiveSession( - "app123", - jobId, - sessionId, - SessionStates.RUNNING, - lastUpdateTime = java.lang.Long.valueOf(12345L)) - - when(sessionManager.getSessionDetails(sessionId)) - .thenReturn(Some(interactiveSession)) - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - sessionId, + conf, + "session1", jobId, applicationId, sessionManager, @@ -812,33 +885,36 @@ class FlintREPLTest assert(!result) // Expecting false as the job should proceed normally } - ignore("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) - when(osClient.getDoc(*, *)).thenReturn(getResponse) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the rest of the GetResponse as needed - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") - val interactiveSession = new InteractiveSession( - "app123", - jobId, - sessionId, - SessionStates.RUNNING, - excludedJobIds = Seq(jobId), - lastUpdateTime = java.lang.Long.valueOf(12345L)) - - when(sessionManager.getSessionDetails(sessionId)) - .thenReturn(Some(interactiveSession)) + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", jobId) val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, + conf, sessionId, jobId, applicationId, @@ -848,16 +924,27 @@ class FlintREPLTest } test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) - when(osClient.getDoc(*, *)).thenReturn(getResponse) + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + // Mock the GetResponse to simulate a scenario of a duplicate job when(getResponse.getSourceAsMap).thenReturn( Map[String, Object]( @@ -872,20 +959,10 @@ class FlintREPLTest .asList("job-2", "job-1") // Include this inside the Map ).asJava) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") - val interactiveSession = new InteractiveSession( - "app1", - jobId, - sessionId, - SessionStates.RUNNING, - excludedJobIds = Seq("job-1", "job-2"), - lastUpdateTime = java.lang.Long.valueOf(12345L)) - // Mock sessionManager to return sessionDetails - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, + conf, sessionId, jobId, applicationId, @@ -895,31 +972,40 @@ class FlintREPLTest } test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) - when(osClient.getDoc(*, *)).thenReturn(getResponse) + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") - val interactiveSession = new InteractiveSession( - "app1", - jobId, - sessionId, - SessionStates.RUNNING, - excludedJobIds = Seq("job-5", "job-6"), - lastUpdateTime = java.lang.Long.valueOf(12345L)) - // Mock sessionManager to return sessionDetails - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, + conf, sessionId, jobId, applicationId, @@ -928,66 +1014,59 @@ class FlintREPLTest assert(!result) // Expecting false as the job proceeds normally } - ignore( - "setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") { - val osClient = mock[OSClient] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - val applicationId = "app1" - val jobId = "job1" - val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] - - assertThrows[NoSuchElementException] { - FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - sessionId, - jobId, - applicationId, - sessionManager, - System.currentTimeMillis()) - } - } - - ignore("queryLoop continue until inactivity limit is reached") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - + test("queryLoop continue until inactivity limit is reached") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) val shortInactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - // Mock processCommands to always allow loop continuation - when(statementLifecycleManager.getNextStatement()).thenReturn(None) - val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -1001,55 +1080,62 @@ class FlintREPLTest spark.stop() } - ignore("queryLoop should stop when canPickUpNextStatement is false") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(true) - - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] - + test("queryLoop should stop when canPickUpNextStatement is false") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + + // Mocking canPickNextStatement to return false + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { + val mockGetResponse = mock[GetResponse] + when(mockGetResponse.isExists()).thenReturn(true) + when(mockGetResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> "differentJobId", + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + mockGetResponse + }) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(true) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + val longInactivityLimit = 10000 // 10 seconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - // Mocking canPickNextStatement to return false - when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { - val mockGetResponse = mock[GetResponse] - when(mockGetResponse.isExists()).thenReturn(true) - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", "differentJobId") - when(mockGetResponse.getSourceAsMap).thenReturn(sourceMap) - mockGetResponse - }) - - // Mock getNextStatement to return None, simulating the end of statements - when(statementLifecycleManager.getNextStatement()).thenReturn(None) - val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -1063,38 +1149,52 @@ class FlintREPLTest spark.stop() } - ignore("queryLoop should properly shut down the thread pool after execution") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - + test("queryLoop should properly shut down the thread pool after execution") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] - when(statementLifecycleManager.getNextStatement()).thenReturn(None) + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit, @@ -1118,36 +1218,53 @@ class FlintREPLTest } test("queryLoop handle exceptions within the loop gracefully") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - // Simulate an exception thrown when hasNext is called - when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + // Simulate an exception thrown when hasNext is called + when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit, @@ -1172,16 +1289,25 @@ class FlintREPLTest } } - ignore("queryLoop should correctly update loop control variables") { + test("queryLoop should correctly update loop control variables") { + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) - when(osClient.doesIndexExist(*)).thenReturn(true) - when(osClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) // Configure mockReader to return true once and then false to exit the loop when(mockReader.hasNext).thenReturn(true).thenReturn(false) @@ -1197,12 +1323,6 @@ class FlintREPLTest """ when(mockReader.next).thenReturn(command) - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" - val inactivityLimit = 5000 // 5 seconds // Create a SparkSession for testing\ @@ -1217,22 +1337,26 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + when(mockSparkSession.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, "")) + .thenReturn("") when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] + val sessionManager = new SessionManagerImpl(mockSparkSession, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, mockSparkSession, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit, @@ -1248,7 +1372,10 @@ class FlintREPLTest // Assuming processCommands updates the lastActivityTime to the current time assert(endTime - startTime >= inactivityLimit) - verify(osClient, times(1)).getIndexMetadata(*) + + val expectedCalls = + Math.ceil(inactivityLimit.toDouble / DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY).toInt + verify(mockOSClient, Mockito.atMost(expectedCalls)).getIndexMetadata(*) } val testCases = Table( @@ -1257,40 +1384,54 @@ class FlintREPLTest (100, 300L) // 100 ms, 300 ms ) - ignore( + test( "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) - when(mockReader.hasNext).thenReturn(false) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + when( + mockOSClient + .createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + when(getResponse.isExists()).thenReturn(false) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when(mockReader.hasNext).thenReturn(false) // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit,