diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 79e70b8c2..b38394b58 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -135,10 +135,56 @@ public final class MetricConstants { */ public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count"; + /** + * Metric for tracking the count of streaming jobs failed during query execution + */ + public static final String STREAMING_EXECUTION_FAILED_METRIC = "streaming.execution.failed.count"; + + /** + * Metric for tracking the count of streaming jobs failed during query result write + */ + public static final String STREAMING_RESULT_WRITER_FAILED_METRIC = "streaming.writer.failed.count"; + /** * Metric for tracking the latency of query execution (start to complete query execution) excluding result write. */ - public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime"; + public static final String QUERY_EXECUTION_TIME_METRIC = "streaming.query.execution.processingTime"; + + /** + * Metric for tracking the latency of query result write only (excluding query execution) + */ + public static final String QUERY_RESULT_WRITER_TIME_METRIC = "streaming.result.writer.processingTime"; + + /** + * Metric for tracking the latency of query total execution including result write. + */ + public static final String QUERY_TOTAL_TIME_METRIC = "streaming.query.total.processingTime"; + + /** + * Metric for tracking the latency of query execution (start to complete query execution) excluding result write. + */ + public static final String STATEMENT_QUERY_EXECUTION_TIME_METRIC = "statement.query.execution.processingTime"; + + /** + * Metric for tracking the latency of query result write only (excluding query execution) + */ + public static final String STATEMENT_RESULT_WRITER_TIME_METRIC = "statement.result.writer.processingTime"; + + /** + * Metric for tracking the latency of query total execution including result write. + */ + public static final String STATEMENT_QUERY_TOTAL_TIME_METRIC = "statement.query.total.processingTime"; + + /** + * Metric for tracking the count of interactive jobs failed during query execution + */ + public static final String STATEMENT_EXECUTION_FAILED_METRIC = "statement.execution.failed.count"; + + /** + * Metric for tracking the count of interactive jobs failed during query result write + */ + public static final String STATEMENT_RESULT_WRITER_FAILED_METRIC = "statement.writer.failed.count"; + /** * Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX) diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 364a8a1de..4d5164d7c 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -171,6 +171,12 @@ object FlintSparkConf { .doc("Enable external scheduler for index refresh") .createWithDefault("false") + val WARMPOOL_ENABLED = + FlintConfig("spark.flint.job.warmpoolEnabled") + .createWithDefault("false") + + val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional() + val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD = FlintConfig("spark.flint.job.externalScheduler.interval") .doc("Interval threshold in minutes for external scheduler to trigger index refresh") @@ -246,6 +252,10 @@ object FlintSparkConf { FlintConfig(s"spark.flint.job.requestIndex") .doc("Request index") .createOptional() + val RESULT_INDEX = + FlintConfig(s"spark.flint.job.resultIndex") + .doc("Result index") + .createOptional() val EXCLUDE_JOB_IDS = FlintConfig(s"spark.flint.deployment.excludeJobs") .doc("Exclude job ids") diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 04609cf3d..fdbab71dc 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -30,48 +30,57 @@ object FlintJob extends Logging with FlintJobExecutor { val (queryOption, resultIndexOption) = parseArgs(args) val conf = createSparkConf() - val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH) - CustomLogging.logInfo(s"""Job type is: ${jobType}""") - conf.set(FlintSparkConf.JOB_TYPE.key, jobType) - - val dataSource = conf.get("spark.flint.datasource.name", "") - val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) - if (query.isEmpty) { - logAndThrow(s"Query undefined for the ${jobType} job.") - } - val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") - - if (resultIndexOption.isEmpty) { - logAndThrow("resultIndex is not set") - } - // https://github.com/opensearch-project/opensearch-spark/issues/138 - /* - * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, - * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), - * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. - * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. - * Without this setup, Spark would not recognize names in the format `my_glue1.default`. - */ - conf.set("spark.sql.defaultCatalog", dataSource) - configDYNMaxExecutors(conf, jobType) - + val sparkSession = createSparkSession(conf) val applicationId = environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val warmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean + logInfo(s"WarmpoolEnabled: ${warmpoolEnabled}") - val streamingRunningCount = new AtomicInteger(0) - val jobOperator = - JobOperator( - applicationId, - jobId, - createSparkSession(conf), - query, - queryId, - dataSource, - resultIndexOption.get, - jobType, - streamingRunningCount) - registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) - jobOperator.start() + if (!warmpoolEnabled) { + val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH) + CustomLogging.logInfo(s"""Job type is: ${jobType}""") + sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + + val dataSource = conf.get("spark.flint.datasource.name", "") + val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) + if (query.isEmpty) { + logAndThrow(s"Query undefined for the ${jobType} job.") + } + val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") + + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + conf.set("spark.sql.defaultCatalog", dataSource) + configDYNMaxExecutors(conf, jobType) + + val streamingRunningCount = new AtomicInteger(0) + val jobOperator = + JobOperator( + applicationId, + jobId, + sparkSession, + query, + queryId, + dataSource, + resultIndexOption.get, + jobType, + streamingRunningCount) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + jobOperator.start() + } else { + // Fetch and execute queries in warm pool mode + val warmpoolJob = WarmpoolJob(conf, sparkSession, resultIndexOption) + warmpoolJob.start() + } } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index ad26cf21a..b1147681c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception import com.fasterxml.jackson.databind.ObjectMapper import org.apache.commons.text.StringEscapeUtils.unescapeJava import org.opensearch.common.Strings +import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.core.IRestHighLevelClient import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage} import org.opensearch.flint.core.metrics.MetricConstants @@ -20,6 +21,7 @@ import play.api.libs.json._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintREPL.instantiate import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.exception.UnrecoverableException @@ -470,6 +472,13 @@ trait FlintJobExecutor { else getRootCause(t.getCause) } + def processQueryException(t: Throwable, flintStatement: FlintStatement): String = { + val error = processQueryException(t) + flintStatement.fail() + flintStatement.error = Some(error) + error + } + /** * This method converts query exception into error string, which then persist to query result * metadata @@ -515,6 +524,87 @@ trait FlintJobExecutor { } } + def handleCommandTimeout( + applicationId: String, + jobId: String, + spark: SparkSession, + dataSource: String, + error: String, + flintStatement: FlintStatement, + sessionId: String, + startTime: Long): DataFrame = { + /* + * https://tinyurl.com/2ezs5xj9 + * + * This only interrupts active Spark jobs that are actively running. + * This would then throw the error from ExecutePlan and terminate it. + * But if the query is not running a Spark job, but executing code on Spark driver, this + * would be a noop and the execution will keep running. + * + * In Apache Spark, actions that trigger a distributed computation can lead to the creation + * of Spark jobs. In the context of Spark SQL, this typically happens when we perform + * actions that require the computation of results that need to be collected or stored. + */ + spark.sparkContext.cancelJobGroup(flintStatement.queryId) + flintStatement.timeout() + flintStatement.error = Some(error) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + flintStatement.state, + error, + flintStatement.queryId, + flintStatement.query, + sessionId, + startTime) + } + + /** + * handling the case where a command's execution fails, updates the flintStatement with the + * error and failure status, and then write the result to result index. Thus, an error is + * written to both result index or statement model in request index + * + * @param spark + * spark session + * @param dataSource + * data source + * @param error + * error message + * @param flintStatement + * flint command + * @param sessionId + * session id + * @param startTime + * start time + * @return + * failed data frame + */ + def handleCommandFailureAndGetFailedData( + applicationId: String, + jobId: String, + spark: SparkSession, + dataSource: String, + error: String, + flintStatement: FlintStatement, + sessionId: String, + startTime: Long): DataFrame = { + flintStatement.fail() + flintStatement.error = Some(error) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + flintStatement.state, + error, + flintStatement.queryId, + flintStatement.query, + sessionId, + startTime) + } + /** * Before OS 2.13, there are two arguments from entry point: query and result index Starting * from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also @@ -547,6 +637,39 @@ trait FlintJobExecutor { } } + def getSegmentName(sparkSession: SparkSession): String = { + val maxExecutorsCount = + sparkSession.conf.get(FlintSparkConf.MAX_EXECUTORS_COUNT.key, "unknown") + String.format("%se", maxExecutorsCount) + } + + def instantiateSessionManager( + spark: SparkSession, + resultIndexOption: Option[String]): SessionManager = { + instantiate( + new SessionManagerImpl(spark, resultIndexOption), + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""), + resultIndexOption.getOrElse("")) + } + + def instantiateStatementExecutionManager( + commandContext: CommandContext): StatementExecutionManager = { + import commandContext._ + instantiate( + new StatementExecutionManagerImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + spark, + sessionId) + } + + def instantiateQueryResultWriter( + spark: SparkSession, + commandContext: CommandContext): QueryResultWriter = { + instantiate( + new QueryResultWriterImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) + } + def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (Strings.isNullOrEmpty(className)) { defaultConstructor diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 6d7dcc0e7..cc5f97144 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -458,57 +458,6 @@ object FlintREPL extends Logging with FlintJobExecutor { recordSessionFailed(sessionTimerContext) } - /** - * handling the case where a command's execution fails, updates the flintStatement with the - * error and failure status, and then write the result to result index. Thus, an error is - * written to both result index or statement model in request index - * - * @param spark - * spark session - * @param dataSource - * data source - * @param error - * error message - * @param flintStatement - * flint command - * @param sessionId - * session id - * @param startTime - * start time - * @return - * failed data frame - */ - def handleCommandFailureAndGetFailedData( - applicationId: String, - jobId: String, - spark: SparkSession, - dataSource: String, - error: String, - flintStatement: FlintStatement, - sessionId: String, - startTime: Long): DataFrame = { - flintStatement.fail() - flintStatement.error = Some(error) - super.constructErrorDF( - applicationId, - jobId, - spark, - dataSource, - flintStatement.state, - error, - flintStatement.queryId, - flintStatement.query, - sessionId, - startTime) - } - - def processQueryException(t: Throwable, flintStatement: FlintStatement): String = { - val error = super.processQueryException(t) - flintStatement.fail() - flintStatement.error = Some(error) - error - } - private def processCommands( statementExecutionManager: StatementExecutionManager, queryResultWriter: QueryResultWriter, @@ -610,43 +559,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def handleCommandTimeout( - applicationId: String, - jobId: String, - spark: SparkSession, - dataSource: String, - error: String, - flintStatement: FlintStatement, - sessionId: String, - startTime: Long) = { - /* - * https://tinyurl.com/2ezs5xj9 - * - * This only interrupts active Spark jobs that are actively running. - * This would then throw the error from ExecutePlan and terminate it. - * But if the query is not running a Spark job, but executing code on Spark driver, this - * would be a noop and the execution will keep running. - * - * In Apache Spark, actions that trigger a distributed computation can lead to the creation - * of Spark jobs. In the context of Spark SQL, this typically happens when we perform - * actions that require the computation of results that need to be collected or stored. - */ - spark.sparkContext.cancelJobGroup(flintStatement.queryId) - flintStatement.timeout() - flintStatement.error = Some(error) - super.constructErrorDF( - applicationId, - jobId, - spark, - dataSource, - flintStatement.state, - error, - flintStatement.queryId, - flintStatement.query, - sessionId, - startTime) - } - // scalastyle:off def executeAndHandle( applicationId: String, @@ -1021,33 +933,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def instantiateSessionManager( - spark: SparkSession, - resultIndexOption: Option[String]): SessionManager = { - instantiate( - new SessionManagerImpl(spark, resultIndexOption), - spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""), - resultIndexOption.getOrElse("")) - } - - private def instantiateStatementExecutionManager( - commandContext: CommandContext): StatementExecutionManager = { - import commandContext._ - instantiate( - new StatementExecutionManagerImpl(commandContext), - spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), - spark, - sessionId) - } - - private def instantiateQueryResultWriter( - spark: SparkSession, - commandContext: CommandContext): QueryResultWriter = { - instantiate( - new QueryResultWriterImpl(commandContext), - spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) - } - private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 27b0be84f..98b8d6307 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -32,10 +32,13 @@ case class JobOperator( dataSource: String, resultIndex: String, jobType: String, - streamingRunningCount: AtomicInteger) + streamingRunningCount: AtomicInteger, + statementContext: Map[String, Any] = Map.empty[String, Any]) extends Logging with FlintJobExecutor { + private val segmentName = getSegmentName(sparkSession) + // JVM shutdown hook sys.addShutdownHook(stop()) @@ -80,7 +83,9 @@ case class JobOperator( "", queryId, LangType.SQL, - currentTimeProvider.currentEpochMillis()) + currentTimeProvider.currentEpochMillis(), + Option.empty, + statementContext) try { val futurePrepareQueryExecution = Future { @@ -107,6 +112,7 @@ case class JobOperator( } catch { case e: TimeoutException => throwableHandler.recordThrowable(s"Preparation for query execution timed out", e) + incrementCounter(MetricConstants.STREAMING_EXECUTION_FAILED_METRIC) dataToWrite = Some( constructErrorDF( applicationId, @@ -120,6 +126,7 @@ case class JobOperator( "", startTime)) case t: Throwable => + incrementCounter(MetricConstants.STREAMING_EXECUTION_FAILED_METRIC) val error = processQueryException(t) dataToWrite = Some( constructErrorDF( @@ -134,10 +141,11 @@ case class JobOperator( "", startTime)) } finally { - emitQueryExecutionTimeMetric(startTime) + emitTimeMetric(startTime, MetricConstants.QUERY_EXECUTION_TIME_METRIC) readWriteBytesSparkListener.emitMetrics() sparkSession.sparkContext.removeSparkListener(readWriteBytesSparkListener) + val resultWriteStartTime = System.currentTimeMillis() try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) } catch { @@ -145,6 +153,10 @@ case class JobOperator( throwableHandler.recordThrowable( s"Failed to write to result index. originalError='${throwableHandler.error}'", t) + incrementCounter(MetricConstants.STREAMING_RESULT_WRITER_FAILED_METRIC) + } finally { + emitTimeMetric(resultWriteStartTime, MetricConstants.QUERY_RESULT_WRITER_TIME_METRIC) + emitTimeMetric(startTime, MetricConstants.QUERY_TOTAL_TIME_METRIC) } if (throwableHandler.hasException) statement.fail() else statement.complete() statement.error = Some(throwableHandler.error) @@ -205,13 +217,6 @@ case class JobOperator( } } - private def emitQueryExecutionTimeMetric(startTime: Long): Unit = { - MetricsUtil - .addHistoricGauge( - MetricConstants.QUERY_EXECUTION_TIME_METRIC, - System.currentTimeMillis() - startTime) - } - def stop(): Unit = { Try { logInfo("Stopping Spark session") @@ -243,8 +248,10 @@ case class JobOperator( } exceptionThrown match { - case true => incrementCounter(MetricConstants.STREAMING_FAILED_METRIC) - case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) + case true => + incrementCounter(MetricConstants.STREAMING_FAILED_METRIC) + case false => + incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) } } @@ -259,4 +266,15 @@ case class JobOperator( spark, sessionId) } + + private def emitTimeMetric(startTime: Long, metricType: String): Unit = { + val metricName = String.format("%s.%s", segmentName, metricType) + MetricsUtil.addHistoricGauge(metricName, System.currentTimeMillis() - startTime) + } + + private def incrementCounter(metricName: String) { + val metricWithSegmentName = String.format("%s.%s", segmentName, metricName) + MetricsUtil.incrementCounter(metricWithSegmentName) + } + } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala new file mode 100644 index 000000000..f83c3d969 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala @@ -0,0 +1,410 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import com.codahale.metrics.Timer +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.logging.CustomLogging +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} +import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, registerGauge, stopTimer} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.WarmpoolJobConfConstants._ +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils + +object WarmpoolJobConfConstants { + val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L +} + +case class WarmpoolJob( + conf: SparkConf, + sparkSession: SparkSession, + resultIndexOption: Option[String]) + extends Logging + with FlintJobExecutor { + + private val statementRunningCount = new AtomicInteger(0) + private val streamingRunningCount = new AtomicInteger(0) + private val segmentName = getSegmentName(sparkSession) + + def start(): Unit = { + // Read the values from the Spark configuration or fall back to the default values + val inactivityLimitMillis: Long = + conf.getLong( + FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS.key, + FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS) + val queryWaitTimeoutMillis: Long = + conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + + val sessionManager = instantiateSessionManager(sparkSession, resultIndexOption) + val commandContext = CommandContext( + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + sparkSession, + "", // In WP flow, datasource is not known yet + "", // In WP flow, jobType is not know yet + "", // WP doesn't use sessionId + sessionManager, + Duration.Inf, // WP doesn't have queryExecutionTimeout + inactivityLimitMillis, + queryWaitTimeoutMillis, // Used only for interactive queries + queryLoopExecutionFrequency) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + try { + FlintREPL.exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { + queryLoop(commandContext) + } + } finally { + sparkSession.stop() + + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() + + // Check for non-daemon threads that may prevent the driver from shutting down. + // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, + // which may be due to unresolved bugs in dependencies or threads not being properly shut down. + if (terminateJVM && threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + logInfo("A non-daemon thread in the driver is seen.") + // Exit the JVM to prevent resource leaks and potential emr-s job hung. + // A zero status code is used for a graceful shutdown without indicating an error. + // If exiting with non-zero status, emr-s job will fail. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully + System.exit(0) + } + } + } + + def queryLoop(commandContext: CommandContext): Unit = { + import commandContext._ + + val statementExecutionManager = instantiateStatementExecutionManager(commandContext) + var canProceed = true + + try { + var lastActivityTime = currentTimeProvider.currentEpochMillis() + + while (currentTimeProvider + .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canProceed) { + statementExecutionManager.getNextStatement() match { + case Some(flintStatement) => + flintStatement.running() + statementExecutionManager.updateStatement(flintStatement) + + val jobType = spark.conf.get(FlintSparkConf.JOB_TYPE.key, FlintJobType.BATCH) + val dataSource = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + val resultIndex = spark.conf.get(FlintSparkConf.RESULT_INDEX.key) + + val postQuerySelectionCommandContext = + commandContext.copy(dataSource = dataSource, jobType = jobType) + + CustomLogging.logInfo(s"""Job type is: ${jobType}""") + val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) + + if (jobType.equalsIgnoreCase(FlintJobType.STREAMING) || jobType.equalsIgnoreCase( + FlintJobType.BATCH)) { + processStreamingJob( + applicationId, + jobId, + flintStatement.query, + flintStatement.queryId, + dataSource, + resultIndex, + jobType, + spark, + flintStatement.context) + } else { + processInteractiveJob( + spark, + postQuerySelectionCommandContext, + flintStatement, + statementExecutionManager, + queryResultWriter) + + // Last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + } + + case _ => + canProceed = false + } + } + } catch { + case t: Throwable => + throwableHandler.recordThrowable(s"Query loop execution failed.", t) + throw t + } finally { + statementExecutionManager.terminateStatementExecution() + } + + Thread.sleep(commandContext.queryLoopExecutionFrequency) + } + + private def processStreamingJob( + applicationId: String, + jobId: String, + query: String, + queryId: String, + dataSource: String, + resultIndex: String, + jobType: String, + sparkSession: SparkSession, + executionContext: Map[String, Any]): Unit = { + + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + sparkSession.conf.set("spark.sql.defaultCatalog", dataSource) + + val streamingRunningCount = new AtomicInteger(0) + val jobOperator = JobOperator( + applicationId, + jobId, + sparkSession, + query, + queryId, + dataSource, + resultIndex, + jobType, + streamingRunningCount, + executionContext) + + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + jobOperator.start() + } + + private def processInteractiveJob( + sparkSession: SparkSession, + commandContext: CommandContext, + flintStatement: FlintStatement, + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter): Unit = { + + import commandContext._ + + var dataToWrite: Option[DataFrame] = None + val startTime: Long = currentTimeProvider.currentEpochMillis() + + statementRunningCount.incrementAndGet() + + val statementTimerContext = getTimerContext(MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) + implicit val ec: ExecutionContext = ExecutionContext.global + + val futurePrepareQueryExecution = Future { + statementExecutionManager.prepareStatementExecution() + } + + try { + ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { + case Right(_) => + dataToWrite = executeAndHandleInteractiveJob( + sparkSession, + commandContext, + flintStatement, + startTime, + statementExecutionManager, + queryResultWriter) + + case Left(error) => + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", + startTime)) + } + } catch { + case e: TimeoutException => + val error = s"Query execution preparation timed out" + CustomLogging.logError(error, e) + dataToWrite = Some( + handleCommandTimeout( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + "", + startTime)) + + case NonFatal(e) => + val error = s"An unexpected error occurred: ${e.getMessage}" + throwableHandler.recordThrowable(error, e) + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) + } finally { + emitTimeMetric(startTime, MetricConstants.STATEMENT_QUERY_EXECUTION_TIME_METRIC) + finalizeCommand( + statementExecutionManager, + queryResultWriter, + dataToWrite, + flintStatement, + statementTimerContext) + emitTimeMetric(startTime, MetricConstants.STATEMENT_QUERY_TOTAL_TIME_METRIC) + } + } + + private def executeAndHandleInteractiveJob( + sparkSession: SparkSession, + commandContext: CommandContext, + flintStatement: FlintStatement, + startTime: Long, + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter): Option[DataFrame] = { + + import commandContext._ + + try { + if (currentTimeProvider + .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { + Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + "wait timeout", + flintStatement, + "", // WP doesn't use sessionId + startTime)) + } else { + // Execute the statement and get the resulting DataFrame + val df = statementExecutionManager.executeStatement(flintStatement) + // Process the DataFrame, applying any necessary transformations + // and triggering Spark actions to materialize the results + Some(queryResultWriter.processDataFrame(df, flintStatement, startTime)) + } + } catch { + case e: TimeoutException => + incrementCounter(MetricConstants.STATEMENT_EXECUTION_FAILED_METRIC) + val error = s"Query execution preparation timed out" + CustomLogging.logError(error, e) + Some( + handleCommandTimeout( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", // WP doesn't use sessionId + startTime)) + + case t: Throwable => + incrementCounter(MetricConstants.STATEMENT_EXECUTION_FAILED_METRIC) + val error = FlintREPL.processQueryException(t, flintStatement) + CustomLogging.logError(error, t) + Some( + handleCommandFailureAndGetFailedData( + applicationId, + jobId, + sparkSession, + dataSource, + error, + flintStatement, + "", // WP doesn't use sessionId + startTime)) + } + } + + /** + * Finalize statement after processing + * + * @param dataToWrite + * Data to write + * @param flintStatement + * Flint statement + */ + private def finalizeCommand( + statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, + dataToWrite: Option[DataFrame], + flintStatement: FlintStatement, + statementTimerContext: Timer.Context): Unit = { + + val resultWriterStartTime: Long = currentTimeProvider.currentEpochMillis() + + try { + dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) + + if (flintStatement.isRunning || flintStatement.isWaiting) { + flintStatement.complete() + } + } catch { + case t: Throwable => + incrementCounter(MetricConstants.STATEMENT_RESULT_WRITER_FAILED_METRIC) + val error = + s"""Fail to write result of ${flintStatement}, cause: ${throwableHandler.error}""" + throwableHandler.recordThrowable(error, t) + CustomLogging.logError(error, t) + flintStatement.fail() + } finally { + if (throwableHandler.hasException) flintStatement.fail() else flintStatement.complete() + flintStatement.error = Some(throwableHandler.error) + + emitTimeMetric(resultWriterStartTime, MetricConstants.STATEMENT_RESULT_WRITER_TIME_METRIC) + statementExecutionManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) + } + } + + private def emitTimeMetric(startTime: Long, metricName: String): Unit = { + val metricNameWithSegment = String.format("%s.%s", segmentName, metricName) + MetricsUtil.addHistoricGauge(metricNameWithSegment, System.currentTimeMillis() - startTime) + } + + private def recordStatementStateChange( + flintStatement: FlintStatement, + statementTimerContext: Timer.Context): Unit = { + stopTimer(statementTimerContext) + if (statementRunningCount.get() > 0) { + statementRunningCount.decrementAndGet() + } + if (flintStatement.isComplete) { + incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) + } else if (flintStatement.isFailed) { + incrementCounter(MetricConstants.STATEMENT_FAILED_METRIC) + } + } + + private def incrementCounter(metricName: String) { + val metricWithSegmentName = String.format("%s.%s", segmentName, metricName) + MetricsUtil.incrementCounter(metricWithSegmentName) + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 339c1870d..de96a5937 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -5,12 +5,20 @@ package org.apache.spark.sql +import org.mockito.ArgumentMatchersSugar +import org.scalatestplus.mockito.MockitoSugar + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider} -class FlintJobTest extends SparkFunSuite with JobMatchers { +class FlintJobTest + extends SparkFunSuite + with MockitoSugar + with ArgumentMatchersSugar + with JobMatchers { private val jobId = "testJobId" private val applicationId = "testApplicationId" @@ -122,4 +130,28 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { spark.sparkContext.conf.get("spark.dynamicAllocation.maxExecutors") shouldBe "30" } + test("createSparkConf should set the app name and default SQL extensions") { + val conf = FlintJob.createSparkConf() + + // Assert that the app name is set correctly + assert(conf.get("spark.app.name") === "FlintJob$") + + // Assert that the default SQL extensions are set correctly + assert(conf.get(SQL_EXTENSIONS_KEY) === DEFAULT_SQL_EXTENSIONS) + } + + test( + "createSparkConf should not use defaultExtensions if spark.sql.extensions is already set") { + val customExtension = "my.custom.extension" + // Set the spark.sql.extensions property before calling createSparkConf + System.setProperty(SQL_EXTENSIONS_KEY, customExtension) + + try { + val conf = FlintJob.createSparkConf() + assert(conf.get(SQL_EXTENSIONS_KEY) === customExtension) + } finally { + // Clean up the system property after the test + System.clearProperty(SQL_EXTENSIONS_KEY) + } + } }