diff --git a/build.sbt b/build.sbt index 9d419decb..456b8f967 100644 --- a/build.sbt +++ b/build.sbt @@ -97,7 +97,7 @@ lazy val flintCommons = (project in file("flint-commons")) ), libraryDependencies ++= deps(sparkVersion), publish / skip := true, - assembly / test := (Test / test).value, + assembly / test := {}, assembly / assemblyOption ~= { _.withIncludeScala(false) }, @@ -149,7 +149,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) val oldStrategy = (assembly / assemblyMergeStrategy).value oldStrategy(x) }, - assembly / test := (Test / test).value) + assembly / test := {}) lazy val flintSparkIntegration = (project in file("flint-spark-integration")) .dependsOn(flintCore, flintCommons) @@ -193,7 +193,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) val cp = (assembly / fullClasspath).value cp filter { file => file.data.getName.contains("LogsConnectorSpark")} }, - assembly / test := (Test / test).value) + assembly / test := {}) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) @@ -269,7 +269,7 @@ lazy val sparkSqlApplication = (project in file("spark-sql-application")) val oldStrategy = (assembly / assemblyMergeStrategy).value oldStrategy(x) }, - assembly / test := (Test / test).value + assembly / test := {} ) lazy val sparkSqlApplicationCosmetic = project diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala index d69fbc30f..f676a3519 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -8,5 +8,9 @@ package org.apache.spark.sql import org.opensearch.flint.data.FlintStatement trait QueryResultWriter { - def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit + def reformatQueryResult( + dataFrame: DataFrame, + flintStatement: FlintStatement, + queryExecutionContext: StatementExecutionContext): DataFrame + def persistQueryResult(dataFrame: DataFrame, flintStatement: FlintStatement): Unit } diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala index 345f97619..91a68ead3 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -9,17 +9,40 @@ import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode +/** + * Trait defining the interface for managing interactive sessions. + */ trait SessionManager { + + /** + * Retrieves metadata about the session manager. + */ def getSessionManagerMetadata: Map[String, Any] + + /** + * Fetches the details of a specific session. + */ def getSessionDetails(sessionId: String): Option[InteractiveSession] + + /** + * Updates the details of a specific session. + */ def updateSessionDetails( sessionDetails: InteractiveSession, updateMode: SessionUpdateMode): Unit - def hasPendingStatement(sessionId: String): Boolean + + /** + * Retrieves the next statement to be executed in a specific session. + */ + def getNextStatement(sessionId: String): Option[FlintStatement] + + /** + * Records a heartbeat for a specific session to indicate it is still active. + */ def recordHeartbeat(sessionId: String): Unit } object SessionUpdateMode extends Enumeration { type SessionUpdateMode = Value - val Update, Upsert, UpdateIf = Value + val UPDATE, UPSERT, UPDATE_IF = Value } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionContext.scala similarity index 73% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala rename to flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionContext.scala index 5108371ef..91358d6ec 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionContext.scala @@ -7,13 +7,14 @@ package org.apache.spark.sql import scala.concurrent.duration.Duration -case class QueryExecutionContext( +case class StatementExecutionContext( spark: SparkSession, jobId: String, sessionId: String, sessionManager: SessionManager, + statementLifecycleManager: StatementLifecycleManager, + queryResultWriter: QueryResultWriter, dataSource: String, - resultIndex: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long) diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala new file mode 100644 index 000000000..5a890f5ed --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +/** + * Trait defining the interface for managing the lifecycle of statements. + */ +trait StatementLifecycleManager { + + /** + * Prepares the statement lifecycle. + */ + def prepareStatementLifecycle(): Either[String, Unit] + + /** + * Updates a specific statement. + */ + def updateStatement(statement: FlintStatement): Unit + + /** + * Terminates the statement lifecycle. + */ + def terminateStatementLifecycle(): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala deleted file mode 100644 index c0a24ab33..000000000 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.apache.spark.sql - -import org.opensearch.flint.data.FlintStatement - -trait StatementManager { - def prepareCommandLifecycle(): Either[String, Unit] - def initCommandLifecycle(sessionId: String): FlintStatement - def closeCommandLifecycle(): Unit - def updateCommandDetails(commandDetails: FlintStatement): Unit -} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala index dbe73e9a5..80c22df82 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala @@ -42,6 +42,7 @@ class FlintStatement( val statementId: String, val queryId: String, val submitTime: Long, + var queryStartTime: Option[Long] = Some(-1L), var error: Option[String] = None, statementContext: Map[String, Any] = Map.empty[String, Any]) extends ContextualDataStore { @@ -76,7 +77,7 @@ object FlintStatement { case _ => None } - new FlintStatement(state, query, statementId, queryId, submitTime, maybeError) + new FlintStatement(state, query, statementId, queryId, submitTime, error = maybeError) } def serialize(flintStatement: FlintStatement): String = { diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 48107fe8c..0cf643791 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -49,12 +49,6 @@ public class FlintOptions implements Serializable { public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider"; - public static final String CUSTOM_SESSION_MANAGER = "customSessionManager"; - - public static final String CUSTOM_STATEMENT_MANAGER = "customStatementManager"; - - public static final String CUSTOM_QUERY_RESULT_WRITER = "customQueryResultWriter"; - /** * By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain. */ @@ -62,8 +56,6 @@ public class FlintOptions implements Serializable { public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex"; - public static final String FLINT_SESSION_ID = "spark.flint.job.sessionId"; - /** * Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader} */ @@ -145,18 +137,6 @@ public String getMetadataAccessAwsCredentialsProvider() { return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); } - public String getCustomSessionManager() { - return options.getOrDefault(CUSTOM_SESSION_MANAGER, ""); - } - - public String getCustomStatementManager() { - return options.getOrDefault(CUSTOM_STATEMENT_MANAGER, ""); - } - - public String getCustomQueryResultWriter() { - return options.getOrDefault(CUSTOM_QUERY_RESULT_WRITER, ""); - } - public String getUsername() { return options.getOrDefault(USERNAME, "flint"); } @@ -177,10 +157,6 @@ public String getSystemIndexName() { return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, ""); } - public String getSessionId() { - return options.getOrDefault(FLINT_SESSION_ID, null); - } - public int getBatchBytes() { // we did not expect this value could be large than 10mb = 10 * 1024 * 1024 return (int) org.apache.spark.network.util.JavaUtils diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index bba999110..d6bf611fe 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -58,7 +58,7 @@ object FlintJob extends Logging with FlintJobExecutor { createSparkSession(conf), query, dataSource, - resultIndex, + resultIndex.get, jobType.equalsIgnoreCase("streaming"), streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 665ec5a27..60f2a5f8b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -89,6 +89,24 @@ trait FlintJobExecutor { } }""".stripMargin + // Define the data schema + val schema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), + // number is not nullable + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = true))) + def createSparkConf(): SparkConf = { new SparkConf() .setAppName(getClass.getSimpleName) @@ -175,7 +193,6 @@ trait FlintJobExecutor { query: String, sessionId: String, startTime: Long, - timeProvider: TimeProvider, cleaner: Cleaner): DataFrame = { // Create the schema dataframe val schemaRows = result.schema.fields.map { field => @@ -188,29 +205,11 @@ trait FlintJobExecutor { StructField("column_name", StringType, nullable = false), StructField("data_type", StringType, nullable = false)))) - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) - val resultToSave = result.toJSON.collect.toList .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")) val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")) - val endTime = timeProvider.currentEpochMillis() + val endTime = currentTimeProvider.currentEpochMillis() // https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data // after consumed the query result. Streaming query shuffle data is cleaned after each @@ -245,28 +244,9 @@ trait FlintJobExecutor { queryId: String, query: String, sessionId: String, - startTime: Long, - timeProvider: TimeProvider): DataFrame = { - - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) - - val endTime = timeProvider.currentEpochMillis() + startTime: Long): DataFrame = { + + val endTime = currentTimeProvider.currentEpochMillis() // Create the data rows val rows = Seq( @@ -419,7 +399,6 @@ trait FlintJobExecutor { query, sessionId, startTime, - currentTimeProvider, CleanerFactory.cleaner(streaming)) } @@ -485,16 +464,19 @@ trait FlintJobExecutor { } } - def parseArgs(args: Array[String]): (Option[String], String) = { + def parseArgs(args: Array[String]): (Option[String], Option[String]) = { args match { + case Array() => + (None, None) case Array(resultIndex) => - (None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument + (None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument case Array(query, resultIndex) => ( Some(query), - resultIndex + Some(resultIndex) ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.") + case _ => + logAndThrow("Unsupported number of arguments. Expected no more than two arguments.") } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 9b0b66d28..6b2f8bd25 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -15,19 +15,16 @@ import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import com.codahale.metrics.Timer -import org.opensearch.action.get.GetResponse -import org.opensearch.common.Strings import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.opensearch.search.sort.SortOrder +import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SessionUpdateMode.UPDATE_IF import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.{ThreadUtils, Utils} @@ -47,28 +44,20 @@ import org.apache.spark.util.{ThreadUtils, Utils} object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + private val PREPARE_QUERY_EXEC_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 - val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L + private val INITIAL_DELAY_MILLIS = 3000L + private val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L @volatile var earlyExitFlag: Boolean = false - def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { - updater.update(flintStatement.statementId, FlintStatement.serialize(flintStatement)) - } - private val sessionRunningCount = new AtomicInteger(0) private val statementRunningCount = new AtomicInteger(0) def main(args: Array[String]) { val (queryOption, resultIndex) = parseArgs(args) - if (Strings.isNullOrEmpty(resultIndex)) { - logAndThrow("resultIndex is not set") - } - // init SparkContext val conf: SparkConf = createSparkConf() val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") @@ -89,19 +78,23 @@ object FlintREPL extends Logging with FlintJobExecutor { val query = getQuery(queryOption, jobType, conf) if (jobType.equalsIgnoreCase("streaming")) { - logInfo(s"""streaming query ${query}""") - configDYNMaxExecutors(conf, jobType) - val streamingRunningCount = new AtomicInteger(0) - val jobOperator = - JobOperator( - createSparkSession(conf), - query, - dataSource, - resultIndex, - true, - streamingRunningCount) - registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) - jobOperator.start() + logInfo(s"streaming query $query") + resultIndex match { + case Some(index) => + configDYNMaxExecutors(conf, jobType) + val streamingRunningCount = new AtomicInteger(0) + val jobOperator = JobOperator( + createSparkSession(conf), + query, + dataSource, + index, // Ensure the correct Option type is passed + true, + streamingRunningCount) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + jobOperator.start() + + case None => logAndThrow("resultIndex is not set") + } } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) @@ -111,10 +104,10 @@ object FlintREPL extends Logging with FlintJobExecutor { } val spark = createSparkSession(conf) + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong( @@ -128,7 +121,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - val sessionManager = instantiateSessionManager() + val sessionManager = instantiateSessionManager(spark, resultIndex) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -141,7 +134,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/320 */ spark.sparkContext.addSparkListener( - new PreShutdownListener(sessionManager, sessionId.get, sessionTimerContext)) + new PreShutdownListener(sessionId.get, sessionManager, sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -173,23 +166,35 @@ object FlintREPL extends Logging with FlintJobExecutor { return } - val commandContext = QueryExecutionContext( + val queryExecutionManager = + instantiateQueryExecutionManager(spark, sessionManager.getSessionManagerMetadata) + val queryResultWriter = + instantiateQueryResultWriter(spark, sessionManager.getSessionManagerMetadata) + val queryLoopContext = StatementExecutionContext( spark, jobId, sessionId.get, sessionManager, + queryExecutionManager, + queryResultWriter, dataSource, - resultIndex, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { - queryLoop(commandContext) + queryLoop(queryLoopContext) } recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => - handleSessionError(sessionTimerContext = sessionTimerContext, e = e) + handleSessionError( + applicationId, + jobId, + sessionId.get, + sessionManager, + sessionTimerContext, + jobStartTime, + e) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running @@ -275,16 +280,16 @@ object FlintREPL extends Logging with FlintJobExecutor { case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 - val excludeJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis + val excludedJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis - if (excludeJobIds.contains(jobId)) { + if (excludedJobIds.contains(jobId)) { logInfo(s"current job is excluded, exit the application.") return true } val sessionDetails = sessionManager.getSessionDetails(sessionId) val existingExcludedJobIds = sessionDetails.get.excludedJobIds - if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { + if (excludedJobIds.sorted == existingExcludedJobIds.sorted) { logInfo("duplicate job running, exit the application.") return true } @@ -296,20 +301,20 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, sessionManager, jobStartTime, - excludeJobIds) + excludedJobIds = excludedJobIds) } false } - def queryLoop(queryExecutionContext: QueryExecutionContext): Unit = { - // 1 thread for updating heart beat + def queryLoop(context: StatementExecutionContext): Unit = { + import context._ + // 1 thread for query execution preparation val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var futureMappingCheck: Future[Either[String, Unit]] = null + implicit val futureExecutor = ExecutionContext.fromExecutor(threadPool) + var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { - checkAndCreateIndex(queryExecutionContext.osClient, queryExecutionContext.resultIndex) + futurePrepareQueryExecution = Future { + statementLifecycleManager.prepareStatementLifecycle() } var lastActivityTime = currentTimeProvider.currentEpochMillis() @@ -317,20 +322,18 @@ object FlintREPL extends Logging with FlintJobExecutor { var canPickUpNextStatement = true var lastCanPickCheckTime = 0L while (currentTimeProvider - .currentEpochMillis() - lastActivityTime <= queryExecutionContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo( - s"""read from ${queryExecutionContext.sessionIndex}, sessionId: ${queryExecutionContext.sessionId}""") + .currentEpochMillis() - lastActivityTime <= context.inactivityLimitMillis && canPickUpNextStatement) { + logInfo(s"""sessionId: ${context.sessionId}""") try { - val commandState = CommandState( + val inMemoryQueryExecutionState = InMemoryQueryExecutionState( lastActivityTime, + lastCanPickCheckTime, verificationResult, - flintReader, - futureMappingCheck, - executionContext, - lastCanPickCheckTime) + futurePrepareQueryExecution, + futureExecutor) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(queryExecutionContext, commandState) + processStatements(context, inMemoryQueryExecutionState) val ( updatedLastActivityTime, @@ -343,7 +346,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - flintReader.close() + statementLifecycleManager.terminateStatementLifecycle() } Thread.sleep(100) @@ -355,37 +358,52 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - // TODO: Refactor this with getDetails + private def refreshSessionState( + applicationId: String, + jobId: String, + sessionId: String, + sessionManager: SessionManager, + jobStartTime: Long, + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + sessionDetails.state = state + sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionDetails + } + private def setupFlintJob( applicationId: String, jobId: String, sessionId: String, sessionManager: SessionManager, jobStartTime: Long, - excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + excludedJobIds: Seq[String] = Seq.empty[String]): Unit = { + refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) - - // TODO: serialize need to be refactored to be more flexible - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo(s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}}""") + SessionStates.RUNNING, + excludedJobIds = excludedJobIds) sessionRunningCount.incrementAndGet() } - def handleSessionError( + private def handleSessionError( applicationId: String, jobId: String, sessionId: String, @@ -395,39 +413,15 @@ object FlintREPL extends Logging with FlintJobExecutor { e: Exception): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - - val sessionDetails = sessionManager - .getSessionDetails(sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(sessionDetails, flintSessionIndexUpdater, sessionId) - if (sessionDetails.isFail) { - recordSessionFailed(sessionTimerContext) - } - } - - private def createFailedFlintInstance( - applicationId: String, - jobId: String, - sessionId: String, - jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - - private def updateFlintInstance( - flintInstance: InteractiveSession, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { - val currentTime = currentTimeProvider.currentEpochMillis() - flintSessionIndexUpdater.upsert( + refreshSessionState( + applicationId, + jobId, sessionId, - InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } /** @@ -450,13 +444,12 @@ object FlintREPL extends Logging with FlintJobExecutor { * @return * failed data frame */ - def handleCommandFailureAndGetFailedData( + def handleStatementFailureAndGetFailedData( spark: SparkSession, dataSource: String, error: String, flintStatement: FlintStatement, - sessionId: String, - startTime: Long): DataFrame = { + sessionId: String): DataFrame = { flintStatement.fail() flintStatement.error = Some(error) super.getFailedData( @@ -466,8 +459,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.queryId, flintStatement.query, sessionId, - startTime, - currentTimeProvider) + flintStatement.queryStartTime.get) } def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { @@ -477,9 +469,9 @@ object FlintREPL extends Logging with FlintJobExecutor { error } - private def processCommands( - context: QueryExecutionContext, - state: CommandState): (Long, VerificationResult, Boolean, Long) = { + private def processStatements( + context: StatementExecutionContext, + state: InMemoryQueryExecutionState): (Long, VerificationResult, Boolean, Long) = { import context._ import state._ @@ -492,8 +484,8 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { + // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionManager, sessionId, jobId) lastCanPickCheckTime = currentTime } @@ -501,35 +493,23 @@ object FlintREPL extends Logging with FlintJobExecutor { if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { - val statementTimerContext = getTimerContext( - MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) + sessionManager.getNextStatement(sessionId) match { + case Some(flintStatement) => + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex, - queryExecutionTimeout, - queryWaitTimeMillis) - - verificationResult = returnedVerificationResult - finalizeCommand( - dataToWrite, - flintStatement, - resultIndex, - flintSessionIndexUpdater, - osClient, - statementTimerContext) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() + val (dataToWrite, returnedVerificationResult) = + processStatementOnVerification(flintStatement, state, context) + + verificationResult = returnedVerificationResult + finalizeStatement(context, dataToWrite, flintStatement, statementTimerContext) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + + case None => + canProceed = false + } } } @@ -549,21 +529,19 @@ object FlintREPL extends Logging with FlintJobExecutor { * @param flintSessionIndexUpdater * flint session index updater */ - private def finalizeCommand( + private def finalizeStatement( + context: StatementExecutionContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, - resultIndex: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, statementTimerContext: Timer.Context): Unit = { + import context._ + try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + dataToWrite.foreach(df => queryResultWriter.persistQueryResult(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling flintStatement.complete() } - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -571,18 +549,18 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) flintStatement.fail() - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) + } finally { + statementLifecycleManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) } } - private def handleCommandTimeout( + private def handleStatementTimeout( spark: SparkSession, dataSource: String, error: String, flintStatement: FlintStatement, - sessionId: String, - startTime: Long): Option[DataFrame] = { + sessionId: String) = { /* * https://tinyurl.com/2ezs5xj9 * @@ -597,131 +575,89 @@ object FlintREPL extends Logging with FlintJobExecutor { */ spark.sparkContext.cancelJobGroup(flintStatement.queryId) Some( - handleCommandFailureAndGetFailedData( - spark, - dataSource, - error, - flintStatement, - sessionId, - startTime)) + handleStatementFailureAndGetFailedData(spark, dataSource, error, flintStatement, sessionId)) } def executeAndHandle( - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecuitonTimeOut: Duration, - queryWaitTimeMillis: Long): Option[DataFrame] = { + state: InMemoryQueryExecutionState, + context: StatementExecutionContext): Option[DataFrame] = { + import context._ + try { - Some( - executeQueryAsync( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecuitonTimeOut, - queryWaitTimeMillis)) + Some(executeQueryAsync(flintStatement, state, context)) } catch { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + handleStatementTimeout(spark, dataSource, error, flintStatement, sessionId) case e: Exception => val error = processQueryException(e, flintStatement) Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, error, flintStatement, - sessionId, - startTime)) + sessionId)) } } private def processStatementOnVerification( - recordedVerificationResult: VerificationResult, - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String, - queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long) = { - val startTime: Long = currentTimeProvider.currentEpochMillis() + state: InMemoryQueryExecutionState, + context: StatementExecutionContext): (Option[DataFrame], VerificationResult) = { + import context._ + import state._ + + flintStatement.queryStartTime = Some(currentTimeProvider.currentEpochMillis()) var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None verificationResult match { case NotVerified => try { - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + ThreadUtils.awaitResult(futurePrepareQueryExecution, PREPARE_QUERY_EXEC_TIMEOUT) match { case Right(_) => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(flintStatement, state, context) verificationResult = VerifiedWithoutError case Left(error) => verificationResult = VerifiedWithError(error) dataToWrite = Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, error, flintStatement, - sessionId, - startTime)) + sessionId)) } } catch { case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" + val error = s"Query execution preparation timed out" CustomLogging.logError(error, e) dataToWrite = - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + handleStatementTimeout(spark, dataSource, error, flintStatement, sessionId) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" CustomLogging.logError(error, e) dataToWrite = Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, error, flintStatement, - sessionId, - startTime)) + sessionId)) } case VerifiedWithError(err) => dataToWrite = Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, err, flintStatement, - sessionId, - startTime)) + sessionId)) case VerifiedWithoutError => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(flintStatement, state, context) } logInfo(s"command complete: $flintStatement") @@ -729,149 +665,79 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeQueryAsync( - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecutionTimeOut: Duration, - queryWaitTimeMillis: Long): DataFrame = { + state: InMemoryQueryExecutionState, + context: StatementExecutionContext): DataFrame = { + import context._ + import state._ + if (currentTimeProvider .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, "wait timeout", flintStatement, - sessionId, - startTime) + sessionId) } else { val futureQueryExecution = Future { - executeQuery( - spark, - flintStatement.query, - dataSource, - flintStatement.queryId, - sessionId, - false) - }(executionContext) + executeQuery(flintStatement, context) + }(futureExecutor) // time out after 10 minutes - ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) + ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeout) } } - private def processCommandInitiation( - flintReader: FlintReader, - flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { - val command = flintReader.next() - logDebug(s"raw command: $command") - val flintStatement = FlintStatement.deserialize(command) - logDebug(s"command: $flintStatement") - flintStatement.running() - logDebug(s"command running: $flintStatement") - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - statementRunningCount.incrementAndGet() - flintStatement - } - private def createQueryReader( - osClient: OSClient, - sessionId: String, - sessionIndex: String, - dataSource: String) = { - // all state in index are in lower case - // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the - // same doc - val dsl = - s"""{ - | "bool": { - | "must": [ - | { - | "term": { - | "type": "statement" - | } - | }, - | { - | "term": { - | "state": "waiting" - | } - | }, - | { - | "term": { - | "sessionId": "$sessionId" - | } - | }, - | { - | "term": { - | "dataSourceName": "$dataSource" - | } - | }, - | { - | "range": { - | "submitTime": { "gte": "now-1h" } - | } - | } - | ] - | } - |}""".stripMargin - - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader + def executeQuery( + flintStatement: FlintStatement, + context: StatementExecutionContext): DataFrame = { + import context._ + // reset start time + flintStatement.queryStartTime = Some(System.currentTimeMillis()) + // we have to set job group in the same thread that started the query according to spark doc + spark.sparkContext.setJobGroup( + flintStatement.queryId, + "Job group for " + flintStatement.queryId, + interruptOnCancel = true) + // Execute SQL query + val result: DataFrame = spark.sql(flintStatement.query) + queryResultWriter.reformatQueryResult(result, flintStatement, context) } class PreShutdownListener( - sessionManager: SessionManager, sessionId: String, + sessionManager: SessionManager, sessionTimerContext: Timer.Context) extends SparkListener with Logging { - // TODO: Refactor update - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { + sessionDetails.state = SessionStates.DEAD + sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + recordSessionSuccess(sessionTimerContext) + } } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } - private def updateFlintInstanceBeforeShutdown( - source: java.util.Map[String, AnyRef], - getResponse: GetResponse, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String, - sessionTimerContext: Timer.Context): Unit = { - val flintInstance = InteractiveSession.deserializeFromMap(source) - flintInstance.state = "dead" - flintSessionIndexUpdater.updateIf( - sessionId, - InteractiveSession.serializeWithoutJobId( - flintInstance, - currentTimeProvider.currentEpochMillis()), - getResponse.getSeqNo, - getResponse.getPrimaryTerm) - recordSessionSuccess(sessionTimerContext) - } - /** * Create a new thread to update the last update time of the flint instance. * @param currentInterval @@ -1008,25 +874,46 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } - private def instantiateSessionManager(): SessionManager = { - val options = FlintSparkConf().flintOptions() - val className = options.getCustomSessionManager() - - if (className.isEmpty) { - new SessionManagerImpl(options) + private def instantiateProvider[T](defaultProvider: => T, providerClassName: String): T = { + if (providerClassName.isEmpty) { + defaultProvider } else { try { - val providerClass = Utils.classForName(className) + val providerClass = Utils.classForName(providerClassName) val ctor = providerClass.getDeclaredConstructor() ctor.setAccessible(true) - ctor.newInstance().asInstanceOf[SessionManager] + ctor.newInstance().asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Failed to instantiate provider: $className", e) + throw new RuntimeException(s"Failed to instantiate provider: $providerClassName", e) } } } + private def instantiateSessionManager( + spark: SparkSession, + resultIndex: Option[String]): SessionManager = { + instantiateProvider( + new SessionManagerImpl(spark, resultIndex), + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key)) + } + + private def instantiateQueryExecutionManager( + spark: SparkSession, + context: Map[String, Any]): StatementLifecycleManager = { + instantiateProvider( + new StatementLifecycleManagerImpl(context), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key)) + } + + private def instantiateQueryResultWriter( + spark: SparkSession, + context: Map[String, Any]): QueryResultWriter = { + instantiateProvider( + new QueryResultWriterImpl(context), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key)) + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/InMemoryQueryExecutionState.scala similarity index 61% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala rename to spark-sql-application/src/main/scala/org/apache/spark/sql/InMemoryQueryExecutionState.scala index ad49201f0..4f87dd643 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/InMemoryQueryExecutionState.scala @@ -9,10 +9,9 @@ import scala.concurrent.{ExecutionContextExecutor, Future} import org.opensearch.flint.core.storage.FlintReader -case class CommandState( +case class InMemoryQueryExecutionState( recordedLastActivityTime: Long, + recordedLastCanPickCheckTime: Long, recordedVerificationResult: VerificationResult, - flintReader: FlintReader, - futureMappingCheck: Future[Either[String, Unit]], - executionContext: ExecutionContextExecutor, - recordedLastCanPickCheckTime: Long) + futurePrepareQueryExecution: Future[Either[String, Unit]], + futureExecutor: ExecutionContextExecutor) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index f315dc836..d386a5c31 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -57,19 +57,17 @@ case class JobOperator( dataToWrite = Some(mappingCheckResult match { case Right(_) => data case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + getFailedData(spark, dataSource, error, "", query, "", startTime) }) exceptionThrown = false } catch { case e: TimeoutException => val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + dataToWrite = Some(getFailedData(spark, dataSource, error, "", query, "", startTime)) case e: Exception => val error = processQueryException(e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + dataToWrite = Some(getFailedData(spark, dataSource, error, "", query, "", startTime)) } finally { cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala new file mode 100644 index 000000000..830fbae68 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.util.CleanerFactory + +class QueryResultWriterImpl(context: Map[String, Any]) + extends QueryResultWriter + with FlintJobExecutor + with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + + override def reformatQueryResult( + dataFrame: DataFrame, + flintStatement: FlintStatement, + queryExecutionContext: StatementExecutionContext): DataFrame = { + import queryExecutionContext._ + getFormattedData( + dataFrame, + spark, + dataSource, + flintStatement.queryId, + flintStatement.query, + sessionId, + flintStatement.queryStartTime.get, + CleanerFactory.cleaner(false)) + } + + override def persistQueryResult(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { + writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index 29f70ddbf..6775c9067 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -8,10 +8,9 @@ package org.apache.spark.sql import scala.util.{Failure, Success, Try} import org.json4s.native.Serialization -import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.storage.FlintReader -import org.opensearch.flint.data.InteractiveSession +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder @@ -19,19 +18,22 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode import org.apache.spark.sql.flint.config.FlintSparkConf -class SessionManagerImpl(flintOptions: FlintOptions) +class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) extends SessionManager with FlintJobExecutor with Logging { // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. - val sessionIndex: String = flintOptions.getSystemIndexName - val sessionId: String = flintOptions.getSessionId - val dataSource: String = flintOptions.getDataSourceName + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key) + val sessionId: String = spark.conf.get(FlintSparkConf.SESSION_ID.key) + val dataSource: String = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) if (sessionIndex.isEmpty) { logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } + if (resultIndex.isEmpty) { + logAndThrow("resultIndex is not set") + } if (sessionId.isEmpty) { logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") } @@ -39,13 +41,13 @@ class SessionManagerImpl(flintOptions: FlintOptions) logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") } - val osClient = new OSClient(flintOptions) + val osClient = new OSClient(FlintSparkConf().flintOptions()) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) - val flintReader: FlintReader = createQueryReader(sessionId, sessionIndex, dataSource) + val flintReader: FlintReader = createOpenSearchQueryReader() override def getSessionManagerMetadata: Map[String, Any] = { Map( - "sessionIndex" -> sessionIndex, + "resultIndex" -> resultIndex.get, "osClient" -> osClient, "flintSessionIndexUpdater" -> flintSessionIndexUpdater, "flintReader" -> flintReader) @@ -65,6 +67,64 @@ class SessionManagerImpl(flintOptions: FlintOptions) } } + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.UPDATE => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.UPSERT => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("_seq_no") + .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .asInstanceOf[Long] + val primaryTerm = sessionDetails + .getContextValue("_primary_term") + .getOrElse( + throw new IllegalArgumentException("Missing _primary_term for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } + + override def getNextStatement(sessionId: String): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + logDebug(s"raw statement: $rawStatement") + val flintStatement = FlintStatement.deserialize(rawStatement) + logDebug(s"statement: $flintStatement") + Some(flintStatement) + } else { + None + } + } + override def recordHeartbeat(sessionId: String): Unit = { flintSessionIndexUpdater.upsert( sessionId, @@ -72,11 +132,7 @@ class SessionManagerImpl(flintOptions: FlintOptions) Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) } - override def hasPendingStatement(sessionId: String): Boolean = { - flintReader.hasNext - } - - private def createQueryReader(sessionId: String, sessionIndex: String, dataSource: String) = { + private def createOpenSearchQueryReader() = { // all state in index are in lower case // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the // same doc @@ -116,52 +172,4 @@ class SessionManagerImpl(flintOptions: FlintOptions) val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) flintReader } - - override def updateSessionDetails( - sessionDetails: InteractiveSession, - sessionUpdateMode: SessionUpdateMode): Unit = { - sessionUpdateMode match { - case SessionUpdateMode.Update => - flintSessionIndexUpdater.update( - sessionDetails.sessionId, - InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) - case SessionUpdateMode.Upsert => - val includeJobId = - !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( - sessionDetails.jobId) - val serializedSession = if (includeJobId) { - InteractiveSession.serialize( - sessionDetails, - currentTimeProvider.currentEpochMillis(), - true) - } else { - InteractiveSession.serializeWithoutJobId( - sessionDetails, - currentTimeProvider.currentEpochMillis()) - } - flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) - case SessionUpdateMode.UpdateIf => - val executionContext = sessionDetails.executionContext.getOrElse( - throw new IllegalArgumentException("Missing executionContext for conditional update")) - val seqNo = executionContext - .get("_seq_no") - .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) - .asInstanceOf[Long] - val primaryTerm = executionContext - .get("_primary_term") - .getOrElse( - throw new IllegalArgumentException("Missing _primary_term for conditional update")) - .asInstanceOf[Long] - flintSessionIndexUpdater.updateIf( - sessionDetails.sessionId, - InteractiveSession.serializeWithoutJobId( - sessionDetails, - currentTimeProvider.currentEpochMillis()), - seqNo, - primaryTerm) - } - - logInfo( - s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") - } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala new file mode 100644 index 000000000..8e6e8a644 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging + +class StatementLifecycleManagerImpl(context: Map[String, Any]) + extends StatementLifecycleManager + with FlintJobExecutor + with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + val flintSessionIndexUpdater = + context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] + val flintReader = context("flintReader").asInstanceOf[FlintReader] + + override def prepareStatementLifecycle(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + override def updateStatement(statement: FlintStatement): Unit = { + flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) + } + override def terminateStatementLifecycle(): Unit = { + flintReader.close() + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 19f596e31..a35cb6590 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -79,7 +79,6 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { "select 1", "20", currentTime - queryRunTime, - new MockTimeProvider(currentTime), CleanerFactory.cleaner(false)) assertEqualDataframe(expected, result) } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 45ec7b2cc..11b9f7bb2 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -181,12 +181,7 @@ class FlintREPLTest "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) // Instantiate the listener - val listener = new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex, - sessionId, - timerContext) + val listener = new PreShutdownListener(osClient, flintSessionIndexUpdater) // Simulate application end listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) @@ -245,13 +240,12 @@ class FlintREPLTest // Compare the result val result = - FlintREPL.handleCommandFailureAndGetFailedData( + FlintREPL.handleStatementFailureAndGetFailedData( spark, dataSourceName, error, flintStatement, - "20", - currentTime - queryRunTime) + "20") assertEqualDataframe(expected, result) assert("failed" == flintStatement.state) assert(error == flintStatement.error.get) @@ -601,16 +595,7 @@ class FlintREPLTest val sparkContext = mock[SparkContext] when(mockSparkSession.sparkContext).thenReturn(sparkContext) - val result = FlintREPL.executeAndHandle( - mockSparkSession, - mockFlintStatement, - dataSource, - sessionId, - executionContext, - startTime, - // make sure it times out before mockSparkSession.sql can return, which takes 60 seconds - Duration(1, SECONDS), - 600000) + val result = FlintREPL.executeAndHandle() verify(mockSparkSession, times(1)).sql(any[String]) verify(sparkContext, times(1)).cancelJobGroup(any[String]) @@ -652,15 +637,7 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val result = FlintREPL.executeAndHandle( - mockSparkSession, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - Duration.Inf, // Use Duration.Inf or a large enough duration to avoid a timeout, - 600000) + val result = FlintREPL.executeAndHandle() // Verify that ParseException was caught and handled result should not be None // or result.isDefined shouldBe true