From a9651be83bcc2b1c39ace129439b59cad3b09e22 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Fri, 9 Aug 2024 14:24:41 -0700 Subject: [PATCH 1/7] REPL refactor Signed-off-by: Louis Chu --- .../apache/spark/sql/QueryResultWriter.scala | 12 + .../org/apache/spark/sql/SessionManager.scala | 48 ++ .../spark/sql/StatementLifecycleManager.scala | 31 + .../sql/flint/config/FlintSparkConf.scala | 9 + .../opensearch/flint/OpenSearchSuite.scala | 5 +- .../org/apache/spark/sql/CommandContext.scala | 7 +- .../org/apache/spark/sql/CommandState.scala | 3 +- .../scala/org/apache/spark/sql/FlintJob.scala | 7 +- .../apache/spark/sql/FlintJobExecutor.scala | 19 +- .../org/apache/spark/sql/FlintREPL.scala | 582 +++++++----------- .../scala/org/apache/spark/sql/OSClient.scala | 1 + .../spark/sql/QueryResultWriterImpl.scala | 21 + .../apache/spark/sql/SessionManagerImpl.scala | 175 ++++++ .../sql/StatementLifecycleManagerImpl.scala | 59 ++ .../org/apache/spark/sql/FlintREPLTest.scala | 358 +++++++---- 15 files changed, 815 insertions(+), 522 deletions(-) create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala new file mode 100644 index 000000000..7ddf6604b --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +trait QueryResultWriter { + def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala new file mode 100644 index 000000000..1c270fbc1 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} + +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode + +/** + * Trait defining the interface for managing interactive sessions. + */ +trait SessionManager { + + /** + * Retrieves metadata about the session manager. + */ + def getSessionContext: Map[String, Any] + + /** + * Fetches the details of a specific session. + */ + def getSessionDetails(sessionId: String): Option[InteractiveSession] + + /** + * Updates the details of a specific session. + */ + def updateSessionDetails( + sessionDetails: InteractiveSession, + updateMode: SessionUpdateMode): Unit + + /** + * Retrieves the next statement to be executed in a specific session. + */ + def getNextStatement(sessionId: String): Option[FlintStatement] + + /** + * Records a heartbeat for a specific session to indicate it is still active. + */ + def recordHeartbeat(sessionId: String): Unit +} + +object SessionUpdateMode extends Enumeration { + type SessionUpdateMode = Value + val UPDATE, UPSERT, UPDATE_IF = Value +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala new file mode 100644 index 000000000..18bf1d819 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +/** + * Trait defining the interface for managing the lifecycle of executing a FlintStatement. + */ +trait StatementLifecycleManager { + + /** + * Prepares the statement lifecycle. + */ + def prepareStatementLifecycle(): Either[String, Unit] + +// def executeStatement(statement: FlintStatement): DataFrame + + /** + * Updates a specific statement. + */ + def updateStatement(statement: FlintStatement): Unit + + /** + * Terminates the statement lifecycle. + */ + def terminateStatementLifecycle(): Unit +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 1d12d004e..0a40e9b1b 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -231,6 +231,15 @@ object FlintSparkConf { FlintConfig("spark.metadata.accessAWSCredentialsProvider") .doc("AWS credentials provider for metadata access permission") .createOptional() + val CUSTOM_SESSION_MANAGER = + FlintConfig("spark.flint.job.customSessionManager") + .createOptional() + val CUSTOM_STATEMENT_MANAGER = + FlintConfig("spark.flint.job.customStatementManager") + .createOptional() + val CUSTOM_QUERY_RESULT_WRITER = + FlintConfig("spark.flint.job.customQueryResultWriter") + .createOptional() } /** diff --git a/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala index e1e967ded..cde3230d4 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala @@ -16,12 +16,13 @@ import org.opensearch.common.xcontent.XContentType import org.opensearch.testcontainers.OpenSearchContainer import org.scalatest.{BeforeAndAfterAll, Suite} +import org.apache.spark.internal.Logging import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT, IGNORE_DOC_ID_COLUMN, REFRESH_POLICY} /** * Test required OpenSearch domain should extend OpenSearchSuite. */ -trait OpenSearchSuite extends BeforeAndAfterAll { +trait OpenSearchSuite extends BeforeAndAfterAll with Logging { self: Suite => protected lazy val container = new OpenSearchContainer() @@ -145,7 +146,7 @@ trait OpenSearchSuite extends BeforeAndAfterAll { val response = openSearchClient.bulk(request, RequestOptions.DEFAULT) - + logInfo(response.toString) assume( !response.hasFailures, s"bulk index docs to $index failed: ${response.buildFailureMessage()}") diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 048f69ced..42e4f82d8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -13,12 +13,11 @@ import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} case class CommandContext( spark: SparkSession, dataSource: String, - resultIndex: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, + sessionManager: SessionManager, jobId: String, + statementLifecycleManager: StatementLifecycleManager, + queryResultWriter: QueryResultWriter, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala index ad49201f0..45b7e81cc 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala @@ -12,7 +12,6 @@ import org.opensearch.flint.core.storage.FlintReader case class CommandState( recordedLastActivityTime: Long, recordedVerificationResult: VerificationResult, - flintReader: FlintReader, - futureMappingCheck: Future[Either[String, Unit]], + futurePrepareQueryExecution: Future[Either[String, Unit]], executionContext: ExecutionContextExecutor, recordedLastCanPickCheckTime: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index bba999110..1278657cb 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ */ object FlintJob extends Logging with FlintJobExecutor { def main(args: Array[String]): Unit = { - val (queryOption, resultIndex) = parseArgs(args) + val (queryOption, resultIndexOption) = parseArgs(args) val conf = createSparkConf() val jobType = conf.get("spark.flint.job.type", "batch") @@ -41,6 +41,9 @@ object FlintJob extends Logging with FlintJobExecutor { if (query.isEmpty) { logAndThrow(s"Query undefined for the ${jobType} job.") } + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -58,7 +61,7 @@ object FlintJob extends Logging with FlintJobExecutor { createSparkSession(conf), query, dataSource, - resultIndex, + resultIndexOption.get, jobType.equalsIgnoreCase("streaming"), streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 37801a9e8..9b3841c7d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -493,16 +493,21 @@ trait FlintJobExecutor { } } - def parseArgs(args: Array[String]): (Option[String], String) = { + /** + * Before OS 2.13, there are two arguments from entry point: query and result index Starting + * from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also + * optional for non-OpenSearch result persist + */ + def parseArgs(args: Array[String]): (Option[String], Option[String]) = { args match { + case Array() => + (None, None) case Array(resultIndex) => - (None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument + (None, Some(resultIndex)) case Array(query, resultIndex) => - ( - Some(query), - resultIndex - ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.") + (Some(query), Some(resultIndex)) + case _ => + logAndThrow("Unsupported number of arguments. Expected no more than two arguments.") } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index e6b8b11ce..bc5782fc5 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,7 +18,7 @@ import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings -import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging @@ -31,8 +31,9 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.FlintREPLConfConstants._ +import org.apache.spark.sql.SessionUpdateMode.UPDATE_IF import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} object FlintREPLConfConstants { val HEARTBEAT_INTERVAL_MILLIS = 60000L @@ -61,19 +62,11 @@ object FlintREPL extends Logging with FlintJobExecutor { @volatile var earlyExitFlag: Boolean = false - def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { - updater.update(flintStatement.statementId, FlintStatement.serialize(flintStatement)) - } - private val sessionRunningCount = new AtomicInteger(0) private val statementRunningCount = new AtomicInteger(0) def main(args: Array[String]) { - val (queryOption, resultIndex) = parseArgs(args) - - if (Strings.isNullOrEmpty(resultIndex)) { - logAndThrow("resultIndex is not set") - } + val (queryOption, resultIndexOption) = parseArgs(args) // init SparkContext val conf: SparkConf = createSparkConf() @@ -95,7 +88,9 @@ object FlintREPL extends Logging with FlintJobExecutor { val query = getQuery(queryOption, jobType, conf) if (jobType.equalsIgnoreCase("streaming")) { - logInfo(s"""streaming query ${query}""") + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } configDYNMaxExecutors(conf, jobType) val streamingRunningCount = new AtomicInteger(0) val jobOperator = @@ -103,25 +98,18 @@ object FlintREPL extends Logging with FlintJobExecutor { createSparkSession(conf), query, dataSource, - resultIndex, + resultIndexOption.get, true, streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } else { - // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) - val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) - - if (sessionIndex.isEmpty) { - logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") - } - if (sessionId.isEmpty) { - logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") - } - + // we don't allow default value for sessionId. Throw exception if key not found. + val sessionId = getSessionId(conf) + logInfo(s"sessionId: ${sessionId}") val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) + val sessionManager = instantiateSessionManager(spark, resultIndexOption) + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") @@ -142,7 +130,6 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong( "spark.flint.job.queryLoopExecutionFrequency", DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -155,12 +142,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/320 */ spark.sparkContext.addSparkListener( - new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionId.get, - sessionTimerContext)) + new PreShutdownListener(sessionId, sessionManager, sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -173,37 +155,31 @@ object FlintREPL extends Logging with FlintJobExecutor { // OpenSearch triggers recovery after 1 minute outdated heart beat var heartBeatFuture: ScheduledFuture[_] = null try { - heartBeatFuture = createHeartBeatUpdater( - HEARTBEAT_INTERVAL_MILLIS, - flintSessionIndexUpdater, - sessionId.get, - threadPool, - osClient, - sessionIndex.get, - INITIAL_DELAY_MILLIS) + heartBeatFuture = createHeartBeatUpdater(sessionId, sessionManager, threadPool) if (setupFlintJobWithExclusionCheck( conf, - sessionIndex, sessionId, - osClient, jobId, applicationId, - flintSessionIndexUpdater, + sessionManager, jobStartTime)) { earlyExitFlag = true return } + val statementLifecycleManager = + instantiateStatementLifecycleManager(conf, sessionManager.getSessionContext) + val queryResultWriter = + instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( spark, dataSource, - resultIndex, - sessionId.get, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, + sessionId, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis, @@ -218,11 +194,9 @@ object FlintREPL extends Logging with FlintJobExecutor { e, applicationId, jobId, - sessionId.get, + sessionId, + sessionManager, jobStartTime, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, sessionTimerContext) } finally { if (threadPool != null) { @@ -295,25 +269,16 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def setupFlintJobWithExclusionCheck( conf: SparkConf, - sessionIndex: Option[String], - sessionId: Option[String], - osClient: OSClient, + sessionId: String, jobId: String, applicationId: String, - flintSessionIndexUpdater: OpenSearchUpdater, + sessionManager: SessionManager, jobStartTime: Long): Boolean = { val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) - confExcludeJobsOpt match { case None => // If confExcludeJobs is None, pass null or an empty sequence as per your setupFlintJob method's signature - setupFlintJob( - applicationId, - jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, - jobStartTime) + setupFlintJob(applicationId, jobId, sessionId, sessionManager, jobStartTime) case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 @@ -324,25 +289,22 @@ object FlintREPL extends Logging with FlintJobExecutor { return true } - val getResponse = osClient.getDoc(sessionIndex.get, sessionId.get) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source != null) { - val existingExcludedJobIds = parseExcludedJobIds(source) + sessionManager.getSessionDetails(sessionId) match { + case Some(sessionDetails) => + val existingExcludedJobIds = sessionDetails.excludedJobIds if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { logInfo("duplicate job running, exit the application.") return true } - } + case _ => // Do nothing } // If none of the edge cases are met, proceed with setup setupFlintJob( applicationId, jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, + sessionId, + sessionManager, jobStartTime, excludeJobIds) } @@ -350,14 +312,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { + import commandContext._ // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - var futureMappingCheck: Future[Either[String, Unit]] = null + var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { - checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) + futurePrepareQueryExecution = Future { + statementLifecycleManager.prepareStatementLifecycle() } var lastActivityTime = currentTimeProvider.currentEpochMillis() @@ -366,21 +329,13 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo( - s"""read from ${commandContext.sessionIndex}, sessionId: ${commandContext.sessionId}""") - val flintReader: FlintReader = - createQueryReader( - commandContext.osClient, - commandContext.sessionId, - commandContext.sessionIndex, - commandContext.dataSource) + logInfo(s"""Executing session with sessionId: ${sessionId}""") try { val commandState = CommandState( lastActivityTime, verificationResult, - flintReader, - futureMappingCheck, + futurePrepareQueryExecution, executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = @@ -397,7 +352,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - flintReader.close() + statementLifecycleManager.terminateStatementLifecycle() } Thread.sleep(commandContext.queryLoopExecutionFrequency) @@ -413,84 +368,72 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId: String, jobId: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionIndex: String, + sessionManager: SessionManager, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + val flintJob = refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) + SessionStates.RUNNING, + excludedJobIds = excludeJobIds) - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo( - s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") sessionRunningCount.incrementAndGet() } + private def refreshSessionState( + applicationId: String, + jobId: String, + sessionId: String, + sessionManager: SessionManager, + jobStartTime: Long, + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + sessionDetails.state = state + sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionDetails + } + def handleSessionError( e: Exception, applicationId: String, jobId: String, sessionId: String, + sessionManager: SessionManager, jobStartTime: Long, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, sessionTimerContext: Timer.Context): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - val flintInstance = getExistingFlintInstance(osClient, sessionIndex, sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) - if (flintInstance.isFail) { - recordSessionFailed(sessionTimerContext) - } - } - - private def getExistingFlintInstance( - osClient: OSClient, - sessionIndex: String, - sessionId: String): Option[InteractiveSession] = Try( - osClient.getDoc(sessionIndex, sessionId)) match { - case Success(getResponse) if getResponse.isExists() => - Option(getResponse.getSourceAsMap) - .map(InteractiveSession.deserializeFromMap) - case Failure(exception) => - CustomLogging.logError( - s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", - exception) - None - case _ => None + refreshSessionState( + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } - private def createFailedFlintInstance( - applicationId: String, - jobId: String, - sessionId: String, - jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - private def updateFlintInstance( flintInstance: InteractiveSession, flintSessionIndexUpdater: OpenSearchUpdater, @@ -565,43 +508,28 @@ object FlintREPL extends Logging with FlintJobExecutor { // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { - canPickNextStatementResult = - canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + canPickNextStatementResult = canPickNextStatement(sessionId, sessionManager, jobId) lastCanPickCheckTime = currentTime } if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { - val statementTimerContext = getTimerContext( - MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) - - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex, - queryExecutionTimeout, - queryWaitTimeMillis) - - verificationResult = returnedVerificationResult - finalizeCommand( - dataToWrite, - flintStatement, - resultIndex, - flintSessionIndexUpdater, - osClient, - statementTimerContext) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() + sessionManager.getNextStatement(sessionId) match { + case Some(flintStatement) => + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) + val (dataToWrite, returnedVerificationResult) = + processStatementOnVerification(flintStatement, state, context) + + verificationResult = returnedVerificationResult + finalizeCommand(context, dataToWrite, flintStatement, statementTimerContext) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + case None => + canProceed = false + } } } @@ -610,32 +538,25 @@ object FlintREPL extends Logging with FlintJobExecutor { } /** - * finalize command after processing + * finalize statement after processing * * @param dataToWrite * data to write * @param flintStatement - * flint command - * @param resultIndex - * result index - * @param flintSessionIndexUpdater - * flint session index updater + * flint statement */ private def finalizeCommand( + commandContext: CommandContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, - resultIndex: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, statementTimerContext: Timer.Context): Unit = { + import commandContext._ try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling flintStatement.complete() } - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -643,8 +564,9 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) flintStatement.fail() - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) + } finally { + statementLifecycleManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) } } @@ -720,16 +642,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processStatementOnVerification( - recordedVerificationResult: VerificationResult, - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String, - queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long) = { + commandState: CommandState, + commandContext: CommandContext) = { + import commandState._ + import commandContext._ + val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None @@ -737,7 +655,7 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult match { case NotVerified => try { - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { case Right(_) => dataToWrite = executeAndHandle( spark, @@ -762,7 +680,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } catch { case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" + val error = s"Query execution preparation timed out" CustomLogging.logError(error, e) dataToWrite = Some( handleCommandTimeout( @@ -842,70 +760,10 @@ object FlintREPL extends Logging with FlintJobExecutor { ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) } } - private def processCommandInitiation( - flintReader: FlintReader, - flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { - val command = flintReader.next() - logDebug(s"raw command: $command") - val flintStatement = FlintStatement.deserialize(command) - logDebug(s"command: $flintStatement") - flintStatement.running() - logDebug(s"command running: $flintStatement") - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - statementRunningCount.incrementAndGet() - flintStatement - } - - private def createQueryReader( - osClient: OSClient, - sessionId: String, - sessionIndex: String, - dataSource: String) = { - // all state in index are in lower case - // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the - // same doc - val dsl = - s"""{ - | "bool": { - | "must": [ - | { - | "term": { - | "type": "statement" - | } - | }, - | { - | "term": { - | "state": "waiting" - | } - | }, - | { - | "term": { - | "sessionId": "$sessionId" - | } - | }, - | { - | "term": { - | "dataSourceName": "$dataSource" - | } - | }, - | { - | "range": { - | "submitTime": { "gte": "now-1h" } - | } - | } - | ] - | } - |}""".stripMargin - - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader - } class PreShutdownListener( - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, sessionId: String, + sessionManager: SessionManager, sessionTimerContext: Timer.Context) extends SparkListener with Logging { @@ -913,77 +771,42 @@ object FlintREPL extends Logging with FlintJobExecutor { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (!getResponse.isExists()) { - return - } - - val source = getResponse.getSourceAsMap - if (source == null) { - return - } - - val state = Option(source.get("state")).map(_.asInstanceOf[String]) - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isComplete && !sessionDetails.isFail) { + sessionDetails.complete() + sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + recordSessionSuccess(sessionTimerContext) + } + } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } - private def updateFlintInstanceBeforeShutdown( - source: java.util.Map[String, AnyRef], - getResponse: GetResponse, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String, - sessionTimerContext: Timer.Context): Unit = { - val flintInstance = InteractiveSession.deserializeFromMap(source) - flintInstance.complete() - flintSessionIndexUpdater.updateIf( - sessionId, - InteractiveSession.serializeWithoutJobId( - flintInstance, - currentTimeProvider.currentEpochMillis()), - getResponse.getSeqNo, - getResponse.getPrimaryTerm) - recordSessionSuccess(sessionTimerContext) - } - /** - * Create a new thread to update the last update time of the flint instance. - * @param currentInterval - * the interval of updating the last update time. Unit is millisecond. - * @param flintSessionUpdater - * the updater of the flint instance. + * Create a new thread to update the last update time of the flint interactive session. + * * @param sessionId - * the session id of the flint instance. + * the session id of the flint interactive session. + * @param sessionManager + * the manager of the flint interactive session. * @param threadPool * the thread pool. - * @param osClient - * the OpenSearch client. - * @param initialDelayMillis - * the intial delay to start heartbeat */ def createHeartBeatUpdater( - currentInterval: Long, - flintSessionUpdater: OpenSearchUpdater, sessionId: String, - threadPool: ScheduledExecutorService, - osClient: OSClient, - sessionIndex: String, - initialDelayMillis: Long): ScheduledFuture[_] = { + sessionManager: SessionManager, + threadPool: ScheduledExecutorService): ScheduledFuture[_] = { threadPool.scheduleAtFixedRate( new Runnable { @@ -994,13 +817,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logWarning("HeartBeatUpdater has been interrupted. Terminating.") return // Exit the run method if the thread is interrupted } - - flintSessionUpdater.upsert( - sessionId, - Serialization.write( - Map( - "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), - "state" -> "running"))) + sessionManager.recordHeartbeat(sessionId) } catch { case ie: InterruptedException => // Preserve the interrupt status @@ -1020,8 +837,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } } }, - initialDelayMillis, - currentInterval, + INITIAL_DELAY_MILLIS, + HEARTBEAT_INTERVAL_MILLIS, java.util.concurrent.TimeUnit.MILLISECONDS) } @@ -1038,35 +855,26 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def canPickNextStatement( sessionId: String, - jobId: String, - osClient: OSClient, - sessionIndex: String): Boolean = { + sessionManager: SessionManager, + jobId: String): Boolean = { try { - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source == null) { - logError(s"""Session id ${sessionId} is empty""") - // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) - return true - } - - val runJobId = Option(source.get("jobId")).map(_.asInstanceOf[String]).orNull - val excludeJobIds: Seq[String] = parseExcludedJobIds(source) - - if (runJobId != null && jobId != runJobId) { - logInfo(s"""the current job ID ${jobId} is not the running job ID ${runJobId}""") - return false - } - if (excludeJobIds != null && excludeJobIds.contains(jobId)) { - logInfo(s"""${jobId} is in the list of excluded jobs""") - return false - } - true - } else { - // still proceed since we are not sure what happened (e.g., session doc may not be available yet) - logError(s"""Fail to find id ${sessionId} from session index""") - true + sessionManager.getSessionDetails(sessionId) match { + case Some(sessionDetails) => + val runJobId = sessionDetails.jobId + val excludeJobIds = sessionDetails.excludedJobIds + + if (!runJobId.isEmpty && jobId != runJobId) { + logInfo(s"the current job ID $jobId is not the running job ID ${runJobId}") + return false + } + if (excludeJobIds.contains(jobId)) { + logInfo(s"$jobId is in the list of excluded jobs") + return false + } + true + case None => + logError(s"Failed to fetch sessionDetails by sessionId: $sessionId.") + true } } catch { // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) @@ -1076,23 +884,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def parseExcludedJobIds(source: java.util.Map[String, AnyRef]): Seq[String] = { - - val rawExcludeJobIds = source.get("excludeJobIds") - Option(rawExcludeJobIds) - .map { - case s: String => Seq(s) - case list: java.util.List[_] @unchecked => - import scala.collection.JavaConverters._ - list.asScala.toList - .collect { case str: String => str } // Collect only strings from the list - case other => - logInfo(s"Unexpected type: ${other.getClass.getName}") - Seq.empty - } - .getOrElse(Seq.empty[String]) // In case of null, return an empty Seq - } - def exponentialBackoffRetry[T](maxRetries: Int, initialDelay: FiniteDuration)( block: => T): T = { var retries = 0 @@ -1130,6 +921,55 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } + private def getSessionId(conf: SparkConf): String = { + val sessionIdOption: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) + if (sessionIdOption.isEmpty) { + logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") + } + sessionIdOption.get + } + + private def instantiate[T](defaultConstructor: => T, className: String): T = { + if (className.isEmpty) { + logInfo("Using default constructor") + defaultConstructor + } else { + try { + val classObject = Utils.classForName(className) + val ctor = classObject.getDeclaredConstructor() + ctor.setAccessible(true) + ctor.newInstance().asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to instantiate provider: $className", e) + } + } + } + + private def instantiateSessionManager( + spark: SparkSession, + resultIndex: Option[String]): SessionManager = { + instantiate( + new SessionManagerImpl(spark, resultIndex), + spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) + } + + private def instantiateStatementLifecycleManager( + sparkConf: SparkConf, + context: Map[String, Any]): StatementLifecycleManager = { + instantiate( + new StatementLifecycleManagerImpl(context), + sparkConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, "")) + } + + private def instantiateQueryResultWriter( + sparkConf: SparkConf, + context: Map[String, Any]): QueryResultWriter = { + instantiate( + new QueryResultWriterImpl(context), + sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index 422cfc947..999742e67 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -121,6 +121,7 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { case Success(response) => IRestHighLevelClient.recordOperationSuccess( MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX) + logInfo(response.toString) response case Failure(e: Exception) => IRestHighLevelClient.recordOperationFailure( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala new file mode 100644 index 000000000..8d07e91ae --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch + +class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + + override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { + writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala new file mode 100644 index 000000000..a34dc8e61 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -0,0 +1,175 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.{Failure, Success, Try} + +import org.json4s.native.Serialization +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.common.model.InteractiveSession.formats +import org.opensearch.flint.core.logging.CustomLogging +import org.opensearch.flint.core.storage.FlintReader +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode +import org.apache.spark.sql.flint.config.FlintSparkConf + +class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) + extends SessionManager + with FlintJobExecutor + with Logging { + + // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key) + val sessionId: String = spark.conf.get(FlintSparkConf.SESSION_ID.key) + val dataSource: String = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + + if (sessionIndex.isEmpty) { + logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") + } + if (resultIndex.isEmpty) { + logAndThrow("resultIndex is not set") + } + if (sessionId.isEmpty) { + logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") + } + if (dataSource.isEmpty) { + logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") + } + + val osClient = new OSClient(FlintSparkConf().flintOptions()) + val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + val flintReader: FlintReader = createOpenSearchQueryReader() + + override def getSessionContext: Map[String, Any] = { + Map( + "resultIndex" -> resultIndex.get, + "osClient" -> osClient, + "flintSessionIndexUpdater" -> flintSessionIndexUpdater, + "flintReader" -> flintReader) + } + + override def getSessionDetails(sessionId: String): Option[InteractiveSession] = { + Try(osClient.getDoc(sessionIndex, sessionId)) match { + case Success(getResponse) if getResponse.isExists => + Option(getResponse.getSourceAsMap) + .map(InteractiveSession.deserializeFromMap) + case Failure(exception) => + CustomLogging.logError( + s"Failed to retrieve existing InteractiveSession: ${exception.getMessage}", + exception) + None + case _ => None + } + } + + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.UPDATE => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.UPSERT => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("_seq_no") + .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .asInstanceOf[Long] + val primaryTerm = sessionDetails + .getContextValue("_primary_term") + .getOrElse( + throw new IllegalArgumentException("Missing _primary_term for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } + + override def getNextStatement(sessionId: String): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + logDebug(s"raw statement: $rawStatement") + val flintStatement = FlintStatement.deserialize(rawStatement) + logDebug(s"statement: $flintStatement") + Some(flintStatement) + } else { + None + } + } + + override def recordHeartbeat(sessionId: String): Unit = { + flintSessionIndexUpdater.upsert( + sessionId, + Serialization.write( + Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) + } + + private def createOpenSearchQueryReader() = { + // all state in index are in lower case + // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the + // same doc + val dsl = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "type": "statement" + | } + | }, + | { + | "term": { + | "state": "waiting" + | } + | }, + | { + | "term": { + | "sessionId": "$sessionId" + | } + | }, + | { + | "term": { + | "dataSourceName": "$dataSource" + | } + | }, + | { + | "range": { + | "submitTime": { "gte": "now-1h" } + | } + | } + | ] + | } + |}""".stripMargin + + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) + flintReader + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala new file mode 100644 index 000000000..7c8fb3457 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.{createResultIndex, isSuperset, resultIndexMapping} +import org.apache.spark.sql.FlintREPL.executeQuery + +class StatementLifecycleManagerImpl(context: Map[String, Any]) + extends StatementLifecycleManager + with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + val flintSessionIndexUpdater = + context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] + val flintReader = context("flintReader").asInstanceOf[FlintReader] + + override def prepareStatementLifecycle(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + override def updateStatement(statement: FlintStatement): Unit = { + flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) + } + override def terminateStatementLifecycle(): Unit = { + flintReader.close() + } + +// override def executeStatement(statement: FlintStatement): DataFrame = { +// executeQuery() +// } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 9c193fc9a..acc6f6bcd 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -22,7 +22,7 @@ import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse -import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.scalatest.prop.TableDrivenPropertyChecks._ @@ -159,18 +159,18 @@ class FlintREPLTest test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks - val flintSessionUpdater = mock[OpenSearchUpdater] - val osClient = mock[OSClient] val threadPool = mock[ScheduledExecutorService] - val getResponse = mock[GetResponse] val scheduledFutureRaw = mock[ScheduledFuture[_]] - + val sessionManager = mock[SessionManager] + val sessionId = "session1" + val currentInterval = 1000L + val initialDelayMillis = 0L // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. when( threadPool.scheduleAtFixedRate( any[Runnable], - eqTo(0), - *, + eqTo(initialDelayMillis), + eqTo(currentInterval), eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) .thenAnswer((invocation: InvocationOnMock) => { val runnable = invocation.getArgument[Runnable](0) @@ -179,55 +179,35 @@ class FlintREPLTest }) // Invoke the method - FlintREPL.createHeartBeatUpdater( - 1000L, - flintSessionUpdater, - "session1", - threadPool, - osClient, - "sessionIndex", - 0) + FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) // Verifications - verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) + verify(sessionManager).recordHeartbeat(sessionId) } test("PreShutdownListener updates FlintInstance if conditions are met") { // Mock dependencies - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val sessionIndex = "testIndex" val sessionId = "testSessionId" val timerContext = mock[Timer.Context] + val sessionManager = mock[SessionManager] - // Setup the getDoc to return a document indicating the session is running - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn( - Map[String, Object]( - "applicationId" -> "app1", - "jobId" -> "job1", - "sessionId" -> "session1", - "state" -> "running", - "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError", - "state" -> "running", - "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + val interactiveSession = new InteractiveSession( + "app123", + "job123", + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Instantiate the listener - val listener = new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex, - sessionId, - timerContext) + val listener = new PreShutdownListener(sessionId, sessionManager, timerContext) // Simulate application end listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) - // Verify the update is called with the correct arguments - verify(flintSessionIndexUpdater).updateIf(*, *, *, *) + verify(sessionManager).updateSessionDetails(interactiveSession, SessionUpdateMode.UPDATE_IF) + interactiveSession.state shouldBe SessionStates.DEAD } test("Test super.constructErrorDF should construct dataframe properly") { @@ -300,18 +280,18 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists and Valid JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" - - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val sessionManager = mock[SessionManager] - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) assert(result) } @@ -320,19 +300,20 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val differentJobId = "jobXYZ" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", differentJobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, differentJobId) // Assertions assert(!result) // The function should return false @@ -343,6 +324,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -356,8 +338,13 @@ class FlintREPLTest sourceMap.put("excludeJobIds", excludeJobIdsList) // But jobId is in the exclude list when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = mock[InteractiveSession] + when(interactiveSession.jobId).thenReturn("jobABC") + when(interactiveSession.excludedJobIds).thenReturn(Seq(jobId)) + // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assertions assert(!result) // The function should return false because jobId is excluded @@ -368,6 +355,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Mock the getDoc response val getResponse = mock[GetResponse] @@ -376,7 +364,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assertions assert(result) // The function should return true despite the null source @@ -387,6 +375,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -401,7 +390,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) assert(result) // The function should return true } @@ -411,6 +400,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock GetResponse val getResponse = mock[GetResponse] @@ -418,7 +408,7 @@ class FlintREPLTest when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist // Execute the function under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assert the function returns true assert(result) @@ -429,13 +419,14 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock OSClient to throw an exception when(osClient.getDoc(sessionIndex, sessionId)) .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) // Execute the method under test and expect true, since the method is designed to return true even in case of an exception - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Verify the result is true despite the exception assert(result) @@ -448,6 +439,7 @@ class FlintREPLTest val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -461,7 +453,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return true since jobId is not excluded assert(result) @@ -515,6 +507,7 @@ class FlintREPLTest val osClient = mock[OSClient] val sessionIndex = "sessionIndex" val handleSessionError = mock[Function1[String, Unit]] + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -531,8 +524,20 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = new InteractiveSession( + applicationId = "app123", + sessionId = sessionId, + state = "active", + lastUpdateTime = System.currentTimeMillis(), + jobId = "jobABC", + excludedJobIds = Seq(jobId)) + + // Mock the sessionManager to return the mocked interactiveSession + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return false since jobId is excluded assert(!result) @@ -543,6 +548,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -559,8 +565,20 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = new InteractiveSession( + applicationId = "app123", + sessionId = sessionId, + state = "active", + lastUpdateTime = System.currentTimeMillis(), + jobId = "jobABC", + excludedJobIds = Seq(jobId)) + + // Mock the sessionManager to return the mocked interactiveSession + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return true since the jobId is not in the excludeJobIds list assert(result) @@ -588,17 +606,18 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -720,12 +739,18 @@ class FlintREPLTest flintStatement.error should not be None flintStatement.error.get should include("Syntax error:") } finally threadPool.shutdown() - } test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) when(getResponse.getSourceAsMap).thenReturn( @@ -740,18 +765,24 @@ class FlintREPLTest when(getResponse.getSeqNo).thenReturn(0L) when(getResponse.getPrimaryTerm).thenReturn(0L) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + lastUpdateTime = java.lang.Long.valueOf(12345L)) + + when(sessionManager.getSessionDetails(sessionId)) + .thenReturn(Some(interactiveSession)) val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(!result) // Expecting false as the job should proceed normally } @@ -759,21 +790,34 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the rest of the GetResponse as needed - val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq(jobId), + lastUpdateTime = java.lang.Long.valueOf(12345L)) + + when(sessionManager.getSessionDetails(sessionId)) + .thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(result) // Expecting true as the job should exit early } @@ -781,6 +825,12 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the GetResponse to simulate a scenario of a duplicate job @@ -799,15 +849,22 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") + val interactiveSession = new InteractiveSession( + "app1", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq("job-1", "job-2"), + lastUpdateTime = java.lang.Long.valueOf(12345L)) + // Mock sessionManager to return sessionDetails + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(result) // Expecting true for early exit due to duplicate job } @@ -815,20 +872,33 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") + val interactiveSession = new InteractiveSession( + "app1", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq("job-5", "job-6"), + lastUpdateTime = java.lang.Long.valueOf(12345L)) + // Mock sessionManager to return sessionDetails + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(!result) // Expecting false as the job proceeds normally } @@ -838,16 +908,19 @@ class FlintREPLTest val osClient = mock[OSClient] val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] assertThrows[NoSuchElementException] { FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - None, // No sessionIndex provided - None, // No sessionId provided - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) } } @@ -870,26 +943,25 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mock processCommands to always allow loop continuation - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) val startTime = System.currentTimeMillis() @@ -911,6 +983,10 @@ class FlintREPLTest .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(true) + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" @@ -926,12 +1002,11 @@ class FlintREPLTest val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, @@ -947,6 +1022,9 @@ class FlintREPLTest mockGetResponse }) + // Mock getNextStatement to return None, simulating the end of statements + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -973,6 +1051,11 @@ class FlintREPLTest val sessionId = "testSessionId" val jobId = "testJobId" + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing @@ -983,12 +1066,11 @@ class FlintREPLTest val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1028,18 +1110,20 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val flintSessionIndexUpdater = mock[OpenSearchUpdater] val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1113,15 +1197,18 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( mockSparkSession, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1163,6 +1250,10 @@ class FlintREPLTest val sessionId = "testSessionId" val jobId = "testJobId" + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() @@ -1171,12 +1262,11 @@ class FlintREPLTest val commandContext = CommandContext( spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + sessionManager, jobId, + statementLifecycleManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, From b321af1028f221e1a7f500cbdc3e19c091723270 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 15 Aug 2024 13:05:29 -0700 Subject: [PATCH 2/7] Ignore UT Signed-off-by: Louis Chu --- .../src/main/scala/org/apache/spark/sql/FlintREPL.scala | 7 +++++++ .../test/scala/org/apache/spark/sql/FlintREPLTest.scala | 2 ++ 2 files changed, 9 insertions(+) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index bc5782fc5..3ff9b0dcc 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -407,7 +407,9 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime, error = error, excludedJobIds = excludedJobIds)) + logInfo(s"State is: ${sessionDetails.state}") sessionDetails.state = state + logInfo(s"State is: ${sessionDetails.state}") sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) sessionDetails } @@ -518,6 +520,11 @@ object FlintREPL extends Logging with FlintJobExecutor { } else { sessionManager.getNextStatement(sessionId) match { case Some(flintStatement) => + flintStatement.running() + logDebug(s"command running: $flintStatement") + statementLifecycleManager.updateStatement(flintStatement) + statementRunningCount.incrementAndGet() + val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val (dataToWrite, returnedVerificationResult) = diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index acc6f6bcd..c1423c458 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -25,6 +25,7 @@ import org.opensearch.action.get.GetResponse import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder +import org.scalatest.Ignore import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar @@ -40,6 +41,7 @@ import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils +@Ignore class FlintREPLTest extends SparkFunSuite with MockitoSugar From f62032235ae1c666cee666087677a5cb12874a8b Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 15 Aug 2024 14:55:21 -0700 Subject: [PATCH 3/7] Add seqNo and primaryTerm to session object Signed-off-by: Louis Chu --- .../flint/core/storage/OpenSearchReader.java | 3 ++ .../org/apache/spark/sql/FlintREPL.scala | 31 ++++----------- .../apache/spark/sql/SessionManagerImpl.scala | 38 +++++++++++++------ .../org/apache/spark/sql/FlintREPLTest.scala | 1 - 4 files changed, 38 insertions(+), 35 deletions(-) diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java index d5fb45f99..c5f178c56 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java @@ -17,11 +17,13 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.logging.Logger; /** * Abstract OpenSearch Reader. */ public abstract class OpenSearchReader implements FlintReader { + private static final Logger LOG = Logger.getLogger(OpenSearchReader.class.getName()); @VisibleForTesting /** Search request source builder. */ @@ -48,6 +50,7 @@ public OpenSearchReader(IRestHighLevelClient client, SearchRequest searchRequest return false; } List searchHits = Arrays.asList(response.get().getHits().getHits()); + LOG.info("Result sets: " + searchHits.size()); iterator = searchHits.iterator(); } return iterator.hasNext(); diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 3ff9b0dcc..ea9905dde 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -11,27 +11,21 @@ import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} import scala.concurrent.duration._ -import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import com.codahale.metrics.Timer -import org.json4s.native.Serialization -import org.opensearch.action.get.GetResponse -import org.opensearch.common.Strings import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} -import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.FlintREPLConfConstants._ -import org.apache.spark.sql.SessionUpdateMode.UPDATE_IF +import org.apache.spark.sql.SessionUpdateMode._ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.{ThreadUtils, Utils} @@ -108,7 +102,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val sessionId = getSessionId(conf) logInfo(s"sessionId: ${sessionId}") val spark = createSparkSession(conf) - val sessionManager = instantiateSessionManager(spark, resultIndexOption) + val sessionManager = instantiateSessionManager(spark, sessionId, resultIndexOption) val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = @@ -410,7 +404,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logInfo(s"State is: ${sessionDetails.state}") sessionDetails.state = state logInfo(s"State is: ${sessionDetails.state}") - sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionManager.updateSessionDetails(sessionDetails, updateMode = UPSERT) sessionDetails } @@ -436,16 +430,6 @@ object FlintREPL extends Logging with FlintJobExecutor { recordSessionFailed(sessionTimerContext) } - private def updateFlintInstance( - flintInstance: InteractiveSession, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { - val currentTime = currentTimeProvider.currentEpochMillis() - flintSessionIndexUpdater.upsert( - sessionId, - InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) - } - /** * handling the case where a command's execution fails, updates the flintStatement with the * error and failure status, and then write the result to result index. Thus, an error is @@ -534,7 +518,7 @@ object FlintREPL extends Logging with FlintJobExecutor { finalizeCommand(context, dataToWrite, flintStatement, statementTimerContext) // last query finish time is last activity time lastActivityTime = currentTimeProvider.currentEpochMillis() - case None => + case _ => canProceed = false } } @@ -790,7 +774,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // processing. if (!earlyExitFlag && !sessionDetails.isComplete && !sessionDetails.isFail) { sessionDetails.complete() - sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + sessionManager.updateSessionDetails(sessionDetails, updateMode = UPDATE_IF) recordSessionSuccess(sessionTimerContext) } } @@ -955,9 +939,10 @@ object FlintREPL extends Logging with FlintJobExecutor { private def instantiateSessionManager( spark: SparkSession, - resultIndex: Option[String]): SessionManager = { + sessionId: String, + resultIndexOption: Option[String]): SessionManager = { instantiate( - new SessionManagerImpl(spark, resultIndex), + new SessionManagerImpl(spark, sessionId, resultIndexOption), spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index a34dc8e61..cdefb0e8c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -18,20 +18,22 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode import org.apache.spark.sql.flint.config.FlintSparkConf -class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) +class SessionManagerImpl( + spark: SparkSession, + sessionId: String, + resultIndexOption: Option[String]) extends SessionManager with FlintJobExecutor with Logging { // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key) - val sessionId: String = spark.conf.get(FlintSparkConf.SESSION_ID.key) val dataSource: String = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) if (sessionIndex.isEmpty) { logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } - if (resultIndex.isEmpty) { + if (resultIndexOption.isEmpty) { logAndThrow("resultIndex is not set") } if (sessionId.isEmpty) { @@ -47,7 +49,7 @@ class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) override def getSessionContext: Map[String, Any] = { Map( - "resultIndex" -> resultIndex.get, + "resultIndex" -> resultIndexOption.get, "osClient" -> osClient, "flintSessionIndexUpdater" -> flintSessionIndexUpdater, "flintReader" -> flintReader) @@ -56,13 +58,27 @@ class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) override def getSessionDetails(sessionId: String): Option[InteractiveSession] = { Try(osClient.getDoc(sessionIndex, sessionId)) match { case Success(getResponse) if getResponse.isExists => - Option(getResponse.getSourceAsMap) + // Retrieve the source map and create session + val sessionOption = Option(getResponse.getSourceAsMap) .map(InteractiveSession.deserializeFromMap) + + // Retrieve sequence number and primary term from the response + val seqNo = getResponse.getSeqNo + val primaryTerm = getResponse.getPrimaryTerm + + // Add seqNo and primaryTerm to the session context + sessionOption.foreach { session => + session.setContextValue("seqNo", seqNo) + session.setContextValue("primaryTerm", primaryTerm) + } + + sessionOption case Failure(exception) => CustomLogging.logError( s"Failed to retrieve existing InteractiveSession: ${exception.getMessage}", exception) None + case _ => None } } @@ -92,13 +108,13 @@ class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) case SessionUpdateMode.UPDATE_IF => val seqNo = sessionDetails - .getContextValue("_seq_no") - .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .getContextValue("seqNo") + .getOrElse(throw new IllegalArgumentException("Missing seqNo for conditional update")) .asInstanceOf[Long] val primaryTerm = sessionDetails - .getContextValue("_primary_term") + .getContextValue("primaryTerm") .getOrElse( - throw new IllegalArgumentException("Missing _primary_term for conditional update")) + throw new IllegalArgumentException("Missing primaryTerm for conditional update")) .asInstanceOf[Long] flintSessionIndexUpdater.updateIf( sessionDetails.sessionId, @@ -116,9 +132,9 @@ class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) override def getNextStatement(sessionId: String): Option[FlintStatement] = { if (flintReader.hasNext) { val rawStatement = flintReader.next() - logDebug(s"raw statement: $rawStatement") + logInfo(s"raw statement: $rawStatement") val flintStatement = FlintStatement.deserialize(rawStatement) - logDebug(s"statement: $flintStatement") + logInfo(s"statement: $flintStatement") Some(flintStatement) } else { None diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index c1423c458..48ea9d4a2 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -41,7 +41,6 @@ import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils -@Ignore class FlintREPLTest extends SparkFunSuite with MockitoSugar From 7f23b556abe9cf576d30b7cb6e28d9af279d0d67 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 15 Aug 2024 19:45:24 -0700 Subject: [PATCH 4/7] Address os client concurrency limitation Signed-off-by: Louis Chu --- .../org/apache/spark/sql/SessionManager.scala | 5 - ...scala => StatementsExecutionManager.scala} | 16 ++- .../core/storage/OpenSearchClientUtils.java | 4 + .../org/apache/spark/sql/CommandContext.scala | 22 ++-- .../org/apache/spark/sql/FlintREPL.scala | 58 +++++---- .../apache/spark/sql/SessionManagerImpl.scala | 74 +---------- .../sql/StatementLifecycleManagerImpl.scala | 59 --------- .../sql/StatementsExecutionManagerImpl.scala | 120 ++++++++++++++++++ .../org/apache/spark/sql/FlintREPLTest.scala | 21 +-- 9 files changed, 198 insertions(+), 181 deletions(-) rename flint-commons/src/main/scala/org/apache/spark/sql/{StatementLifecycleManager.scala => StatementsExecutionManager.scala} (50%) delete mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala index 1c270fbc1..00d48b20c 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -31,11 +31,6 @@ trait SessionManager { sessionDetails: InteractiveSession, updateMode: SessionUpdateMode): Unit - /** - * Retrieves the next statement to be executed in a specific session. - */ - def getNextStatement(sessionId: String): Option[FlintStatement] - /** * Records a heartbeat for a specific session to indicate it is still active. */ diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala similarity index 50% rename from flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala rename to flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala index 18bf1d819..df37a8146 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala @@ -8,17 +8,23 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement /** - * Trait defining the interface for managing the lifecycle of executing a FlintStatement. + * Trait defining the interface for managing FlintStatements executing in a micro-batch within + * same session. */ -trait StatementLifecycleManager { +trait StatementsExecutionManager { /** - * Prepares the statement lifecycle. + * Prepares execution of each individual statement */ - def prepareStatementLifecycle(): Either[String, Unit] + def prepareStatementExecution(): Either[String, Unit] // def executeStatement(statement: FlintStatement): DataFrame + /** + * Retrieves the next statement to be executed. + */ + def getNextStatement(): Option[FlintStatement] + /** * Updates a specific statement. */ @@ -27,5 +33,5 @@ trait StatementLifecycleManager { /** * Terminates the statement lifecycle. */ - def terminateStatementLifecycle(): Unit + def terminateStatementsExecution(): Unit } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java index 0f80d07c9..004c1784f 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java @@ -27,11 +27,15 @@ import org.opensearch.flint.core.RestHighLevelClientWrapper; import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor; import org.opensearch.flint.core.http.RetryableHttpAsyncClient; +import java.util.logging.Logger; + /** * Utility functions to create {@link IRestHighLevelClient}. */ public class OpenSearchClientUtils { + private static final Logger LOG = Logger.getLogger(OpenSearchClientUtils.class.getName()); + /** * Metadata log index name prefix diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 42e4f82d8..c0bcf3211 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -11,14 +11,14 @@ import scala.concurrent.duration.Duration import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} case class CommandContext( - spark: SparkSession, - dataSource: String, - sessionId: String, - sessionManager: SessionManager, - jobId: String, - statementLifecycleManager: StatementLifecycleManager, - queryResultWriter: QueryResultWriter, - queryExecutionTimeout: Duration, - inactivityLimitMillis: Long, - queryWaitTimeMillis: Long, - queryLoopExecutionFrequency: Long) + val spark: SparkSession, + val dataSource: String, + val sessionId: String, + val sessionManager: SessionManager, + val jobId: String, + var statementsExecutionManager: StatementsExecutionManager, + val queryResultWriter: QueryResultWriter, + val queryExecutionTimeout: Duration, + val inactivityLimitMillis: Long, + val queryWaitTimeMillis: Long, + val queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index ea9905dde..941523341 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -65,6 +65,10 @@ object FlintREPL extends Logging with FlintJobExecutor { // init SparkContext val conf: SparkConf = createSparkConf() val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") + + if (dataSource == "unknown") { + logInfo(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") + } // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -102,7 +106,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val sessionId = getSessionId(conf) logInfo(s"sessionId: ${sessionId}") val spark = createSparkSession(conf) - val sessionManager = instantiateSessionManager(spark, sessionId, resultIndexOption) + val sessionManager = instantiateSessionManager(spark, resultIndexOption) val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = @@ -162,8 +166,6 @@ object FlintREPL extends Logging with FlintJobExecutor { return } - val statementLifecycleManager = - instantiateStatementLifecycleManager(conf, sessionManager.getSessionContext) val queryResultWriter = instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( @@ -172,7 +174,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, sessionManager, jobId, - statementLifecycleManager, + null, // StatementLifecycleManager will be instantiated inside query loop queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, @@ -313,10 +315,6 @@ object FlintREPL extends Logging with FlintJobExecutor { var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futurePrepareQueryExecution = Future { - statementLifecycleManager.prepareStatementLifecycle() - } - var lastActivityTime = currentTimeProvider.currentEpochMillis() var verificationResult: VerificationResult = NotVerified var canPickUpNextStatement = true @@ -324,7 +322,18 @@ object FlintREPL extends Logging with FlintJobExecutor { while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { logInfo(s"""Executing session with sessionId: ${sessionId}""") + val statementsExecutionManager = + instantiateStatementsExecutionManager( + spark.sparkContext.getConf, + sessionId, + dataSource, + sessionManager.getSessionContext) + futurePrepareQueryExecution = Future { + statementsExecutionManager.prepareStatementExecution() + } + + commandContext.statementsExecutionManager = statementsExecutionManager try { val commandState = CommandState( lastActivityTime, @@ -346,7 +355,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - statementLifecycleManager.terminateStatementLifecycle() + statementsExecutionManager.terminateStatementsExecution() } Thread.sleep(commandContext.queryLoopExecutionFrequency) @@ -502,11 +511,11 @@ object FlintREPL extends Logging with FlintJobExecutor { earlyExitFlag = true canProceed = false } else { - sessionManager.getNextStatement(sessionId) match { + statementsExecutionManager.getNextStatement() match { case Some(flintStatement) => flintStatement.running() logDebug(s"command running: $flintStatement") - statementLifecycleManager.updateStatement(flintStatement) + statementsExecutionManager.updateStatement(flintStatement) statementRunningCount.incrementAndGet() val statementTimerContext = getTimerContext( @@ -556,7 +565,7 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) flintStatement.fail() } finally { - statementLifecycleManager.updateStatement(flintStatement) + statementsExecutionManager.updateStatement(flintStatement) recordStatementStateChange(flintStatement, statementTimerContext) } } @@ -746,7 +755,6 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, false) }(executionContext) - // time out after 10 minutes ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) } @@ -920,16 +928,20 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionIdOption.get } - private def instantiate[T](defaultConstructor: => T, className: String): T = { + private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (className.isEmpty) { logInfo("Using default constructor") defaultConstructor } else { try { val classObject = Utils.classForName(className) - val ctor = classObject.getDeclaredConstructor() + val ctor = if (args.isEmpty) { + classObject.getDeclaredConstructor() + } else { + classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*) + } ctor.setAccessible(true) - ctor.newInstance().asInstanceOf[T] + ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T] } catch { case e: Exception => throw new RuntimeException(s"Failed to instantiate provider: $className", e) @@ -939,19 +951,21 @@ object FlintREPL extends Logging with FlintJobExecutor { private def instantiateSessionManager( spark: SparkSession, - sessionId: String, resultIndexOption: Option[String]): SessionManager = { instantiate( - new SessionManagerImpl(spark, sessionId, resultIndexOption), + new SessionManagerImpl(spark, resultIndexOption), spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) } - private def instantiateStatementLifecycleManager( + private def instantiateStatementsExecutionManager( sparkConf: SparkConf, - context: Map[String, Any]): StatementLifecycleManager = { + sessionId: String, + dataSource: String, + context: Map[String, Any]): StatementsExecutionManager = { instantiate( - new StatementLifecycleManagerImpl(context), - sparkConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, "")) + new StatementsExecutionManagerImpl(sessionId, dataSource, context), + sparkConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + sessionId) } private def instantiateQueryResultWriter( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index cdefb0e8c..85a1be5f1 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -12,47 +12,36 @@ import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.storage.FlintReader -import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode import org.apache.spark.sql.flint.config.FlintSparkConf -class SessionManagerImpl( - spark: SparkSession, - sessionId: String, - resultIndexOption: Option[String]) +class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String]) extends SessionManager with FlintJobExecutor with Logging { // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. - val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key) - val dataSource: String = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") if (sessionIndex.isEmpty) { logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } + if (resultIndexOption.isEmpty) { logAndThrow("resultIndex is not set") } - if (sessionId.isEmpty) { - logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") - } - if (dataSource.isEmpty) { - logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") - } val osClient = new OSClient(FlintSparkConf().flintOptions()) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) - val flintReader: FlintReader = createOpenSearchQueryReader() override def getSessionContext: Map[String, Any] = { Map( + "sessionIndex" -> sessionIndex, "resultIndex" -> resultIndexOption.get, "osClient" -> osClient, - "flintSessionIndexUpdater" -> flintSessionIndexUpdater, - "flintReader" -> flintReader) + "flintSessionIndexUpdater" -> flintSessionIndexUpdater) } override def getSessionDetails(sessionId: String): Option[InteractiveSession] = { @@ -129,63 +118,10 @@ class SessionManagerImpl( s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") } - override def getNextStatement(sessionId: String): Option[FlintStatement] = { - if (flintReader.hasNext) { - val rawStatement = flintReader.next() - logInfo(s"raw statement: $rawStatement") - val flintStatement = FlintStatement.deserialize(rawStatement) - logInfo(s"statement: $flintStatement") - Some(flintStatement) - } else { - None - } - } - override def recordHeartbeat(sessionId: String): Unit = { flintSessionIndexUpdater.upsert( sessionId, Serialization.write( Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) } - - private def createOpenSearchQueryReader() = { - // all state in index are in lower case - // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the - // same doc - val dsl = - s"""{ - | "bool": { - | "must": [ - | { - | "term": { - | "type": "statement" - | } - | }, - | { - | "term": { - | "state": "waiting" - | } - | }, - | { - | "term": { - | "sessionId": "$sessionId" - | } - | }, - | { - | "term": { - | "dataSourceName": "$dataSource" - | } - | }, - | { - | "range": { - | "submitTime": { "gte": "now-1h" } - | } - | } - | ] - | } - |}""".stripMargin - - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader - } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala deleted file mode 100644 index 7c8fb3457..000000000 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.apache.spark.sql - -import org.opensearch.flint.common.model.FlintStatement -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{createResultIndex, isSuperset, resultIndexMapping} -import org.apache.spark.sql.FlintREPL.executeQuery - -class StatementLifecycleManagerImpl(context: Map[String, Any]) - extends StatementLifecycleManager - with Logging { - - val resultIndex = context("resultIndex").asInstanceOf[String] - val osClient = context("osClient").asInstanceOf[OSClient] - val flintSessionIndexUpdater = - context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] - val flintReader = context("flintReader").asInstanceOf[FlintReader] - - override def prepareStatementLifecycle(): Either[String, Unit] = { - try { - val existingSchema = osClient.getIndexMetadata(resultIndex) - if (!isSuperset(existingSchema, resultIndexMapping)) { - Left(s"The mapping of $resultIndex is incorrect.") - } else { - Right(()) - } - } catch { - case e: IllegalStateException - if e.getCause != null && - e.getCause.getMessage.contains("index_not_found_exception") => - createResultIndex(osClient, resultIndex, resultIndexMapping) - case e: InterruptedException => - val error = s"Interrupted by the main thread: ${e.getMessage}" - Thread.currentThread().interrupt() // Preserve the interrupt status - logError(error, e) - Left(error) - case e: Exception => - val error = s"Failed to verify existing mapping: ${e.getMessage}" - logError(error, e) - Left(error) - } - } - override def updateStatement(statement: FlintStatement): Unit = { - flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) - } - override def terminateStatementLifecycle(): Unit = { - flintReader.close() - } - -// override def executeStatement(statement: FlintStatement): DataFrame = { -// executeQuery() -// } -} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala new file mode 100644 index 000000000..64d64a45c --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.{createResultIndex, isSuperset, resultIndexMapping} +import org.apache.spark.sql.FlintREPL.executeQuery + +class StatementsExecutionManagerImpl( + sessionId: String, + dataSource: String, + context: Map[String, Any]) + extends StatementsExecutionManager + with Logging { + + val sessionIndex = context("sessionIndex").asInstanceOf[String] + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + val flintSessionIndexUpdater = + context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] + + // Using one reader client within same session will cause concurrency issue. + // To resolve this move the reader creation and getNextStatement method to mirco-batch level + val flintReader = createOpenSearchQueryReader() + + override def prepareStatementExecution(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + override def updateStatement(statement: FlintStatement): Unit = { + flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) + } + override def terminateStatementsExecution(): Unit = { + flintReader.close() + } + + override def getNextStatement(): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + logInfo(s"raw statement: $rawStatement") + val flintStatement = FlintStatement.deserialize(rawStatement) + logInfo(s"statement: $flintStatement") + Some(flintStatement) + } else { + None + } + } + + // override def executeStatement(statement: FlintStatement): DataFrame = { +// executeQuery() +// } + + private def createOpenSearchQueryReader() = { + // all state in index are in lower case + // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the + // same doc + val dsl = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "type": "statement" + | } + | }, + | { + | "term": { + | "state": "waiting" + | } + | }, + | { + | "term": { + | "sessionId": "$sessionId" + | } + | }, + | { + | "term": { + | "dataSourceName": "$dataSource" + | } + | }, + | { + | "range": { + | "submitTime": { "gte": "now-1h" } + | } + | } + | ] + | } + |}""".stripMargin + + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) + flintReader + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 48ea9d4a2..5e31d797b 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils +@Ignore class FlintREPLTest extends SparkFunSuite with MockitoSugar @@ -608,7 +609,7 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementLifecycleManager] + val statementLifecycleManager = mock[StatementsExecutionManager] val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( @@ -945,7 +946,7 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementLifecycleManager] + val statementLifecycleManager = mock[StatementsExecutionManager] val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( @@ -962,7 +963,7 @@ class FlintREPLTest DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mock processCommands to always allow loop continuation - when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + when(statementLifecycleManager.getNextStatement()).thenReturn(None) val startTime = System.currentTimeMillis() @@ -985,7 +986,7 @@ class FlintREPLTest when(mockReader.hasNext).thenReturn(true) val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementLifecycleManager] + val statementLifecycleManager = mock[StatementsExecutionManager] val queryResultWriter = mock[QueryResultWriter] val resultIndex = "testResultIndex" @@ -1024,7 +1025,7 @@ class FlintREPLTest }) // Mock getNextStatement to return None, simulating the end of statements - when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + when(statementLifecycleManager.getNextStatement()).thenReturn(None) val startTime = System.currentTimeMillis() @@ -1053,9 +1054,9 @@ class FlintREPLTest val jobId = "testJobId" val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementLifecycleManager] + val statementLifecycleManager = mock[StatementsExecutionManager] val queryResultWriter = mock[QueryResultWriter] - when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + when(statementLifecycleManager.getNextStatement()).thenReturn(None) val inactivityLimit = 500 // 500 milliseconds @@ -1112,7 +1113,7 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementLifecycleManager] + val statementLifecycleManager = mock[StatementsExecutionManager] val queryResultWriter = mock[QueryResultWriter] val flintSessionIndexUpdater = mock[OpenSearchUpdater] @@ -1199,7 +1200,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementLifecycleManager] + val statementLifecycleManager = mock[StatementsExecutionManager] val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( @@ -1252,7 +1253,7 @@ class FlintREPLTest val jobId = "testJobId" val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementLifecycleManager] + val statementLifecycleManager = mock[StatementsExecutionManager] val queryResultWriter = mock[QueryResultWriter] // Create a SparkSession for testing From ce469225ac0c0080a344a76e20efe596ce57019c Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 15 Aug 2024 20:29:35 -0700 Subject: [PATCH 5/7] Add interface for query execution Signed-off-by: Louis Chu --- .../sql/StatementsExecutionManager.scala | 5 ++++- .../org/apache/spark/sql/FlintREPL.scala | 22 +++++++++---------- .../sql/StatementsExecutionManagerImpl.scala | 7 +++--- .../org/apache/spark/sql/FlintREPLTest.scala | 6 +++++ 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala index df37a8146..ae9c4bce4 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala @@ -18,7 +18,10 @@ trait StatementsExecutionManager { */ def prepareStatementExecution(): Either[String, Unit] -// def executeStatement(statement: FlintStatement): DataFrame + /** + * Executes a specific statement and returns the spark dataframe + */ + def executeStatement(statement: FlintStatement): DataFrame /** * Retrieves the next statement to be executed. diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 941523341..49e45b316 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -324,7 +324,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logInfo(s"""Executing session with sessionId: ${sessionId}""") val statementsExecutionManager = instantiateStatementsExecutionManager( - spark.sparkContext.getConf, + spark, sessionId, dataSource, sessionManager.getSessionContext) @@ -606,6 +606,7 @@ object FlintREPL extends Logging with FlintJobExecutor { def executeAndHandle( spark: SparkSession, flintStatement: FlintStatement, + statementsExecutionManager: StatementsExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -617,6 +618,7 @@ object FlintREPL extends Logging with FlintJobExecutor { executeQueryAsync( spark, flintStatement, + statementsExecutionManager: StatementsExecutionManager, dataSource, sessionId, executionContext, @@ -660,6 +662,7 @@ object FlintREPL extends Logging with FlintJobExecutor { dataToWrite = executeAndHandle( spark, flintStatement, + statementsExecutionManager, dataSource, sessionId, executionContext, @@ -715,6 +718,7 @@ object FlintREPL extends Logging with FlintJobExecutor { dataToWrite = executeAndHandle( spark, flintStatement, + statementsExecutionManager, dataSource, sessionId, executionContext, @@ -730,6 +734,7 @@ object FlintREPL extends Logging with FlintJobExecutor { def executeQueryAsync( spark: SparkSession, flintStatement: FlintStatement, + statementsExecutionManager: StatementsExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -747,13 +752,7 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } else { val futureQueryExecution = Future { - executeQuery( - spark, - flintStatement.query, - dataSource, - flintStatement.queryId, - sessionId, - false) + statementsExecutionManager.executeStatement(flintStatement) }(executionContext) // time out after 10 minutes ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) @@ -958,13 +957,14 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def instantiateStatementsExecutionManager( - sparkConf: SparkConf, + spark: SparkSession, sessionId: String, dataSource: String, context: Map[String, Any]): StatementsExecutionManager = { instantiate( - new StatementsExecutionManagerImpl(sessionId, dataSource, context), - sparkConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + new StatementsExecutionManagerImpl(spark, sessionId, dataSource, context), + spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + spark, sessionId) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala index 64d64a45c..e138e6067 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala @@ -14,6 +14,7 @@ import org.apache.spark.sql.FlintJob.{createResultIndex, isSuperset, resultIndex import org.apache.spark.sql.FlintREPL.executeQuery class StatementsExecutionManagerImpl( + spark: SparkSession, sessionId: String, dataSource: String, context: Map[String, Any]) @@ -73,9 +74,9 @@ class StatementsExecutionManagerImpl( } } - // override def executeStatement(statement: FlintStatement): DataFrame = { -// executeQuery() -// } + override def executeStatement(statement: FlintStatement): DataFrame = { + executeQuery(spark, statement.query, dataSource, statement.queryId, sessionId, false) + } private def createOpenSearchQueryReader() = { // all state in index are in lower case diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 5e31d797b..10a85c62b 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -645,6 +645,8 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + + val mockStatementsExecutionManager = mock[StatementsExecutionManager] // val mockExecutionContextExecutor: ExecutionContextExecutor = mock[ExecutionContextExecutor] val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -676,6 +678,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, flintStatement, + mockStatementsExecutionManager, dataSource, sessionId, executionContext, @@ -698,6 +701,8 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + val mockStatementsExecutionManager = mock[StatementsExecutionManager] + val flintStatement = new FlintStatement( "Running", @@ -729,6 +734,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, flintStatement, + mockStatementsExecutionManager, dataSource, sessionId, executionContext, From 93f2e71e4f310347db1c38ae2f03786c151b70c5 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Fri, 16 Aug 2024 15:02:40 -0700 Subject: [PATCH 6/7] Clean up logs and add UTs Signed-off-by: Louis Chu --- .../apache/spark/sql/QueryResultWriter.scala | 8 ++ ....scala => StatementExecutionManager.scala} | 8 +- .../flint/common/model/FlintStatement.scala | 2 +- .../core/storage/OpenSearchClientUtils.java | 4 - .../flint/core/storage/OpenSearchReader.java | 4 - .../opensearch/flint/OpenSearchSuite.scala | 4 +- .../org/apache/spark/sql/CommandContext.scala | 2 +- .../org/apache/spark/sql/FlintREPL.scala | 32 +++---- .../scala/org/apache/spark/sql/OSClient.scala | 1 - .../spark/sql/QueryResultWriterImpl.scala | 4 +- .../apache/spark/sql/SessionManagerImpl.scala | 16 ++-- ...la => StatementExecutionManagerImpl.scala} | 43 ++------- .../org/apache/spark/sql/FlintREPLTest.scala | 93 +++++++++++-------- 13 files changed, 107 insertions(+), 114 deletions(-) rename flint-commons/src/main/scala/org/apache/spark/sql/{StatementsExecutionManager.scala => StatementExecutionManager.scala} (72%) rename spark-sql-application/src/main/scala/org/apache/spark/sql/{StatementsExecutionManagerImpl.scala => StatementExecutionManagerImpl.scala} (65%) diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala index 7ddf6604b..49dc8e355 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -7,6 +7,14 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement +/** + * Trait for writing the result of a query execution to an external data storage. + */ trait QueryResultWriter { + + /** + * Writes the given DataFrame, which represents the result of a query execution, to an external + * data storage based on the provided FlintStatement metadata. + */ def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit } diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala similarity index 72% rename from flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala rename to flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala index ae9c4bce4..acf28c572 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementsExecutionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala @@ -8,10 +8,12 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement /** - * Trait defining the interface for managing FlintStatements executing in a micro-batch within - * same session. + * Trait defining the interface for managing FlintStatement execution. For example, in FlintREPL, + * multiple FlintStatements are running in a micro-batch within same session. + * + * This interface can also apply to other spark entry point like FlintJob. */ -trait StatementsExecutionManager { +trait StatementExecutionManager { /** * Prepares execution of each individual statement diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala index bc8b38d9a..00876d46e 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala @@ -65,7 +65,7 @@ class FlintStatement( // Does not include context, which could contain sensitive information. override def toString: String = - s"FlintStatement(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" + s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" } object FlintStatement { diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java index 004c1784f..0f80d07c9 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java @@ -27,15 +27,11 @@ import org.opensearch.flint.core.RestHighLevelClientWrapper; import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor; import org.opensearch.flint.core.http.RetryableHttpAsyncClient; -import java.util.logging.Logger; - /** * Utility functions to create {@link IRestHighLevelClient}. */ public class OpenSearchClientUtils { - private static final Logger LOG = Logger.getLogger(OpenSearchClientUtils.class.getName()); - /** * Metadata log index name prefix diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java index c5f178c56..1440db1f3 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java @@ -17,14 +17,11 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.logging.Logger; /** * Abstract OpenSearch Reader. */ public abstract class OpenSearchReader implements FlintReader { - private static final Logger LOG = Logger.getLogger(OpenSearchReader.class.getName()); - @VisibleForTesting /** Search request source builder. */ public final SearchRequest searchRequest; @@ -50,7 +47,6 @@ public OpenSearchReader(IRestHighLevelClient client, SearchRequest searchRequest return false; } List searchHits = Arrays.asList(response.get().getHits().getHits()); - LOG.info("Result sets: " + searchHits.size()); iterator = searchHits.iterator(); } return iterator.hasNext(); diff --git a/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala index cde3230d4..35c700aca 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala @@ -16,13 +16,12 @@ import org.opensearch.common.xcontent.XContentType import org.opensearch.testcontainers.OpenSearchContainer import org.scalatest.{BeforeAndAfterAll, Suite} -import org.apache.spark.internal.Logging import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT, IGNORE_DOC_ID_COLUMN, REFRESH_POLICY} /** * Test required OpenSearch domain should extend OpenSearchSuite. */ -trait OpenSearchSuite extends BeforeAndAfterAll with Logging { +trait OpenSearchSuite extends BeforeAndAfterAll { self: Suite => protected lazy val container = new OpenSearchContainer() @@ -146,7 +145,6 @@ trait OpenSearchSuite extends BeforeAndAfterAll with Logging { val response = openSearchClient.bulk(request, RequestOptions.DEFAULT) - logInfo(response.toString) assume( !response.hasFailures, s"bulk index docs to $index failed: ${response.buildFailureMessage()}") diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index c0bcf3211..dcc486922 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -16,7 +16,7 @@ case class CommandContext( val sessionId: String, val sessionManager: SessionManager, val jobId: String, - var statementsExecutionManager: StatementsExecutionManager, + var statementsExecutionManager: StatementExecutionManager, val queryResultWriter: QueryResultWriter, val queryExecutionTimeout: Duration, val inactivityLimitMillis: Long, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 49e45b316..5cdc805dd 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -64,10 +64,10 @@ object FlintREPL extends Logging with FlintJobExecutor { // init SparkContext val conf: SparkConf = createSparkConf() - val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") + val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "") - if (dataSource == "unknown") { - logInfo(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") + if (dataSource.trim.isEmpty) { + logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set or is empty") } // https://github.com/opensearch-project/opensearch-spark/issues/138 /* @@ -323,7 +323,7 @@ object FlintREPL extends Logging with FlintJobExecutor { .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { logInfo(s"""Executing session with sessionId: ${sessionId}""") val statementsExecutionManager = - instantiateStatementsExecutionManager( + instantiateStatementExecutionManager( spark, sessionId, dataSource, @@ -514,7 +514,6 @@ object FlintREPL extends Logging with FlintJobExecutor { statementsExecutionManager.getNextStatement() match { case Some(flintStatement) => flintStatement.running() - logDebug(s"command running: $flintStatement") statementsExecutionManager.updateStatement(flintStatement) statementRunningCount.incrementAndGet() @@ -606,7 +605,7 @@ object FlintREPL extends Logging with FlintJobExecutor { def executeAndHandle( spark: SparkSession, flintStatement: FlintStatement, - statementsExecutionManager: StatementsExecutionManager, + statementsExecutionManager: StatementExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -618,7 +617,7 @@ object FlintREPL extends Logging with FlintJobExecutor { executeQueryAsync( spark, flintStatement, - statementsExecutionManager: StatementsExecutionManager, + statementsExecutionManager: StatementExecutionManager, dataSource, sessionId, executionContext, @@ -734,7 +733,7 @@ object FlintREPL extends Logging with FlintJobExecutor { def executeQueryAsync( spark: SparkSession, flintStatement: FlintStatement, - statementsExecutionManager: StatementsExecutionManager, + statementsExecutionManager: StatementExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -919,12 +918,13 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } - private def getSessionId(conf: SparkConf): String = { - val sessionIdOption: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) - if (sessionIdOption.isEmpty) { - logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") + def getSessionId(conf: SparkConf): String = { + conf.getOption(FlintSparkConf.SESSION_ID.key) match { + case Some(sessionId) if sessionId.nonEmpty => + sessionId + case _ => + logAndThrow(s"${FlintSparkConf.SESSION_ID.key} is not set or is empty") } - sessionIdOption.get } private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { @@ -956,13 +956,13 @@ object FlintREPL extends Logging with FlintJobExecutor { spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) } - private def instantiateStatementsExecutionManager( + private def instantiateStatementExecutionManager( spark: SparkSession, sessionId: String, dataSource: String, - context: Map[String, Any]): StatementsExecutionManager = { + context: Map[String, Any]): StatementExecutionManager = { instantiate( - new StatementsExecutionManagerImpl(spark, sessionId, dataSource, context), + new StatementExecutionManagerImpl(spark, sessionId, dataSource, context), spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), spark, sessionId) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index 999742e67..422cfc947 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -121,7 +121,6 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { case Success(response) => IRestHighLevelClient.recordOperationSuccess( MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX) - logInfo(response.toString) response case Failure(e: Exception) => IRestHighLevelClient.recordOperationFailure( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala index 8d07e91ae..238f8fa3d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -12,8 +12,8 @@ import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging { - val resultIndex = context("resultIndex").asInstanceOf[String] - val osClient = context("osClient").asInstanceOf[OSClient] + private val resultIndex = context("resultIndex").asInstanceOf[String] + private val osClient = context("osClient").asInstanceOf[OSClient] override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index 85a1be5f1..79bd63200 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -22,19 +22,19 @@ class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String]) with FlintJobExecutor with Logging { - // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. - val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } + + // we don't allow default value for sessionIndex. Throw exception if key not found. + private val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") if (sessionIndex.isEmpty) { logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } - if (resultIndexOption.isEmpty) { - logAndThrow("resultIndex is not set") - } - - val osClient = new OSClient(FlintSparkConf().flintOptions()) - val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + private val osClient = new OSClient(FlintSparkConf().flintOptions()) + private val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) override def getSessionContext: Map[String, Any] = { Map( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala similarity index 65% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala rename to spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index e138e6067..dc8414a17 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementsExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -10,50 +10,29 @@ import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{createResultIndex, isSuperset, resultIndexMapping} +import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createResultIndex, isSuperset, resultIndexMapping} import org.apache.spark.sql.FlintREPL.executeQuery -class StatementsExecutionManagerImpl( +class StatementExecutionManagerImpl( spark: SparkSession, sessionId: String, dataSource: String, context: Map[String, Any]) - extends StatementsExecutionManager + extends StatementExecutionManager with Logging { - val sessionIndex = context("sessionIndex").asInstanceOf[String] - val resultIndex = context("resultIndex").asInstanceOf[String] - val osClient = context("osClient").asInstanceOf[OSClient] - val flintSessionIndexUpdater = + private val sessionIndex = context("sessionIndex").asInstanceOf[String] + private val resultIndex = context("resultIndex").asInstanceOf[String] + private val osClient = context("osClient").asInstanceOf[OSClient] + private val flintSessionIndexUpdater = context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] // Using one reader client within same session will cause concurrency issue. // To resolve this move the reader creation and getNextStatement method to mirco-batch level - val flintReader = createOpenSearchQueryReader() + private val flintReader = createOpenSearchQueryReader() override def prepareStatementExecution(): Either[String, Unit] = { - try { - val existingSchema = osClient.getIndexMetadata(resultIndex) - if (!isSuperset(existingSchema, resultIndexMapping)) { - Left(s"The mapping of $resultIndex is incorrect.") - } else { - Right(()) - } - } catch { - case e: IllegalStateException - if e.getCause != null && - e.getCause.getMessage.contains("index_not_found_exception") => - createResultIndex(osClient, resultIndex, resultIndexMapping) - case e: InterruptedException => - val error = s"Interrupted by the main thread: ${e.getMessage}" - Thread.currentThread().interrupt() // Preserve the interrupt status - logError(error, e) - Left(error) - case e: Exception => - val error = s"Failed to verify existing mapping: ${e.getMessage}" - logError(error, e) - Left(error) - } + checkAndCreateIndex(osClient, resultIndex) } override def updateStatement(statement: FlintStatement): Unit = { flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) @@ -65,9 +44,8 @@ class StatementsExecutionManagerImpl( override def getNextStatement(): Option[FlintStatement] = { if (flintReader.hasNext) { val rawStatement = flintReader.next() - logInfo(s"raw statement: $rawStatement") val flintStatement = FlintStatement.deserialize(rawStatement) - logInfo(s"statement: $flintStatement") + logInfo(s"Next statement to execute: $flintStatement") Some(flintStatement) } else { None @@ -114,7 +92,6 @@ class StatementsExecutionManagerImpl( | ] | } |}""".stripMargin - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) flintReader } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 10a85c62b..bba0e40e2 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -41,7 +41,6 @@ import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils -@Ignore class FlintREPLTest extends SparkFunSuite with MockitoSugar @@ -50,38 +49,56 @@ class FlintREPLTest // By using a type alias and casting, I can bypass the type checking error. type AnyScheduledFuture = ScheduledFuture[_] - test( - "parseArgs with one argument should return None for query and the argument as resultIndex") { + test("parseArgs with no arguments should return (None, None)") { + val args = Array.empty[String] + val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) + queryOption shouldBe None + resultIndexOption shouldBe None + } + + test("parseArgs with one argument should return None for query and Some for resultIndex") { val args = Array("resultIndexName") - val (queryOption, resultIndex) = FlintREPL.parseArgs(args) + val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) queryOption shouldBe None - resultIndex shouldBe "resultIndexName" + resultIndexOption shouldBe Some("resultIndexName") } - test( - "parseArgs with two arguments should return the first argument as query and the second as resultIndex") { + test("parseArgs with two arguments should return Some for both query and resultIndex") { val args = Array("SELECT * FROM table", "resultIndexName") - val (queryOption, resultIndex) = FlintREPL.parseArgs(args) + val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) queryOption shouldBe Some("SELECT * FROM table") - resultIndex shouldBe "resultIndexName" + resultIndexOption shouldBe Some("resultIndexName") } test( - "parseArgs with no arguments should throw IllegalArgumentException with specific message") { - val args = Array.empty[String] + "parseArgs with more than two arguments should throw IllegalArgumentException with specific message") { + val args = Array("arg1", "arg2", "arg3") val exception = intercept[IllegalArgumentException] { FlintREPL.parseArgs(args) } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + exception.getMessage shouldBe "Unsupported number of arguments. Expected no more than two arguments." } - test( - "parseArgs with more than two arguments should throw IllegalArgumentException with specific message") { - val args = Array("arg1", "arg2", "arg3") + test("getSessionId should throw exception when SESSION_ID is not set") { + val conf = new SparkConf() val exception = intercept[IllegalArgumentException] { - FlintREPL.parseArgs(args) + FlintREPL.getSessionId(conf) } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + assert(exception.getMessage === FlintSparkConf.SESSION_ID.key + " is not set or is empty") + } + + test("getSessionId should return the session ID when it's set") { + val sessionId = "test-session-id" + val conf = new SparkConf().set(FlintSparkConf.SESSION_ID.key, sessionId) + assert(FlintREPL.getSessionId(conf) === sessionId) + } + + test("getSessionId should throw exception when SESSION_ID is set to empty string") { + val conf = new SparkConf().set(FlintSparkConf.SESSION_ID.key, "") + val exception = intercept[IllegalArgumentException] { + FlintREPL.getSessionId(conf) + } + assert(exception.getMessage === FlintSparkConf.SESSION_ID.key + " is not set or is empty") } test("getQuery should return query from queryOption if present") { @@ -159,7 +176,7 @@ class FlintREPLTest } } - test("createHeartBeatUpdater should update heartbeat correctly") { + ignore("createHeartBeatUpdater should update heartbeat correctly") { // Mocks val threadPool = mock[ScheduledExecutorService] val scheduledFutureRaw = mock[ScheduledFuture[_]] @@ -321,7 +338,7 @@ class FlintREPLTest assert(!result) // The function should return false } - test("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { + ignore("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { val sessionId = "session123" val jobId = "jobABC" val osClient = mock[OSClient] @@ -545,7 +562,7 @@ class FlintREPLTest assert(!result) } - test("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { + ignore("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { val sessionId = "session123" val jobId = "jobABC" val osClient = mock[OSClient] @@ -609,7 +626,7 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementsExecutionManager] + val statementLifecycleManager = mock[StatementExecutionManager] val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( @@ -639,14 +656,14 @@ class FlintREPLTest } } - test("executeAndHandle should handle TimeoutException properly") { + ignore("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) - val mockStatementsExecutionManager = mock[StatementsExecutionManager] + val mockStatementsExecutionManager = mock[StatementExecutionManager] // val mockExecutionContextExecutor: ExecutionContextExecutor = mock[ExecutionContextExecutor] val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -695,13 +712,13 @@ class FlintREPLTest } finally threadPool.shutdown() } - test("executeAndHandle should handle ParseException properly") { + ignore("executeAndHandle should handle ParseException properly") { val mockSparkSession = mock[SparkSession] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) - val mockStatementsExecutionManager = mock[StatementsExecutionManager] + val mockStatementsExecutionManager = mock[StatementExecutionManager] val flintStatement = new FlintStatement( @@ -795,7 +812,7 @@ class FlintREPLTest assert(!result) // Expecting false as the job should proceed normally } - test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { + ignore("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] val applicationId = "app1" @@ -911,7 +928,7 @@ class FlintREPLTest assert(!result) // Expecting false as the job proceeds normally } - test( + ignore( "setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") { val osClient = mock[OSClient] val flintSessionIndexUpdater = mock[OpenSearchUpdater] @@ -933,7 +950,7 @@ class FlintREPLTest } } - test("queryLoop continue until inactivity limit is reached") { + ignore("queryLoop continue until inactivity limit is reached") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) @@ -952,7 +969,7 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementsExecutionManager] + val statementLifecycleManager = mock[StatementExecutionManager] val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( @@ -984,7 +1001,7 @@ class FlintREPLTest spark.stop() } - test("queryLoop should stop when canPickUpNextStatement is false") { + ignore("queryLoop should stop when canPickUpNextStatement is false") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) @@ -992,7 +1009,7 @@ class FlintREPLTest when(mockReader.hasNext).thenReturn(true) val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementsExecutionManager] + val statementLifecycleManager = mock[StatementExecutionManager] val queryResultWriter = mock[QueryResultWriter] val resultIndex = "testResultIndex" @@ -1046,7 +1063,7 @@ class FlintREPLTest spark.stop() } - test("queryLoop should properly shut down the thread pool after execution") { + ignore("queryLoop should properly shut down the thread pool after execution") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) @@ -1060,7 +1077,7 @@ class FlintREPLTest val jobId = "testJobId" val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementsExecutionManager] + val statementLifecycleManager = mock[StatementExecutionManager] val queryResultWriter = mock[QueryResultWriter] when(statementLifecycleManager.getNextStatement()).thenReturn(None) @@ -1119,7 +1136,7 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementsExecutionManager] + val statementLifecycleManager = mock[StatementExecutionManager] val queryResultWriter = mock[QueryResultWriter] val flintSessionIndexUpdater = mock[OpenSearchUpdater] @@ -1155,7 +1172,7 @@ class FlintREPLTest } } - test("queryLoop should correctly update loop control variables") { + ignore("queryLoop should correctly update loop control variables") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) @@ -1206,7 +1223,7 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementsExecutionManager] + val statementLifecycleManager = mock[StatementExecutionManager] val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( @@ -1240,7 +1257,7 @@ class FlintREPLTest (100, 300L) // 100 ms, 300 ms ) - test( + ignore( "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => val mockReader = mock[FlintReader] @@ -1259,7 +1276,7 @@ class FlintREPLTest val jobId = "testJobId" val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementsExecutionManager] + val statementLifecycleManager = mock[StatementExecutionManager] val queryResultWriter = mock[QueryResultWriter] // Create a SparkSession for testing From 28bacdb75c5ac9898a5f3bdca4e94fb3452a6700 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Sun, 25 Aug 2024 18:21:09 -0700 Subject: [PATCH 7/7] Clean up and fix UTs Signed-off-by: Louis Chu --- .../common/model/InteractiveSession.scala | 6 +- .../apache/spark/sql/FlintJobITSuite.scala | 12 +- .../apache/spark/sql/FlintREPLITSuite.scala | 2 +- .../org/apache/spark/sql/CommandContext.scala | 22 +- .../scala/org/apache/spark/sql/FlintJob.scala | 6 + .../apache/spark/sql/FlintJobExecutor.scala | 20 +- .../org/apache/spark/sql/FlintREPL.scala | 123 ++- .../org/apache/spark/sql/JobOperator.scala | 40 +- .../apache/spark/sql/SessionManagerImpl.scala | 10 +- .../sql/StatementExecutionManagerImpl.scala | 29 +- .../org/apache/spark/sql/FlintJobTest.scala | 8 +- .../org/apache/spark/sql/FlintREPLTest.scala | 763 +++++++++++------- 12 files changed, 648 insertions(+), 393 deletions(-) diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index 9acdeab5f..915d5e229 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -14,6 +14,8 @@ import org.json4s.JsonAST.{JArray, JString} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization +import org.apache.spark.internal.Logging + object SessionStates { val RUNNING = "running" val DEAD = "dead" @@ -52,7 +54,8 @@ class InteractiveSession( val excludedJobIds: Seq[String] = Seq.empty[String], val error: Option[String] = None, sessionContext: Map[String, Any] = Map.empty[String, Any]) - extends ContextualDataStore { + extends ContextualDataStore + with Logging { context = sessionContext // Initialize the context from the constructor def running(): Unit = state = SessionStates.RUNNING @@ -96,6 +99,7 @@ object InteractiveSession { // Replace extractOpt with jsonOption and map val excludeJobIds: Seq[String] = meta \ "excludeJobIds" match { case JArray(lst) => lst.map(_.extract[String]) + case JString(s) => Seq(s) case _ => Seq.empty[String] } diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala index 7318e5c7c..9371557a2 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -98,9 +98,15 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { * JobOperator instance to accommodate specific runtime requirements. */ val job = - JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount) - job.envinromentProvider = new MockEnvironment( - Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + JobOperator( + appId, + jobRunId, + spark, + query, + dataSourceName, + resultIndex, + true, + streamingRunningCount) job.terminateJVM = false job.start() } diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 1d86a6589..24f3c9a89 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -169,7 +169,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { "spark.flint.job.queryLoopExecutionFrequency", queryLoopExecutionFrequency.toString) - FlintREPL.envinromentProvider = new MockEnvironment( + FlintREPL.environmentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) FlintREPL.enableHiveSupport = false FlintREPL.terminateJVM = false diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index dcc486922..0d5c062ae 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -11,14 +11,14 @@ import scala.concurrent.duration.Duration import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} case class CommandContext( - val spark: SparkSession, - val dataSource: String, - val sessionId: String, - val sessionManager: SessionManager, - val jobId: String, - var statementsExecutionManager: StatementExecutionManager, - val queryResultWriter: QueryResultWriter, - val queryExecutionTimeout: Duration, - val inactivityLimitMillis: Long, - val queryWaitTimeMillis: Long, - val queryLoopExecutionFrequency: Long) + applicationId: String, + jobId: String, + spark: SparkSession, + dataSource: String, + sessionId: String, + sessionManager: SessionManager, + queryResultWriter: QueryResultWriter, + queryExecutionTimeout: Duration, + inactivityLimitMillis: Long, + queryWaitTimeMillis: Long, + queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 1278657cb..c556e2786 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -55,9 +55,15 @@ object FlintJob extends Logging with FlintJobExecutor { conf.set("spark.sql.defaultCatalog", dataSource) configDYNMaxExecutors(conf, jobType) + val applicationId = + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( + applicationId, + jobId, createSparkSession(conf), query, dataSource, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 9b3841c7d..95d3ba0f1 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -41,7 +41,7 @@ trait FlintJobExecutor { var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() - var envinromentProvider: EnvironmentProvider = new RealEnvironment() + var environmentProvider: EnvironmentProvider = new RealEnvironment() var enableHiveSupport: Boolean = true // termiante JVM in the presence non-deamon thread before exiting var terminateJVM = true @@ -190,6 +190,7 @@ trait FlintJobExecutor { } } + // scalastyle:off /** * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. * @@ -201,6 +202,8 @@ trait FlintJobExecutor { * dataframe with result, schema and emr step id */ def getFormattedData( + applicationId: String, + jobId: String, result: DataFrame, spark: SparkSession, dataSource: String, @@ -231,14 +234,13 @@ trait FlintJobExecutor { // after consumed the query result. Streaming query shuffle data is cleaned after each // microBatch execution. cleaner.cleanUp(spark) - // Create the data rows val rows = Seq( ( resultToSave, resultSchemaToSave, - envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + jobId, + applicationId, dataSource, "SUCCESS", "", @@ -254,6 +256,8 @@ trait FlintJobExecutor { } def constructErrorDF( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, status: String, @@ -270,8 +274,8 @@ trait FlintJobExecutor { ( null, null, - envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + jobId, + applicationId, dataSource, status.toUpperCase(Locale.ROOT), error, @@ -396,6 +400,8 @@ trait FlintJobExecutor { } def executeQuery( + applicationId: String, + jobId: String, spark: SparkSession, query: String, dataSource: String, @@ -409,6 +415,8 @@ trait FlintJobExecutor { val result: DataFrame = spark.sql(query) // Get Data getFormattedData( + applicationId, + jobId, result, spark, dataSource, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 5cdc805dd..340c0656e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -19,7 +19,6 @@ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} -import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -79,6 +78,10 @@ object FlintREPL extends Logging with FlintJobExecutor { */ conf.set("spark.sql.defaultCatalog", dataSource) + val applicationId = + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val jobType = conf.get(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) CustomLogging.logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") conf.set(FlintSparkConf.JOB_TYPE.key, jobType) @@ -93,6 +96,8 @@ object FlintREPL extends Logging with FlintJobExecutor { val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( + applicationId, + jobId, createSparkSession(conf), query, dataSource, @@ -108,10 +113,6 @@ object FlintREPL extends Logging with FlintJobExecutor { val spark = createSparkSession(conf) val sessionManager = instantiateSessionManager(spark, resultIndexOption) - val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong( @@ -169,12 +170,12 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryResultWriter = instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - null, // StatementLifecycleManager will be instantiated inside query loop queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, @@ -315,25 +316,21 @@ object FlintREPL extends Logging with FlintJobExecutor { var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { + logInfo(s"""Executing session with sessionId: ${sessionId}""") + var lastActivityTime = currentTimeProvider.currentEpochMillis() var verificationResult: VerificationResult = NotVerified var canPickUpNextStatement = true var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo(s"""Executing session with sessionId: ${sessionId}""") val statementsExecutionManager = - instantiateStatementExecutionManager( - spark, - sessionId, - dataSource, - sessionManager.getSessionContext) + instantiateStatementExecutionManager(commandContext) futurePrepareQueryExecution = Future { statementsExecutionManager.prepareStatementExecution() } - commandContext.statementsExecutionManager = statementsExecutionManager try { val commandState = CommandState( lastActivityTime, @@ -342,7 +339,7 @@ object FlintREPL extends Logging with FlintJobExecutor { executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(commandContext, commandState) + processCommands(statementsExecutionManager, commandContext, commandState) val ( updatedLastActivityTime, @@ -374,9 +371,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionManager: SessionManager, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = refreshSessionState( + refreshSessionState( applicationId, jobId, sessionId, @@ -397,7 +392,7 @@ object FlintREPL extends Logging with FlintJobExecutor { state: String, error: Option[String] = None, excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { - + logInfo(s"refreshSessionState: ${jobId}") val sessionDetails = sessionManager .getSessionDetails(sessionId) .getOrElse( @@ -410,6 +405,7 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime, error = error, excludedJobIds = excludedJobIds)) + logInfo(s"Current session: ${sessionDetails}") logInfo(s"State is: ${sessionDetails.state}") sessionDetails.state = state logInfo(s"State is: ${sessionDetails.state}") @@ -460,6 +456,8 @@ object FlintREPL extends Logging with FlintJobExecutor { * failed data frame */ def handleCommandFailureAndGetFailedData( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, error: String, @@ -469,6 +467,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.fail() flintStatement.error = Some(error) super.constructErrorDF( + applicationId, + jobId, spark, dataSource, flintStatement.state, @@ -487,6 +487,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processCommands( + statementExecutionManager: StatementExecutionManager, context: CommandContext, state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ @@ -500,7 +501,6 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionId, sessionManager, jobId) @@ -511,19 +511,28 @@ object FlintREPL extends Logging with FlintJobExecutor { earlyExitFlag = true canProceed = false } else { - statementsExecutionManager.getNextStatement() match { + statementExecutionManager.getNextStatement() match { case Some(flintStatement) => flintStatement.running() - statementsExecutionManager.updateStatement(flintStatement) + statementExecutionManager.updateStatement(flintStatement) statementRunningCount.incrementAndGet() val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val (dataToWrite, returnedVerificationResult) = - processStatementOnVerification(flintStatement, state, context) + processStatementOnVerification( + statementExecutionManager, + flintStatement, + state, + context) verificationResult = returnedVerificationResult - finalizeCommand(context, dataToWrite, flintStatement, statementTimerContext) + finalizeCommand( + statementExecutionManager, + context, + dataToWrite, + flintStatement, + statementTimerContext) // last query finish time is last activity time lastActivityTime = currentTimeProvider.currentEpochMillis() case _ => @@ -545,6 +554,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * flint statement */ private def finalizeCommand( + statementExecutionManager: StatementExecutionManager, commandContext: CommandContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, @@ -564,18 +574,20 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) flintStatement.fail() } finally { - statementsExecutionManager.updateStatement(flintStatement) + statementExecutionManager.updateStatement(flintStatement) recordStatementStateChange(flintStatement, statementTimerContext) } } private def handleCommandTimeout( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, error: String, flintStatement: FlintStatement, sessionId: String, - startTime: Long): DataFrame = { + startTime: Long) = { /* * https://tinyurl.com/2ezs5xj9 * @@ -592,6 +604,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.timeout() flintStatement.error = Some(error) super.constructErrorDF( + applicationId, + jobId, spark, dataSource, flintStatement.state, @@ -602,10 +616,13 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } + // scalastyle:off def executeAndHandle( + applicationId: String, + jobId: String, spark: SparkSession, flintStatement: FlintStatement, - statementsExecutionManager: StatementExecutionManager, + statementExecutionManager: StatementExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -615,9 +632,11 @@ object FlintREPL extends Logging with FlintJobExecutor { try { Some( executeQueryAsync( + applicationId, + jobId, spark, flintStatement, - statementsExecutionManager: StatementExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -628,11 +647,22 @@ object FlintREPL extends Logging with FlintJobExecutor { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - Some(handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)) + Some( + handleCommandTimeout( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) case e: Exception => val error = processQueryException(e, flintStatement) Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -643,6 +673,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processStatementOnVerification( + statementExecutionManager: StatementExecutionManager, flintStatement: FlintStatement, commandState: CommandState, commandContext: CommandContext) = { @@ -659,9 +690,11 @@ object FlintREPL extends Logging with FlintJobExecutor { ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { case Right(_) => dataToWrite = executeAndHandle( + applicationId, + jobId, spark, flintStatement, - statementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -673,6 +706,8 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = VerifiedWithError(error) dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -686,6 +721,8 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) dataToWrite = Some( handleCommandTimeout( + applicationId, + jobId, spark, dataSource, error, @@ -697,6 +734,8 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -707,6 +746,8 @@ object FlintREPL extends Logging with FlintJobExecutor { case VerifiedWithError(err) => dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, err, @@ -715,9 +756,11 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime)) case VerifiedWithoutError => dataToWrite = executeAndHandle( + applicationId, + jobId, spark, flintStatement, - statementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -731,6 +774,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeQueryAsync( + applicationId: String, + jobId: String, spark: SparkSession, flintStatement: FlintStatement, statementsExecutionManager: StatementExecutionManager, @@ -743,6 +788,8 @@ object FlintREPL extends Logging with FlintJobExecutor { if (currentTimeProvider .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, "wait timeout", @@ -780,6 +827,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // processing. if (!earlyExitFlag && !sessionDetails.isComplete && !sessionDetails.isFail) { sessionDetails.complete() + logInfo(s"jobId before shutting down session: ${sessionDetails.jobId}") sessionManager.updateSessionDetails(sessionDetails, updateMode = UPDATE_IF) recordSessionSuccess(sessionTimerContext) } @@ -929,7 +977,6 @@ object FlintREPL extends Logging with FlintJobExecutor { private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (className.isEmpty) { - logInfo("Using default constructor") defaultConstructor } else { try { @@ -953,17 +1000,15 @@ object FlintREPL extends Logging with FlintJobExecutor { resultIndexOption: Option[String]): SessionManager = { instantiate( new SessionManagerImpl(spark, resultIndexOption), - spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) } private def instantiateStatementExecutionManager( - spark: SparkSession, - sessionId: String, - dataSource: String, - context: Map[String, Any]): StatementExecutionManager = { + commandContext: CommandContext): StatementExecutionManager = { + import commandContext._ instantiate( - new StatementExecutionManagerImpl(spark, sessionId, dataSource, context), - spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + new StatementExecutionManagerImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), spark, sessionId) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index c079b3e96..b49f4a9ed 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.util.ShuffleCleaner import org.apache.spark.util.ThreadUtils case class JobOperator( + applicationId: String, + jobId: String, spark: SparkSession, query: String, dataSource: String, @@ -51,13 +53,23 @@ case class JobOperator( val futureMappingCheck = Future { checkAndCreateIndex(osClient, resultIndex) } - val data = executeQuery(spark, query, dataSource, "", "", streaming) + val data = executeQuery(applicationId, jobId, spark, query, dataSource, "", "", streaming) val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) dataToWrite = Some(mappingCheckResult match { case Right(_) => data case Left(error) => - constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "FAILED", + error, + "", + query, + "", + startTime) }) exceptionThrown = false } catch { @@ -65,11 +77,31 @@ case class JobOperator( val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) dataToWrite = Some( - constructErrorDF(spark, dataSource, "TIMEOUT", error, "", query, "", startTime)) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "TIMEOUT", + error, + "", + query, + "", + startTime)) case e: Exception => val error = processQueryException(e) dataToWrite = Some( - constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime)) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "FAILED", + error, + "", + query, + "", + startTime)) } finally { cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index 79bd63200..2039159e4 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -8,10 +8,9 @@ package org.apache.spark.sql import scala.util.{Failure, Success, Try} import org.json4s.native.Serialization -import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.common.model.InteractiveSession import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.logging.CustomLogging -import org.opensearch.flint.core.storage.FlintReader import org.apache.spark.internal.Logging import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode @@ -27,14 +26,14 @@ class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String]) } // we don't allow default value for sessionIndex. Throw exception if key not found. - private val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") if (sessionIndex.isEmpty) { logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } - private val osClient = new OSClient(FlintSparkConf().flintOptions()) - private val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + val osClient = new OSClient(FlintSparkConf().flintOptions()) + lazy val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) override def getSessionContext: Map[String, Any] = { Map( @@ -50,7 +49,6 @@ class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String]) // Retrieve the source map and create session val sessionOption = Option(getResponse.getSourceAsMap) .map(InteractiveSession.deserializeFromMap) - // Retrieve sequence number and primary term from the response val seqNo = getResponse.getSeqNo val primaryTerm = getResponse.getPrimaryTerm diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index dc8414a17..0b059f1d3 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -6,21 +6,22 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.core.storage.OpenSearchUpdater import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createResultIndex, isSuperset, resultIndexMapping} -import org.apache.spark.sql.FlintREPL.executeQuery -class StatementExecutionManagerImpl( - spark: SparkSession, - sessionId: String, - dataSource: String, - context: Map[String, Any]) +/** + * StatementExecutionManagerImpl is session based implementation of StatementExecutionManager + * interface It uses FlintReader to fetch all pending queries in a mirco-batch + * @param commandContext + */ +class StatementExecutionManagerImpl(commandContext: CommandContext) extends StatementExecutionManager + with FlintJobExecutor with Logging { + private val context = commandContext.sessionManager.getSessionContext private val sessionIndex = context("sessionIndex").asInstanceOf[String] private val resultIndex = context("resultIndex").asInstanceOf[String] private val osClient = context("osClient").asInstanceOf[OSClient] @@ -53,10 +54,20 @@ class StatementExecutionManagerImpl( } override def executeStatement(statement: FlintStatement): DataFrame = { - executeQuery(spark, statement.query, dataSource, statement.queryId, sessionId, false) + import commandContext._ + executeQuery( + applicationId, + jobId, + spark, + statement.query, + dataSource, + statement.queryId, + sessionId, + false) } private def createOpenSearchQueryReader() = { + import commandContext._ // all state in index are in lower case // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the // same doc diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 19f596e31..339c1870d 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -11,6 +11,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider} class FlintJobTest extends SparkFunSuite with JobMatchers { + private val jobId = "testJobId" + private val applicationId = "testApplicationId" val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() @@ -55,8 +57,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { Array( "{'column_name':'Letter','data_type':'string'}", "{'column_name':'Number','data_type':'integer'}"), - "unknown", - "unknown", + jobId, + applicationId, dataSourceName, "SUCCESS", "", @@ -72,6 +74,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { // Compare the result val result = FlintJob.getFormattedData( + applicationId, + jobId, input, spark, dataSourceName, diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index bba0e40e2..433c2351b 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -18,14 +18,13 @@ import scala.reflect.runtime.universe.TypeTag import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer import org.mockito.{ArgumentMatchersSugar, Mockito} -import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} +import org.mockito.Mockito.{atLeastOnce, doNothing, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder -import org.scalatest.Ignore import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar @@ -49,6 +48,9 @@ class FlintREPLTest // By using a type alias and casting, I can bypass the type checking error. type AnyScheduledFuture = ScheduledFuture[_] + private val jobId = "testJobId" + private val applicationId = "testApplicationId" + test("parseArgs with no arguments should return (None, None)") { val args = Array.empty[String] val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) @@ -176,21 +178,15 @@ class FlintREPLTest } } - ignore("createHeartBeatUpdater should update heartbeat correctly") { + test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks val threadPool = mock[ScheduledExecutorService] val scheduledFutureRaw = mock[ScheduledFuture[_]] val sessionManager = mock[SessionManager] val sessionId = "session1" - val currentInterval = 1000L - val initialDelayMillis = 0L // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. - when( - threadPool.scheduleAtFixedRate( - any[Runnable], - eqTo(initialDelayMillis), - eqTo(currentInterval), - eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + when(threadPool + .scheduleAtFixedRate(any[Runnable], *, *, eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) .thenAnswer((invocation: InvocationOnMock) => { val runnable = invocation.getArgument[Runnable](0) runnable.run() @@ -201,7 +197,7 @@ class FlintREPLTest FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) // Verifications - verify(sessionManager).recordHeartbeat(sessionId) + verify(sessionManager, atLeastOnce()).recordHeartbeat(sessionId) } test("PreShutdownListener updates FlintInstance if conditions are met") { @@ -256,8 +252,8 @@ class FlintREPLTest Row( null, null, - "unknown", - "unknown", + jobId, + applicationId, dataSourceName, "FAILED", error, @@ -280,6 +276,8 @@ class FlintREPLTest // Compare the result val result = FlintREPL.handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSourceName, error, @@ -338,29 +336,21 @@ class FlintREPLTest assert(!result) // The function should return false } - ignore("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { + test("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - - val excludeJobIdsList = new java.util.ArrayList[String]() - excludeJobIdsList.add(jobId) // Add the jobId to the list to simulate exclusion - - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId) // The jobId matches - sourceMap.put("excludeJobIds", excludeJobIdsList) // But jobId is in the exclude list - when(getResponse.getSourceAsMap).thenReturn(sourceMap) - - // Mock the InteractiveSession - val interactiveSession = mock[InteractiveSession] - when(interactiveSession.jobId).thenReturn("jobABC") - when(interactiveSession.excludedJobIds).thenReturn(Seq(jobId)) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000, + Seq(jobId) // Add the jobId to the list to simulate exclusion + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -372,15 +362,9 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists but Source is Null") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" val sessionManager = mock[SessionManager] - // Mock the getDoc response - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null + when(sessionManager.getSessionDetails(sessionId)).thenReturn(None) // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -392,12 +376,21 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists with Unexpected Type in excludeJobIds") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) val sourceMap = new java.util.HashMap[String, Object]() @@ -417,13 +410,21 @@ class FlintREPLTest test("test canPickNextStatement: Doc Does Not Exist") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } // Set up the mock GetResponse val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist // Execute the function under test @@ -436,12 +437,20 @@ class FlintREPLTest test("test canPickNextStatement: OSClient Throws Exception") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } // Set up the mock OSClient to throw an exception - when(osClient.getDoc(sessionIndex, sessionId)) + when(mockOSClient.getDoc(sessionIndex, sessionId)) .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) // Execute the method under test and expect true, since the method is designed to return true even in case of an exception @@ -456,12 +465,20 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as a String that does NOT match jobId @@ -523,18 +540,30 @@ class FlintREPLTest test("Doc Exists and excludeJobIds is an ArrayList Containing JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val handleSessionError = mock[Function1[String, Unit]] - val sessionManager = mock[SessionManager] + val lastUpdateTime = System.currentTimeMillis() + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as an ArrayList containing jobId val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) // Creating an ArrayList and adding the jobId to it val excludeJobIdsList = new java.util.ArrayList[String]() @@ -542,19 +571,6 @@ class FlintREPLTest sourceMap.put("excludeJobIds", excludeJobIdsList.asInstanceOf[Object]) when(getResponse.getSourceAsMap).thenReturn(sourceMap) - - // Mock the InteractiveSession - val interactiveSession = new InteractiveSession( - applicationId = "app123", - sessionId = sessionId, - state = "active", - lastUpdateTime = System.currentTimeMillis(), - jobId = "jobABC", - excludedJobIds = Seq(jobId)) - - // Mock the sessionManager to return the mocked interactiveSession - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -562,20 +578,33 @@ class FlintREPLTest assert(!result) } - ignore("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { + test("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val sessionManager = mock[SessionManager] + val lastUpdateTime = System.currentTimeMillis() + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) // Creating an ArrayList and adding a different jobId to it val excludeJobIdsList = new java.util.ArrayList[String]() @@ -584,18 +613,6 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) - // Mock the InteractiveSession - val interactiveSession = new InteractiveSession( - applicationId = "app123", - sessionId = sessionId, - state = "active", - lastUpdateTime = System.currentTimeMillis(), - jobId = "jobABC", - excludedJobIds = Seq(jobId)) - - // Mock the sessionManager to return the mocked interactiveSession - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - // Execute the method under test val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) @@ -608,34 +625,52 @@ class FlintREPLTest val exception = new RuntimeException( new ConnectException( "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) when(mockReader.hasNext).thenThrow(exception) + when(mockOSClient.getIndexMetadata(any[String])).thenReturn(FlintREPL.resultIndexMapping) + val maxRetries = 1 var actualRetries = 0 - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" val applicationId = "testApplicationId" + val sessionIndex = "sessionIndex" + val lastUpdateTime = System.currentTimeMillis() + + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), 60, @@ -656,14 +691,15 @@ class FlintREPLTest } } - ignore("executeAndHandle should handle TimeoutException properly") { + test("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn("someSessionIndex") - val mockStatementsExecutionManager = mock[StatementExecutionManager] // val mockExecutionContextExecutor: ExecutionContextExecutor = mock[ExecutionContextExecutor] val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -688,19 +724,35 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val sparkContext = mock[SparkContext] when(mockSparkSession.sparkContext).thenReturn(sparkContext) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( + applicationId, + jobId, + mockSparkSession, + dataSource, + sessionId, + sessionManager, + queryResultWriter, + Duration(10, MINUTES), + 60, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + val statementExecutionManager = new StatementExecutionManagerImpl(commandContext) + val result = FlintREPL.executeAndHandle( + applicationId, + jobId, mockSparkSession, flintStatement, - mockStatementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, startTime, - // make sure it times out before mockSparkSession.sql can return, which takes 60 seconds Duration(1, SECONDS), 600000) @@ -712,13 +764,14 @@ class FlintREPLTest } finally threadPool.shutdown() } - ignore("executeAndHandle should handle ParseException properly") { + test("executeAndHandle should handle ParseException properly") { val mockSparkSession = mock[SparkSession] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) - val mockStatementsExecutionManager = mock[StatementExecutionManager] + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn("someSessionIndex") val flintStatement = new FlintStatement( @@ -748,15 +801,33 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( + applicationId, + jobId, + mockSparkSession, + dataSource, + sessionId, + sessionManager, + queryResultWriter, + Duration(10, MINUTES), + 60, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + val statementExecutionManager = new StatementExecutionManagerImpl(commandContext) + val result = FlintREPL.executeAndHandle( + applicationId, + jobId, mockSparkSession, flintStatement, - mockStatementsExecutionManager, + statementExecutionManager, dataSource, sessionId, executionContext, startTime, - Duration.Inf, // Use Duration.Inf or a large enough duration to avoid a timeout, + Duration.Inf, 600000) // Verify that ParseException was caught and handled @@ -767,22 +838,35 @@ class FlintREPLTest } test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) when(getResponse.getSourceAsMap).thenReturn( Map[String, Object]( "applicationId" -> "app1", "jobId" -> "job1", - "sessionId" -> "session1", + "sessionId" -> sessionId, "lastUpdateTime" -> java.lang.Long.valueOf(12345L), "error" -> "someError", "state" -> "running", @@ -790,21 +874,10 @@ class FlintREPLTest when(getResponse.getSeqNo).thenReturn(0L) when(getResponse.getPrimaryTerm).thenReturn(0L) - val interactiveSession = new InteractiveSession( - "app123", - jobId, - sessionId, - SessionStates.RUNNING, - lastUpdateTime = java.lang.Long.valueOf(12345L)) - - when(sessionManager.getSessionDetails(sessionId)) - .thenReturn(Some(interactiveSession)) - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - sessionId, + conf, + "session1", jobId, applicationId, sessionManager, @@ -812,33 +885,36 @@ class FlintREPLTest assert(!result) // Expecting false as the job should proceed normally } - ignore("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) - when(osClient.getDoc(*, *)).thenReturn(getResponse) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the rest of the GetResponse as needed - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") - val interactiveSession = new InteractiveSession( - "app123", - jobId, - sessionId, - SessionStates.RUNNING, - excludedJobIds = Seq(jobId), - lastUpdateTime = java.lang.Long.valueOf(12345L)) - - when(sessionManager.getSessionDetails(sessionId)) - .thenReturn(Some(interactiveSession)) + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", jobId) val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, + conf, sessionId, jobId, applicationId, @@ -848,16 +924,27 @@ class FlintREPLTest } test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) - when(osClient.getDoc(*, *)).thenReturn(getResponse) + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + // Mock the GetResponse to simulate a scenario of a duplicate job when(getResponse.getSourceAsMap).thenReturn( Map[String, Object]( @@ -872,20 +959,10 @@ class FlintREPLTest .asList("job-2", "job-1") // Include this inside the Map ).asJava) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") - val interactiveSession = new InteractiveSession( - "app1", - jobId, - sessionId, - SessionStates.RUNNING, - excludedJobIds = Seq("job-1", "job-2"), - lastUpdateTime = java.lang.Long.valueOf(12345L)) - // Mock sessionManager to return sessionDetails - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, + conf, sessionId, jobId, applicationId, @@ -895,31 +972,40 @@ class FlintREPLTest } test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val applicationId = "app1" - val jobId = "job1" + val sessionIndex = "sessionIndex" val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) - when(osClient.getDoc(*, *)).thenReturn(getResponse) + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") - val interactiveSession = new InteractiveSession( - "app1", - jobId, - sessionId, - SessionStates.RUNNING, - excludedJobIds = Seq("job-5", "job-6"), - lastUpdateTime = java.lang.Long.valueOf(12345L)) - // Mock sessionManager to return sessionDetails - when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, + conf, sessionId, jobId, applicationId, @@ -928,66 +1014,59 @@ class FlintREPLTest assert(!result) // Expecting false as the job proceeds normally } - ignore( - "setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") { - val osClient = mock[OSClient] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - val applicationId = "app1" - val jobId = "job1" - val sessionId = "session1" - val jobStartTime = System.currentTimeMillis() - val sessionManager = mock[SessionManager] - - assertThrows[NoSuchElementException] { - FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - sessionId, - jobId, - applicationId, - sessionManager, - System.currentTimeMillis()) - } - } - - ignore("queryLoop continue until inactivity limit is reached") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - + test("queryLoop continue until inactivity limit is reached") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) val shortInactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - // Mock processCommands to always allow loop continuation - when(statementLifecycleManager.getNextStatement()).thenReturn(None) - val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -1001,55 +1080,62 @@ class FlintREPLTest spark.stop() } - ignore("queryLoop should stop when canPickUpNextStatement is false") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(true) - - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] - + test("queryLoop should stop when canPickUpNextStatement is false") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + + // Mocking canPickNextStatement to return false + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { + val mockGetResponse = mock[GetResponse] + when(mockGetResponse.isExists()).thenReturn(true) + when(mockGetResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> "differentJobId", + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + mockGetResponse + }) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(true) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + val longInactivityLimit = 10000 // 10 seconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - // Mocking canPickNextStatement to return false - when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { - val mockGetResponse = mock[GetResponse] - when(mockGetResponse.isExists()).thenReturn(true) - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", "differentJobId") - when(mockGetResponse.getSourceAsMap).thenReturn(sourceMap) - mockGetResponse - }) - - // Mock getNextStatement to return None, simulating the end of statements - when(statementLifecycleManager.getNextStatement()).thenReturn(None) - val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -1063,38 +1149,52 @@ class FlintREPLTest spark.stop() } - ignore("queryLoop should properly shut down the thread pool after execution") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - + test("queryLoop should properly shut down the thread pool after execution") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] - when(statementLifecycleManager.getNextStatement()).thenReturn(None) + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit, @@ -1118,36 +1218,53 @@ class FlintREPLTest } test("queryLoop handle exceptions within the loop gracefully") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - // Simulate an exception thrown when hasNext is called - when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + // Simulate an exception thrown when hasNext is called + when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit, @@ -1172,16 +1289,25 @@ class FlintREPLTest } } - ignore("queryLoop should correctly update loop control variables") { + test("queryLoop should correctly update loop control variables") { + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) - when(osClient.doesIndexExist(*)).thenReturn(true) - when(osClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) // Configure mockReader to return true once and then false to exit the loop when(mockReader.hasNext).thenReturn(true).thenReturn(false) @@ -1197,12 +1323,6 @@ class FlintREPLTest """ when(mockReader.next).thenReturn(command) - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" - val inactivityLimit = 5000 // 5 seconds // Create a SparkSession for testing\ @@ -1217,22 +1337,26 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + when(mockSparkSession.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, "")) + .thenReturn("") when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] + val sessionManager = new SessionManagerImpl(mockSparkSession, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, mockSparkSession, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit, @@ -1248,7 +1372,10 @@ class FlintREPLTest // Assuming processCommands updates the lastActivityTime to the current time assert(endTime - startTime >= inactivityLimit) - verify(osClient, times(1)).getIndexMetadata(*) + + val expectedCalls = + Math.ceil(inactivityLimit.toDouble / DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY).toInt + verify(mockOSClient, Mockito.atMost(expectedCalls)).getIndexMetadata(*) } val testCases = Table( @@ -1257,40 +1384,54 @@ class FlintREPLTest (100, 300L) // 100 ms, 300 ms ) - ignore( + test( "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) - when(mockReader.hasNext).thenReturn(false) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" - val sessionManager = mock[SessionManager] - val statementLifecycleManager = mock[StatementExecutionManager] - val queryResultWriter = mock[QueryResultWriter] + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + when( + mockOSClient + .createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + when(getResponse.isExists()).thenReturn(false) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when(mockReader.hasNext).thenReturn(false) // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, sessionId, sessionManager, - jobId, - statementLifecycleManager, queryResultWriter, Duration(10, MINUTES), inactivityLimit,