From cd8120364fe4124ad4537ed5f98d1a5011af025a Mon Sep 17 00:00:00 2001 From: Shri Saran Raj N Date: Mon, 9 Dec 2024 23:11:11 +0530 Subject: [PATCH] Implement FlintJob to handle all query types in WP mode Signed-off-by: Shri Saran Raj N --- .../flint/core/metrics/MetricConstants.java | 48 +- .../sql/flint/config/FlintSparkConf.scala | 10 + .../scala/org/apache/spark/sql/FlintJob.scala | 432 +++++++++++++++++- .../apache/spark/sql/FlintJobExecutor.scala | 6 + .../org/apache/spark/sql/FlintREPL.scala | 10 +- .../org/apache/spark/sql/JobOperator.scala | 53 ++- .../org/apache/spark/sql/WarmpoolJob.scala | 410 +++++++++++++++++ .../org/apache/spark/sql/FlintJobTest.scala | 305 ++++++++++++- 8 files changed, 1228 insertions(+), 46 deletions(-) create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 79e70b8c2..b38394b58 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -135,10 +135,56 @@ public final class MetricConstants { */ public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count"; + /** + * Metric for tracking the count of streaming jobs failed during query execution + */ + public static final String STREAMING_EXECUTION_FAILED_METRIC = "streaming.execution.failed.count"; + + /** + * Metric for tracking the count of streaming jobs failed during query result write + */ + public static final String STREAMING_RESULT_WRITER_FAILED_METRIC = "streaming.writer.failed.count"; + /** * Metric for tracking the latency of query execution (start to complete query execution) excluding result write. */ - public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime"; + public static final String QUERY_EXECUTION_TIME_METRIC = "streaming.query.execution.processingTime"; + + /** + * Metric for tracking the latency of query result write only (excluding query execution) + */ + public static final String QUERY_RESULT_WRITER_TIME_METRIC = "streaming.result.writer.processingTime"; + + /** + * Metric for tracking the latency of query total execution including result write. + */ + public static final String QUERY_TOTAL_TIME_METRIC = "streaming.query.total.processingTime"; + + /** + * Metric for tracking the latency of query execution (start to complete query execution) excluding result write. + */ + public static final String STATEMENT_QUERY_EXECUTION_TIME_METRIC = "statement.query.execution.processingTime"; + + /** + * Metric for tracking the latency of query result write only (excluding query execution) + */ + public static final String STATEMENT_RESULT_WRITER_TIME_METRIC = "statement.result.writer.processingTime"; + + /** + * Metric for tracking the latency of query total execution including result write. + */ + public static final String STATEMENT_QUERY_TOTAL_TIME_METRIC = "statement.query.total.processingTime"; + + /** + * Metric for tracking the count of interactive jobs failed during query execution + */ + public static final String STATEMENT_EXECUTION_FAILED_METRIC = "statement.execution.failed.count"; + + /** + * Metric for tracking the count of interactive jobs failed during query result write + */ + public static final String STATEMENT_RESULT_WRITER_FAILED_METRIC = "statement.writer.failed.count"; + /** * Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX) diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 364a8a1de..4d5164d7c 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -171,6 +171,12 @@ object FlintSparkConf { .doc("Enable external scheduler for index refresh") .createWithDefault("false") + val WARMPOOL_ENABLED = + FlintConfig("spark.flint.job.warmpoolEnabled") + .createWithDefault("false") + + val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional() + val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD = FlintConfig("spark.flint.job.externalScheduler.interval") .doc("Interval threshold in minutes for external scheduler to trigger index refresh") @@ -246,6 +252,10 @@ object FlintSparkConf { FlintConfig(s"spark.flint.job.requestIndex") .doc("Request index") .createOptional() + val RESULT_INDEX = + FlintConfig(s"spark.flint.job.resultIndex") + .doc("Result index") + .createOptional() val EXCLUDE_JOB_IDS = FlintConfig(s"spark.flint.deployment.excludeJobs") .doc("Exclude job ids") diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 04609cf3d..5bb08dd13 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -8,12 +8,22 @@ package org.apache.spark.sql import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import com.codahale.metrics.Timer +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging -import org.opensearch.flint.core.metrics.MetricConstants -import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} +import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintREPL.{exponentialBackoffRetry, handleCommandFailureAndGetFailedData, handleCommandTimeout} +import org.apache.spark.sql.FlintREPLConfConstants.{DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY, DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS, MAPPING_CHECK_TIMEOUT} import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint @@ -26,24 +36,176 @@ import org.apache.spark.sql.flint.config.FlintSparkConf * write sql query result to given opensearch index */ object FlintJob extends Logging with FlintJobExecutor { + + private val statementRunningCount = new AtomicInteger(0) + private val streamingRunningCount = new AtomicInteger(0) + def main(args: Array[String]): Unit = { val (queryOption, resultIndexOption) = parseArgs(args) val conf = createSparkConf() - val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH) - CustomLogging.logInfo(s"""Job type is: ${jobType}""") - conf.set(FlintSparkConf.JOB_TYPE.key, jobType) - - val dataSource = conf.get("spark.flint.datasource.name", "") - val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) - if (query.isEmpty) { - logAndThrow(s"Query undefined for the ${jobType} job.") + val sparkSession = createSparkSession(conf) + val applicationId = + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val segmentName = getSegmentName(sparkSession) + val warmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean + logInfo(s"WarmpoolEnabled: ${warmpoolEnabled}") + + if (!warmpoolEnabled) { + val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH) + CustomLogging.logInfo(s"""Job type is: ${jobType}""") + sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + + val dataSource = conf.get("spark.flint.datasource.name", "") + val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) + if (query.isEmpty) { + logAndThrow(s"Query undefined for the ${jobType} job.") + } + val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") + + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } + + processStreamingJob( + applicationId, + jobId, + query, + queryId, + dataSource, + resultIndexOption.get, + jobType, + sparkSession, + Map.empty) + } else { + // Read the values from the Spark configuration or fall back to the default values + val inactivityLimitMillis: Long = + conf.getLong( + FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS.key, + FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS) + val queryWaitTimeoutMillis: Long = + conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + + val sessionManager = FlintREPL.instantiateSessionManager(sparkSession, resultIndexOption) + val commandContext = CommandContext( + applicationId, + jobId, + sparkSession, + "", // In WP flow, FlintJob doesn't know the dataSource yet + "", // In WP flow, FlintJob doesn't know the jobType yet + "", // FlintJob doesn't use sessionId + sessionManager, + Duration.Inf, // FlintJob doesn't have queryExecutionTimeout + inactivityLimitMillis, + queryWaitTimeoutMillis, // Used only for interactive queries + queryLoopExecutionFrequency) + registerGauge( + String.format("%s.%s", segmentName, MetricConstants.STATEMENT_RUNNING_METRIC), + statementRunningCount) + registerGauge( + String.format("%s.%s", segmentName, MetricConstants.STREAMING_RUNNING_METRIC), + streamingRunningCount) + try { + exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { + queryLoop(commandContext, segmentName) + } + } finally { + sparkSession.stop() + + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() + + // Check for non-daemon threads that may prevent the driver from shutting down. + // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, + // which may be due to unresolved bugs in dependencies or threads not being properly shut down. + if (terminateJVM && threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + logInfo("A non-daemon thread in the driver is seen.") + // Exit the JVM to prevent resource leaks and potential emr-s job hung. + // A zero status code is used for a graceful shutdown without indicating an error. + // If exiting with non-zero status, emr-s job will fail. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully + System.exit(0) + } + } } - val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") + } + + def queryLoop(commandContext: CommandContext, segmentName: String): Unit = { + import commandContext._ + + val statementExecutionManager = FlintREPL.instantiateStatementExecutionManager(commandContext) + var canProceed = true + + try { + var lastActivityTime = currentTimeProvider.currentEpochMillis() + while (currentTimeProvider + .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canProceed) { + statementExecutionManager.getNextStatement() match { + case Some(flintStatement) => + flintStatement.running() + statementExecutionManager.updateStatement(flintStatement) + val jobType = spark.conf.get(FlintSparkConf.JOB_TYPE.key, FlintJobType.BATCH) + val dataSource = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + val resultIndex = spark.conf.get(FlintSparkConf.RESULT_INDEX.key) + val postQuerySelectionCommandContext = + commandContext.copy(dataSource = dataSource, jobType = jobType) - if (resultIndexOption.isEmpty) { - logAndThrow("resultIndex is not set") + CustomLogging.logInfo(s"""Job type is: ${jobType}""") + val queryResultWriter = FlintREPL.instantiateQueryResultWriter(spark, commandContext) + if (jobType.equalsIgnoreCase(FlintJobType.STREAMING) || jobType.equalsIgnoreCase( + FlintJobType.BATCH)) { + processStreamingJob( + applicationId, + jobId, + flintStatement.query, + flintStatement.queryId, + dataSource, + resultIndex, + jobType, + spark, + flintStatement.context) + } else { + processInteractiveJob( + spark, + postQuerySelectionCommandContext, + flintStatement, + segmentName, + statementExecutionManager, + queryResultWriter) + + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + } + case _ => + canProceed = false + } + } + } catch { + case t: Throwable => + throwableHandler.recordThrowable(s"Query loop execution failed.", t) + throw t + } finally { + statementExecutionManager.terminateStatementExecution() } + Thread.sleep(commandContext.queryLoopExecutionFrequency) + } + + private def processStreamingJob( + applicationId: String, + jobId: String, + query: String, + queryId: String, + dataSource: String, + resultIndex: String, + jobType: String, + sparkSession: SparkSession, + executionContext: Map[String, Any]): Unit = { // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -52,26 +214,250 @@ object FlintJob extends Logging with FlintJobExecutor { * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ - conf.set("spark.sql.defaultCatalog", dataSource) - configDYNMaxExecutors(conf, jobType) - - val applicationId = - environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + sparkSession.conf.set("spark.sql.defaultCatalog", dataSource) + val segmentName = sparkSession.conf.get("spark.dynamicAllocation.maxExecutors") val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( applicationId, jobId, - createSparkSession(conf), + sparkSession, query, queryId, dataSource, - resultIndexOption.get, + resultIndex, jobType, - streamingRunningCount) - registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + streamingRunningCount, + executionContext) + registerGauge( + String.format("%s.%s", segmentName, MetricConstants.STREAMING_RUNNING_METRIC), + streamingRunningCount) jobOperator.start() } + + def processInteractiveJob( + sparkSession: SparkSession, + commandContext: CommandContext, + flintStatement: FlintStatement, + segmentName: String, + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter): Unit = { + import commandContext._ + + var dataToWrite: Option[DataFrame] = None + val startTime: Long = currentTimeProvider.currentEpochMillis() + + statementRunningCount.incrementAndGet() + val statementTimerContext = getTimerContext(MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) + implicit val ec: ExecutionContext = ExecutionContext.global + + val futurePrepareQueryExecution = Future { + statementExecutionManager.prepareStatementExecution() + } + + try { + ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { + case Right(_) => + dataToWrite = executeAndHandleInteractiveJob( + sparkSession, + commandContext, + flintStatement, + segmentName, + startTime, + statementExecutionManager, + queryResultWriter) + case Left(error) => + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", + startTime)) + } + } catch { + case e: TimeoutException => + val error = s"Query execution preparation timed out" + CustomLogging.logError(error, e) + dataToWrite = Some( + handleCommandTimeout( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + "", + startTime)) + case NonFatal(e) => + val error = s"An unexpected error occurred: ${e.getMessage}" + throwableHandler.recordThrowable(error, e) + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) + } finally { + emitTimeMetric( + startTime, + segmentName, + MetricConstants.STATEMENT_QUERY_EXECUTION_TIME_METRIC) + finalizeCommand( + statementExecutionManager, + queryResultWriter, + dataToWrite, + flintStatement, + segmentName, + statementTimerContext) + emitTimeMetric(startTime, segmentName, MetricConstants.STATEMENT_QUERY_TOTAL_TIME_METRIC) + } + } + + def executeAndHandleInteractiveJob( + sparkSession: SparkSession, + commandContext: CommandContext, + flintStatement: FlintStatement, + segmentName: String, + startTime: Long, + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter): Option[DataFrame] = { + import commandContext._ + + try { + if (currentTimeProvider + .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { + Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + "wait timeout", + flintStatement, + "", // FlintJob doesn't use sessionId + startTime)) + } else { + // Execute the statement and get the resulting DataFrame + // This step may involve Spark transformations, but not necessarily actions + val df = statementExecutionManager.executeStatement(flintStatement) + // Process the DataFrame, applying any necessary transformations + // and triggering Spark actions to materialize the results + // This is where the actual data processing occurs + Some(queryResultWriter.processDataFrame(df, flintStatement, startTime)) + } + } catch { + case e: TimeoutException => + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STATEMENT_EXECUTION_FAILED_METRIC)) + val error = s"Query execution preparation timed out" + CustomLogging.logError(error, e) + Some( + FlintREPL.handleCommandTimeout( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", // FlintJob doesn't use sessionId + startTime)) + case t: Throwable => + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STATEMENT_EXECUTION_FAILED_METRIC)) + val error = FlintREPL.processQueryException(t, flintStatement) + CustomLogging.logError(error, t) + Some( + FlintREPL.handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", // FlintJob doesn't use sessionId + startTime)) + } + } + + /** + * finalize statement after processing + * + * @param dataToWrite + * data to write + * @param flintStatement + * flint statement + */ + private def finalizeCommand( + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, + dataToWrite: Option[DataFrame], + flintStatement: FlintStatement, + segmentName: String, + statementTimerContext: Timer.Context): Unit = { + val resultWriterStartTime: Long = currentTimeProvider.currentEpochMillis() + try { + dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) + if (flintStatement.isRunning || flintStatement.isWaiting) { + flintStatement.complete() + } + } catch { + case t: Throwable => + incrementCounter( + String + .format("%s.%s", segmentName, MetricConstants.STATEMENT_RESULT_WRITER_FAILED_METRIC)) + val error = + s"""Fail to write result of ${flintStatement}, cause: ${throwableHandler.error}""" + throwableHandler.recordThrowable(error, t) + CustomLogging.logError(error, t) + flintStatement.fail() + } finally { + if (throwableHandler.hasException) flintStatement.fail() else flintStatement.complete() + flintStatement.error = Some(throwableHandler.error) + + emitTimeMetric( + resultWriterStartTime, + segmentName, + MetricConstants.STATEMENT_RESULT_WRITER_TIME_METRIC) + statementExecutionManager.updateStatement(flintStatement) + recordStatementStateChange( + statementRunningCount, + flintStatement, + statementTimerContext, + segmentName) + } + } + + private def recordStatementStateChange( + statementRunningCount: AtomicInteger, + flintStatement: FlintStatement, + statementTimerContext: Timer.Context, + segmentName: String): Unit = { + + stopTimer(statementTimerContext) + if (statementRunningCount.get() > 0) { + statementRunningCount.decrementAndGet() + } + + if (flintStatement.isComplete) { + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STATEMENT_SUCCESS_METRIC)) + } else if (flintStatement.isFailed) { + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STATEMENT_FAILED_METRIC)) + } + } + + private def emitTimeMetric(startTime: Long, segmentName: String, metricType: String): Unit = { + val metricName = String.format("%s.%s", segmentName, metricType) + MetricsUtil.addHistoricGauge(metricName, System.currentTimeMillis() - startTime) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index ad26cf21a..586f25f0c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -547,6 +547,12 @@ trait FlintJobExecutor { } } + def getSegmentName(sparkSession: SparkSession): String = { + val maxExecutorsCount = + sparkSession.conf.get(FlintSparkConf.MAX_EXECUTORS_COUNT.key, "unknown") + String.format("%se", maxExecutorsCount) + } + def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (Strings.isNullOrEmpty(className)) { defaultConstructor diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 6d7dcc0e7..0c50e365d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -610,7 +610,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def handleCommandTimeout( + def handleCommandTimeout( applicationId: String, jobId: String, spark: SparkSession, @@ -618,7 +618,7 @@ object FlintREPL extends Logging with FlintJobExecutor { error: String, flintStatement: FlintStatement, sessionId: String, - startTime: Long) = { + startTime: Long): DataFrame = { /* * https://tinyurl.com/2ezs5xj9 * @@ -1021,7 +1021,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def instantiateSessionManager( + def instantiateSessionManager( spark: SparkSession, resultIndexOption: Option[String]): SessionManager = { instantiate( @@ -1030,7 +1030,7 @@ object FlintREPL extends Logging with FlintJobExecutor { resultIndexOption.getOrElse("")) } - private def instantiateStatementExecutionManager( + def instantiateStatementExecutionManager( commandContext: CommandContext): StatementExecutionManager = { import commandContext._ instantiate( @@ -1040,7 +1040,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId) } - private def instantiateQueryResultWriter( + def instantiateQueryResultWriter( spark: SparkSession, commandContext: CommandContext): QueryResultWriter = { instantiate( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 27b0be84f..3d9332999 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -32,7 +32,8 @@ case class JobOperator( dataSource: String, resultIndex: String, jobType: String, - streamingRunningCount: AtomicInteger) + streamingRunningCount: AtomicInteger, + statementContext: Map[String, Any] = Map.empty[String, Any]) extends Logging with FlintJobExecutor { @@ -42,6 +43,7 @@ case class JobOperator( def start(): Unit = { val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + val segmentName = getSegmentName(sparkSession) var dataToWrite: Option[DataFrame] = None @@ -80,7 +82,9 @@ case class JobOperator( "", queryId, LangType.SQL, - currentTimeProvider.currentEpochMillis()) + currentTimeProvider.currentEpochMillis(), + Option.empty, + statementContext) try { val futurePrepareQueryExecution = Future { @@ -107,6 +111,8 @@ case class JobOperator( } catch { case e: TimeoutException => throwableHandler.recordThrowable(s"Preparation for query execution timed out", e) + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STREAMING_EXECUTION_FAILED_METRIC)) dataToWrite = Some( constructErrorDF( applicationId, @@ -120,6 +126,8 @@ case class JobOperator( "", startTime)) case t: Throwable => + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STREAMING_EXECUTION_FAILED_METRIC)) val error = processQueryException(t) dataToWrite = Some( constructErrorDF( @@ -134,10 +142,11 @@ case class JobOperator( "", startTime)) } finally { - emitQueryExecutionTimeMetric(startTime) + emitTimeMetric(startTime, segmentName, MetricConstants.QUERY_EXECUTION_TIME_METRIC) readWriteBytesSparkListener.emitMetrics() sparkSession.sparkContext.removeSparkListener(readWriteBytesSparkListener) + val resultWriteStartTime = System.currentTimeMillis() try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) } catch { @@ -145,6 +154,14 @@ case class JobOperator( throwableHandler.recordThrowable( s"Failed to write to result index. originalError='${throwableHandler.error}'", t) + incrementCounter(String + .format("%s.%s", segmentName, MetricConstants.STREAMING_RESULT_WRITER_FAILED_METRIC)) + } finally { + emitTimeMetric( + resultWriteStartTime, + segmentName, + MetricConstants.QUERY_RESULT_WRITER_TIME_METRIC) + emitTimeMetric(startTime, segmentName, MetricConstants.QUERY_TOTAL_TIME_METRIC) } if (throwableHandler.hasException) statement.fail() else statement.complete() statement.error = Some(throwableHandler.error) @@ -158,11 +175,11 @@ case class JobOperator( t) } - cleanUpResources(threadPool) + cleanUpResources(threadPool, segmentName) } } - def cleanUpResources(threadPool: ThreadPoolExecutor): Unit = { + def cleanUpResources(threadPool: ThreadPoolExecutor, segmentName: String): Unit = { val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) try { // Wait for streaming job complete if no error @@ -190,7 +207,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } - recordStreamingCompletionStatus(throwableHandler.hasException) + recordStreamingCompletionStatus(throwableHandler.hasException, segmentName) // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, @@ -205,13 +222,6 @@ case class JobOperator( } } - private def emitQueryExecutionTimeMetric(startTime: Long): Unit = { - MetricsUtil - .addHistoricGauge( - MetricConstants.QUERY_EXECUTION_TIME_METRIC, - System.currentTimeMillis() - startTime) - } - def stop(): Unit = { Try { logInfo("Stopping Spark session") @@ -236,15 +246,21 @@ case class JobOperator( * @param exceptionThrown * Indicates whether an exception was thrown during the streaming job execution. */ - private def recordStreamingCompletionStatus(exceptionThrown: Boolean): Unit = { + private def recordStreamingCompletionStatus( + exceptionThrown: Boolean, + segmentName: String): Unit = { // Decrement the metric for running streaming jobs as the job is now completing. if (streamingRunningCount.get() > 0) { streamingRunningCount.decrementAndGet() } exceptionThrown match { - case true => incrementCounter(MetricConstants.STREAMING_FAILED_METRIC) - case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) + case true => + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STREAMING_FAILED_METRIC)) + case false => + incrementCounter( + String.format("%s.%s", segmentName, MetricConstants.STREAMING_SUCCESS_METRIC)) } } @@ -259,4 +275,9 @@ case class JobOperator( spark, sessionId) } + + private def emitTimeMetric(startTime: Long, segmentName: String, metricType: String): Unit = { + val metricName = String.format("%s.%s", segmentName, metricType) + MetricsUtil.addHistoricGauge(metricName, System.currentTimeMillis() - startTime) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala new file mode 100644 index 000000000..f83c3d969 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala @@ -0,0 +1,410 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import com.codahale.metrics.Timer +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.logging.CustomLogging +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} +import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, registerGauge, stopTimer} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.WarmpoolJobConfConstants._ +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils + +object WarmpoolJobConfConstants { + val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L +} + +case class WarmpoolJob( + conf: SparkConf, + sparkSession: SparkSession, + resultIndexOption: Option[String]) + extends Logging + with FlintJobExecutor { + + private val statementRunningCount = new AtomicInteger(0) + private val streamingRunningCount = new AtomicInteger(0) + private val segmentName = getSegmentName(sparkSession) + + def start(): Unit = { + // Read the values from the Spark configuration or fall back to the default values + val inactivityLimitMillis: Long = + conf.getLong( + FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS.key, + FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS) + val queryWaitTimeoutMillis: Long = + conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + + val sessionManager = instantiateSessionManager(sparkSession, resultIndexOption) + val commandContext = CommandContext( + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + sparkSession, + "", // In WP flow, datasource is not known yet + "", // In WP flow, jobType is not know yet + "", // WP doesn't use sessionId + sessionManager, + Duration.Inf, // WP doesn't have queryExecutionTimeout + inactivityLimitMillis, + queryWaitTimeoutMillis, // Used only for interactive queries + queryLoopExecutionFrequency) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + try { + FlintREPL.exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { + queryLoop(commandContext) + } + } finally { + sparkSession.stop() + + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() + + // Check for non-daemon threads that may prevent the driver from shutting down. + // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, + // which may be due to unresolved bugs in dependencies or threads not being properly shut down. + if (terminateJVM && threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + logInfo("A non-daemon thread in the driver is seen.") + // Exit the JVM to prevent resource leaks and potential emr-s job hung. + // A zero status code is used for a graceful shutdown without indicating an error. + // If exiting with non-zero status, emr-s job will fail. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully + System.exit(0) + } + } + } + + def queryLoop(commandContext: CommandContext): Unit = { + import commandContext._ + + val statementExecutionManager = instantiateStatementExecutionManager(commandContext) + var canProceed = true + + try { + var lastActivityTime = currentTimeProvider.currentEpochMillis() + + while (currentTimeProvider + .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canProceed) { + statementExecutionManager.getNextStatement() match { + case Some(flintStatement) => + flintStatement.running() + statementExecutionManager.updateStatement(flintStatement) + + val jobType = spark.conf.get(FlintSparkConf.JOB_TYPE.key, FlintJobType.BATCH) + val dataSource = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + val resultIndex = spark.conf.get(FlintSparkConf.RESULT_INDEX.key) + + val postQuerySelectionCommandContext = + commandContext.copy(dataSource = dataSource, jobType = jobType) + + CustomLogging.logInfo(s"""Job type is: ${jobType}""") + val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) + + if (jobType.equalsIgnoreCase(FlintJobType.STREAMING) || jobType.equalsIgnoreCase( + FlintJobType.BATCH)) { + processStreamingJob( + applicationId, + jobId, + flintStatement.query, + flintStatement.queryId, + dataSource, + resultIndex, + jobType, + spark, + flintStatement.context) + } else { + processInteractiveJob( + spark, + postQuerySelectionCommandContext, + flintStatement, + statementExecutionManager, + queryResultWriter) + + // Last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + } + + case _ => + canProceed = false + } + } + } catch { + case t: Throwable => + throwableHandler.recordThrowable(s"Query loop execution failed.", t) + throw t + } finally { + statementExecutionManager.terminateStatementExecution() + } + + Thread.sleep(commandContext.queryLoopExecutionFrequency) + } + + private def processStreamingJob( + applicationId: String, + jobId: String, + query: String, + queryId: String, + dataSource: String, + resultIndex: String, + jobType: String, + sparkSession: SparkSession, + executionContext: Map[String, Any]): Unit = { + + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + sparkSession.conf.set("spark.sql.defaultCatalog", dataSource) + + val streamingRunningCount = new AtomicInteger(0) + val jobOperator = JobOperator( + applicationId, + jobId, + sparkSession, + query, + queryId, + dataSource, + resultIndex, + jobType, + streamingRunningCount, + executionContext) + + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + jobOperator.start() + } + + private def processInteractiveJob( + sparkSession: SparkSession, + commandContext: CommandContext, + flintStatement: FlintStatement, + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter): Unit = { + + import commandContext._ + + var dataToWrite: Option[DataFrame] = None + val startTime: Long = currentTimeProvider.currentEpochMillis() + + statementRunningCount.incrementAndGet() + + val statementTimerContext = getTimerContext(MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) + implicit val ec: ExecutionContext = ExecutionContext.global + + val futurePrepareQueryExecution = Future { + statementExecutionManager.prepareStatementExecution() + } + + try { + ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { + case Right(_) => + dataToWrite = executeAndHandleInteractiveJob( + sparkSession, + commandContext, + flintStatement, + startTime, + statementExecutionManager, + queryResultWriter) + + case Left(error) => + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", + startTime)) + } + } catch { + case e: TimeoutException => + val error = s"Query execution preparation timed out" + CustomLogging.logError(error, e) + dataToWrite = Some( + handleCommandTimeout( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + "", + startTime)) + + case NonFatal(e) => + val error = s"An unexpected error occurred: ${e.getMessage}" + throwableHandler.recordThrowable(error, e) + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) + } finally { + emitTimeMetric(startTime, MetricConstants.STATEMENT_QUERY_EXECUTION_TIME_METRIC) + finalizeCommand( + statementExecutionManager, + queryResultWriter, + dataToWrite, + flintStatement, + statementTimerContext) + emitTimeMetric(startTime, MetricConstants.STATEMENT_QUERY_TOTAL_TIME_METRIC) + } + } + + private def executeAndHandleInteractiveJob( + sparkSession: SparkSession, + commandContext: CommandContext, + flintStatement: FlintStatement, + startTime: Long, + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter): Option[DataFrame] = { + + import commandContext._ + + try { + if (currentTimeProvider + .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { + Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + "wait timeout", + flintStatement, + "", // WP doesn't use sessionId + startTime)) + } else { + // Execute the statement and get the resulting DataFrame + val df = statementExecutionManager.executeStatement(flintStatement) + // Process the DataFrame, applying any necessary transformations + // and triggering Spark actions to materialize the results + Some(queryResultWriter.processDataFrame(df, flintStatement, startTime)) + } + } catch { + case e: TimeoutException => + incrementCounter(MetricConstants.STATEMENT_EXECUTION_FAILED_METRIC) + val error = s"Query execution preparation timed out" + CustomLogging.logError(error, e) + Some( + handleCommandTimeout( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", // WP doesn't use sessionId + startTime)) + + case t: Throwable => + incrementCounter(MetricConstants.STATEMENT_EXECUTION_FAILED_METRIC) + val error = FlintREPL.processQueryException(t, flintStatement) + CustomLogging.logError(error, t) + Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", // WP doesn't use sessionId + startTime)) + } + } + + /** + * Finalize statement after processing + * + * @param dataToWrite + * Data to write + * @param flintStatement + * Flint statement + */ + private def finalizeCommand( + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, + dataToWrite: Option[DataFrame], + flintStatement: FlintStatement, + statementTimerContext: Timer.Context): Unit = { + + val resultWriterStartTime: Long = currentTimeProvider.currentEpochMillis() + + try { + dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) + + if (flintStatement.isRunning || flintStatement.isWaiting) { + flintStatement.complete() + } + } catch { + case t: Throwable => + incrementCounter(MetricConstants.STATEMENT_RESULT_WRITER_FAILED_METRIC) + val error = + s"""Fail to write result of ${flintStatement}, cause: ${throwableHandler.error}""" + throwableHandler.recordThrowable(error, t) + CustomLogging.logError(error, t) + flintStatement.fail() + } finally { + if (throwableHandler.hasException) flintStatement.fail() else flintStatement.complete() + flintStatement.error = Some(throwableHandler.error) + + emitTimeMetric(resultWriterStartTime, MetricConstants.STATEMENT_RESULT_WRITER_TIME_METRIC) + statementExecutionManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) + } + } + + private def emitTimeMetric(startTime: Long, metricName: String): Unit = { + val metricNameWithSegment = String.format("%s.%s", segmentName, metricName) + MetricsUtil.addHistoricGauge(metricNameWithSegment, System.currentTimeMillis() - startTime) + } + + private def recordStatementStateChange( + flintStatement: FlintStatement, + statementTimerContext: Timer.Context): Unit = { + stopTimer(statementTimerContext) + if (statementRunningCount.get() > 0) { + statementRunningCount.decrementAndGet() + } + if (flintStatement.isComplete) { + incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) + } else if (flintStatement.isFailed) { + incrementCounter(MetricConstants.STATEMENT_FAILED_METRIC) + } + } + + private def incrementCounter(metricName: String) { + val metricWithSegmentName = String.format("%s.%s", segmentName, metricName) + MetricsUtil.incrementCounter(metricWithSegmentName) + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 339c1870d..5952859b3 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -5,14 +5,36 @@ package org.apache.spark.sql +import java.net.ConnectException + +import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.concurrent.duration.{Duration, MINUTES} + +import org.mockito.ArgumentMatchersSugar +import org.mockito.Mockito.when +import org.opensearch.action.get.GetResponse +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader} +import org.opensearch.search.sort.SortOrder +import org.scalatestplus.mockito.MockitoSugar + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY +import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} +import org.apache.spark.sql.exception.UnrecoverableException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider} -class FlintJobTest extends SparkFunSuite with JobMatchers { +class FlintJobTest + extends SparkFunSuite + with MockitoSugar + with ArgumentMatchersSugar + with JobMatchers { private val jobId = "testJobId" private val applicationId = "testApplicationId" + private val ONE_EXECUTOR_SEGMENT = "1e" + private val INTERACTIVE_JOB_TYPE = "interactive" val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() @@ -122,4 +144,285 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { spark.sparkContext.conf.get("spark.dynamicAllocation.maxExecutors") shouldBe "30" } + test("createSparkConf should set the app name and default SQL extensions") { + val conf = FlintJob.createSparkConf() + + // Assert that the app name is set correctly + assert(conf.get("spark.app.name") === "FlintJob$") + + // Assert that the default SQL extensions are set correctly + assert(conf.get(SQL_EXTENSIONS_KEY) === DEFAULT_SQL_EXTENSIONS) + } + + test( + "createSparkConf should not use defaultExtensions if spark.sql.extensions is already set") { + val customExtension = "my.custom.extension" + // Set the spark.sql.extensions property before calling createSparkConf + System.setProperty(SQL_EXTENSIONS_KEY, customExtension) + + try { + val conf = FlintJob.createSparkConf() + assert(conf.get(SQL_EXTENSIONS_KEY) === customExtension) + } finally { + // Clean up the system property after the test + System.clearProperty(SQL_EXTENSIONS_KEY) + } + } + + test("exponentialBackoffRetry should retry on ConnectException") { + val mockReader = mock[OpenSearchReader] + val exception = new RuntimeException( + new ConnectException( + "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenThrow(exception) + + when(mockOSClient.getIndexMetadata(any[String])).thenReturn(FlintJob.resultIndexMapping) + + val maxRetries = 1 + var actualRetries = 0 + + val jobId = "testJobId" + val applicationId = "testApplicationId" + val sessionIndex = "sessionIndex" + val resultIndexOption = Some("testResultIndex") + val lastUpdateTime = System.currentTimeMillis() + + // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) + + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + try { + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } + + val commandContext = CommandContext( + applicationId, + jobId, + spark, + "", + "", + "", + sessionManager, + Duration.Inf, + 60, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + + intercept[RuntimeException] { + FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { + actualRetries += 1 + FlintJob.queryLoop(commandContext, ONE_EXECUTOR_SEGMENT) + } + } + + assert(actualRetries == maxRetries) + } finally { + // Stop the SparkSession + spark.stop() + } + } + + test("queryLoop continue until inactivity limit is reached") { + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val resultIndexOption = Some("testResultIndex") + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val shortInactivityLimit = 50 // 50 milliseconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + + val commandContext = CommandContext( + applicationId, + jobId, + spark, + dataSource, + INTERACTIVE_JOB_TYPE, + sessionId, + sessionManager, + Duration(10, MINUTES), + shortInactivityLimit, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + + val startTime = System.currentTimeMillis() + + FlintJob.queryLoop(commandContext, ONE_EXECUTOR_SEGMENT) + + val endTime = System.currentTimeMillis() + + // Check if the loop ran for approximately the duration of the inactivity limit + assert(endTime - startTime >= shortInactivityLimit) + + // Stop the SparkSession + spark.stop() + } + + test("queryLoop should properly shut down the thread pool after execution") { + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val resultIndexOption = Some("testResultIndex") + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + val inactivityLimit = 500 // 500 milliseconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + + val commandContext = CommandContext( + applicationId, + jobId, + spark, + dataSource, + INTERACTIVE_JOB_TYPE, + sessionId, + sessionManager, + Duration(10, MINUTES), + inactivityLimit, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + + try { + FlintJob.queryLoop(commandContext, ONE_EXECUTOR_SEGMENT) + + } finally { + // Stop the SparkSession + spark.stop() + } + } + + test("queryLoop handle exceptions within the loop gracefully") { + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val resultIndexOption = Some("testResultIndex") + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + // Simulate an exception thrown when hasNext is called + val unrecoverableException = UnrecoverableException(new RuntimeException("Test exception")) + when(mockReader.hasNext).thenThrow(unrecoverableException) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val inactivityLimit = 500 // 500 milliseconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + + val commandContext = CommandContext( + applicationId, + jobId, + spark, + dataSource, + INTERACTIVE_JOB_TYPE, + sessionId, + sessionManager, + Duration(10, MINUTES), + inactivityLimit, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + + try { + intercept[UnrecoverableException] { + FlintJob.queryLoop(commandContext, ONE_EXECUTOR_SEGMENT) + } + + FlintJob.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + } finally { + // Stop the SparkSession + spark.stop() + } + } }