From 1f31e56da5fd45e9e8a37e05c1acf80bae786f44 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Fri, 10 Nov 2023 18:18:29 -0800 Subject: [PATCH] Implement blue/green deployment support in REPL mode (#141) * Implement blue/green deployment support in REPL mode Features: - Add mutual exclusivity in session-based emr-s jobs to ensure only one job runs at a time, enhancing system stability during blue/green deployments. This allows active job exclusion and seamless task pickup by new jobs. Details in #94. Fixes: - Resolve a bug where long-running queries failed to cancel post-index mapping verification, by introducing timely query cancellation checks within the REPL loop. - Address issue #138 with the proposed short-term fix, improving reliability. Tests: - Conducted manual testing to validate blue/green deployment support and query cancellation. - Extended unit tests to cover new features and bug fixes. Signed-off-by: Kaituo Li * add jobStartTime in session doc Signed-off-by: Kaituo Li * read excluded jobs and customize timeout setting Signed-off-by: Kaituo Li --------- Signed-off-by: Kaituo Li --- .../opensearch/flint/app/FlintCommand.scala | 8 + .../opensearch/flint/app/FlintInstance.scala | 77 +- .../flint/app/FlintInstanceTest.scala | 119 +++ .../org/apache/spark/sql/CommandContext.scala | 24 + .../org/apache/spark/sql/CommandState.scala | 17 + .../scala/org/apache/spark/sql/FlintJob.scala | 8 +- .../apache/spark/sql/FlintJobExecutor.scala | 135 +-- .../org/apache/spark/sql/FlintREPL.scala | 763 ++++++++++----- .../scala/org/apache/spark/sql/OSClient.scala | 12 + .../sql/util/DefaultThreadPoolFactory.scala | 18 + .../spark/sql/util/ThreadPoolFactory.scala | 14 + .../org/apache/spark/sql/FlintJobTest.scala | 1 + .../org/apache/spark/sql/FlintREPLTest.scala | 896 +++++++++++++++++- .../sql/util/MockThreadPoolFactory.scala | 14 + .../sql/{ => util}/MockTimeProvider.scala | 4 +- 15 files changed, 1803 insertions(+), 307 deletions(-) create mode 100644 flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultThreadPoolFactory.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockThreadPoolFactory.scala rename spark-sql-application/src/test/scala/org/apache/spark/sql/{ => util}/MockTimeProvider.scala (72%) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala index 288d661cd..7624c2c54 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala @@ -41,6 +41,14 @@ class FlintCommand( def isFailed(): Boolean = { state == "failed" } + + def isWaiting(): Boolean = { + state == "waiting" + } + + override def toString: String = { + s"FlintCommand(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" + } } object FlintCommand { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala index 3b5ed2a74..52b7d9736 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala @@ -5,11 +5,15 @@ package org.opensearch.flint.app +import java.util.{Map => JavaMap} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + import org.json4s.{Formats, NoTypeHints} -import org.json4s.JsonAST.JString +import org.json4s.JsonAST.{JArray, JString} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization -import org.opensearch.index.seqno.SequenceNumbers // lastUpdateTime is added to FlintInstance to track the last update time of the instance. Its unit is millisecond. class FlintInstance( @@ -17,8 +21,11 @@ class FlintInstance( val jobId: String, // sessionId is the session type doc id val sessionId: String, - val state: String, + var state: String, val lastUpdateTime: Long, + // We need jobStartTime to check if HMAC token is expired or not + val jobStartTime: Long, + val excludedJobIds: Seq[String] = Seq.empty[String], val error: Option[String] = None) {} object FlintInstance { @@ -32,24 +39,80 @@ object FlintInstance { val jobId = (meta \ "jobId").extract[String] val sessionId = (meta \ "sessionId").extract[String] val lastUpdateTime = (meta \ "lastUpdateTime").extract[Long] + val jobStartTime = (meta \ "jobStartTime").extract[Long] + // To handle the possibility of excludeJobIds not being present, + // we use extractOpt which gives us an Option[Seq[String]]. + // If it is not present, it will return None, which we can then + // convert to an empty Seq[String] using getOrElse. + // Replace extractOpt with jsonOption and map + val excludeJobIds: Seq[String] = meta \ "excludeJobIds" match { + case JArray(lst) => lst.map(_.extract[String]) + case _ => Seq.empty[String] + } + val maybeError: Option[String] = (meta \ "error") match { case JString(str) => Some(str) case _ => None } - new FlintInstance(applicationId, jobId, sessionId, state, lastUpdateTime, maybeError) + new FlintInstance( + applicationId, + jobId, + sessionId, + state, + lastUpdateTime, + jobStartTime, + excludeJobIds, + maybeError) + } + + def deserializeFromMap(source: JavaMap[String, AnyRef]): FlintInstance = { + // Since we are dealing with JavaMap, we convert it to a Scala mutable Map for ease of use. + val scalaSource = source.asScala + + val applicationId = scalaSource("applicationId").asInstanceOf[String] + val state = scalaSource("state").asInstanceOf[String] + val jobId = scalaSource("jobId").asInstanceOf[String] + val sessionId = scalaSource("sessionId").asInstanceOf[String] + val lastUpdateTime = scalaSource("lastUpdateTime").asInstanceOf[Long] + val jobStartTime = scalaSource("jobStartTime").asInstanceOf[Long] + + // We safely handle the possibility of excludeJobIds being absent or not a list. + val excludeJobIds: Seq[String] = scalaSource.get("excludeJobIds") match { + case Some(lst: java.util.List[_]) => lst.asScala.toList.map(_.asInstanceOf[String]) + case _ => Seq.empty[String] + } + + // Handle error similarly, ensuring we get an Option[String]. + val maybeError: Option[String] = scalaSource.get("error") match { + case Some(str: String) => Some(str) + case _ => None + } + + // Construct a new FlintInstance with the extracted values. + new FlintInstance( + applicationId, + jobId, + sessionId, + state, + lastUpdateTime, + jobStartTime, + excludeJobIds, + maybeError) } - def serialize(job: FlintInstance): String = { + def serialize(job: FlintInstance, currentTime: Long): String = { + // jobId is only readable by spark, thus we don't override jobId Serialization.write( Map( "type" -> "session", "sessionId" -> job.sessionId, "error" -> job.error.getOrElse(""), "applicationId" -> job.applicationId, - "jobId" -> job.jobId, "state" -> job.state, // update last update time - "lastUpdateTime" -> System.currentTimeMillis())) + "lastUpdateTime" -> currentTime, + "excludeJobIds" -> job.excludedJobIds, + "jobStartTime" -> job.jobStartTime)) } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala new file mode 100644 index 000000000..31749c794 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.app + +import java.util.{HashMap => JavaHashMap} + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite + +class FlintInstanceTest extends SparkFunSuite with Matchers { + + test("deserialize should correctly parse a FlintInstance with excludedJobIds from JSON") { + val json = + """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"excludeJobIds":["job-101","job-202"]}""" + val instance = FlintInstance.deserialize(json) + + instance.applicationId shouldBe "app-123" + instance.jobId shouldBe "job-456" + instance.sessionId shouldBe "session-789" + instance.state shouldBe "RUNNING" + instance.lastUpdateTime shouldBe 1620000000000L + instance.jobStartTime shouldBe 1620000001000L + instance.excludedJobIds should contain allOf ("job-101", "job-202") + instance.error shouldBe None + } + + test("serialize should correctly produce JSON from a FlintInstance with excludedJobIds") { + val excludedJobIds = Seq("job-101", "job-202") + val instance = new FlintInstance( + "app-123", + "job-456", + "session-789", + "RUNNING", + 1620000000000L, + 1620000001000L, + excludedJobIds) + val currentTime = System.currentTimeMillis() + val json = FlintInstance.serialize(instance, currentTime) + + json should include(""""applicationId":"app-123"""") + json should not include (""""jobId":"job-456"""") + json should include(""""sessionId":"session-789"""") + json should include(""""state":"RUNNING"""") + json should include(s""""lastUpdateTime":$currentTime""") + json should include(""""excludeJobIds":["job-101","job-202"]""") + json should include(""""jobStartTime":1620000001000""") + json should include(""""error":""""") + } + + test("deserialize should correctly handle an empty excludedJobIds field in JSON") { + val jsonWithoutExcludedJobIds = + """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000}""" + val instance = FlintInstance.deserialize(jsonWithoutExcludedJobIds) + + instance.excludedJobIds shouldBe empty + } + + test("deserialize should correctly handle error field in JSON") { + val jsonWithError = + """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"FAILED","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"error":"Some error occurred"}""" + val instance = FlintInstance.deserialize(jsonWithError) + + instance.error shouldBe Some("Some error occurred") + } + + test("serialize should include error when present in FlintInstance") { + val instance = new FlintInstance( + "app-123", + "job-456", + "session-789", + "FAILED", + 1620000000000L, + 1620000001000L, + Seq.empty[String], + Some("Some error occurred")) + val currentTime = System.currentTimeMillis() + val json = FlintInstance.serialize(instance, currentTime) + + json should include(""""error":"Some error occurred"""") + } + + test("deserializeFromMap should handle normal case") { + val sourceMap = new JavaHashMap[String, AnyRef]() + sourceMap.put("applicationId", "app1") + sourceMap.put("jobId", "job1") + sourceMap.put("sessionId", "session1") + sourceMap.put("state", "running") + sourceMap.put("lastUpdateTime", java.lang.Long.valueOf(1234567890L)) + sourceMap.put("jobStartTime", java.lang.Long.valueOf(9876543210L)) + sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3")) + sourceMap.put("error", "An error occurred") + + val result = FlintInstance.deserializeFromMap(sourceMap) + + assert(result.applicationId == "app1") + assert(result.jobId == "job1") + assert(result.sessionId == "session1") + assert(result.state == "running") + assert(result.lastUpdateTime == 1234567890L) + assert(result.jobStartTime == 9876543210L) + assert(result.excludedJobIds == Seq("job2", "job3")) + assert(result.error.contains("An error occurred")) + } + + test("deserializeFromMap should handle incorrect field types") { + val sourceMap = new JavaHashMap[String, AnyRef]() + sourceMap.put("applicationId", Integer.valueOf(123)) + sourceMap.put("lastUpdateTime", "1234567890") + + assertThrows[ClassCastException] { + FlintInstance.deserializeFromMap(sourceMap) + } + } + +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala new file mode 100644 index 000000000..fe2fa5212 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.concurrent.{ExecutionContextExecutor, Future} +import scala.concurrent.duration.Duration + +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} + +case class CommandContext( + spark: SparkSession, + dataSource: String, + resultIndex: String, + sessionId: String, + flintSessionIndexUpdater: OpenSearchUpdater, + osClient: OSClient, + sessionIndex: String, + jobId: String, + queryExecutionTimeout: Duration, + inactivityLimitMillis: Long, + queryWaitTimeMillis: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala new file mode 100644 index 000000000..2e285e3d9 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.concurrent.{ExecutionContextExecutor, Future} + +import org.opensearch.flint.core.storage.FlintReader + +case class CommandState( + recordedLastActivityTime: Long, + recordedVerificationResult: VerificationResult, + flintReader: FlintReader, + futureMappingCheck: Future[Either[String, Unit]], + executionContext: ExecutionContextExecutor) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 8acc2308d..9492cc6d9 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -55,10 +55,10 @@ object FlintJob extends Logging with FlintJobExecutor { var dataToWrite: Option[DataFrame] = None val startTime = System.currentTimeMillis() + // osClient needs spark session to be created first to get FlintOptions initialized. + // Otherwise, we will have connection exception from EMR-S to OS. + val osClient = new OSClient(FlintSparkConf().flintOptions()) try { - // osClient needs spark session to be created first to get FlintOptions initialized. - // Otherwise, we will have connection exception from EMR-S to OS. - val osClient = new OSClient(FlintSparkConf().flintOptions()) val futureMappingCheck = Future { checkAndCreateIndex(osClient, resultIndex) } @@ -81,7 +81,7 @@ object FlintJob extends Logging with FlintJobExecutor { dataToWrite = Some( getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) } finally { - dataToWrite.foreach(df => writeData(df, resultIndex)) + dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) // Stop SparkSession if it is not streaming job if (wait.equalsIgnoreCase("streaming")) { spark.streams.awaitAnyTermination() diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 5c9849dfd..6e7dbb926 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -14,16 +14,64 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{getFormattedData, handleIndexNotFoundException, isSuperset, logError, logInfo} +import org.apache.spark.sql.FlintJob.{createIndex, getFormattedData, isSuperset, logError, logInfo} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} -import org.apache.spark.sql.util.{RealTimeProvider, TimeProvider} +import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} trait FlintJobExecutor { this: Logging => var currentTimeProvider: TimeProvider = new RealTimeProvider() + var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() + + // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, + val resultIndexMapping = + """{ + "dynamic": false, + "properties": { + "result": { + "type": "object", + "enabled": false + }, + "schema": { + "type": "object", + "enabled": false + }, + "jobRunId": { + "type": "keyword" + }, + "applicationId": { + "type": "keyword" + }, + "dataSourceName": { + "type": "keyword" + }, + "status": { + "type": "keyword" + }, + "queryId": { + "type": "keyword" + }, + "queryText": { + "type": "text" + }, + "sessionId": { + "type": "keyword" + }, + "updateTime": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "error": { + "type": "text" + }, + "queryRunTime" : { + "type" : "long" + } + } + }""".stripMargin def createSparkConf(): SparkConf = { new SparkConf() @@ -37,13 +85,35 @@ trait FlintJobExecutor { SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() } - def writeData(resultData: DataFrame, resultIndex: String): Unit = { + private def writeData(resultData: DataFrame, resultIndex: String): Unit = { resultData.write .format("flint") .mode("append") .save(resultIndex) } + /** + * writes the DataFrame to the specified Elasticsearch index, and createIndex creates an index + * with the given mapping if it does not exist. + * @param resultData + * data to write + * @param resultIndex + * result index + * @param osClient + * OpenSearch client + */ + def writeDataFrameToOpensearch( + resultData: DataFrame, + resultIndex: String, + osClient: OSClient): Unit = { + if (osClient.doesIndexExist(resultIndex)) { + writeData(resultData, resultIndex) + } else { + createIndex(osClient, resultIndex, resultIndexMapping) + writeData(resultData, resultIndex) + } + } + /** * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. * @@ -226,56 +296,9 @@ trait FlintJobExecutor { } def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { - // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, - val mapping = - """{ - "dynamic": false, - "properties": { - "result": { - "type": "object", - "enabled": false - }, - "schema": { - "type": "object", - "enabled": false - }, - "jobRunId": { - "type": "keyword" - }, - "applicationId": { - "type": "keyword" - }, - "dataSourceName": { - "type": "keyword" - }, - "status": { - "type": "keyword" - }, - "queryId": { - "type": "keyword" - }, - "queryText": { - "type": "text" - }, - "sessionId": { - "type": "keyword" - }, - "updateTime": { - "type": "date", - "format": "strict_date_time||epoch_millis" - }, - "error": { - "type": "text" - }, - "queryRunTime" : { - "type" : "long" - } - } - }""".stripMargin - try { val existingSchema = osClient.getIndexMetadata(resultIndex) - if (!isSuperset(existingSchema, mapping)) { + if (!isSuperset(existingSchema, resultIndexMapping)) { Left(s"The mapping of $resultIndex is incorrect.") } else { Right(()) @@ -283,7 +306,7 @@ trait FlintJobExecutor { } catch { case e: IllegalStateException if e.getCause().getMessage().contains("index_not_found_exception") => - handleIndexNotFoundException(osClient, resultIndex, mapping) + createIndex(osClient, resultIndex, resultIndexMapping) case e: Exception => val error = s"Failed to verify existing mapping: ${e.getMessage}" logError(error, e) @@ -291,7 +314,7 @@ trait FlintJobExecutor { } } - def handleIndexNotFoundException( + def createIndex( osClient: OSClient, resultIndex: String, mapping: String): Either[String, Unit] = { @@ -316,6 +339,8 @@ trait FlintJobExecutor { sessionId: String): DataFrame = { // Execute SQL query val startTime = System.currentTimeMillis() + // we have to set job group in the same thread that started the query according to spark doc + spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true) val result: DataFrame = spark.sql(query) // Get Data getFormattedData( @@ -378,7 +403,7 @@ trait FlintJobExecutor { case r: SparkException => handleQueryException( r, - "Fail to run query. Cause", + "Spark exception. Cause", spark, dataSource, query, @@ -387,7 +412,7 @@ trait FlintJobExecutor { case r: Exception => handleQueryException( r, - "Fail to write result, cause", + "Fail to run query, cause", spark, dataSource, query, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index c714682c0..e444b71ee 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -5,26 +5,28 @@ package org.apache.spark.sql +import java.net.ConnectException import java.time.Instant -import java.util.concurrent.{ScheduledExecutorService, ThreadPoolExecutor} +import java.util.Map +import java.util.concurrent.ScheduledExecutorService +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} -import scala.concurrent.duration.{Duration, MINUTES, SECONDS} +import scala.concurrent.duration._ +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} -import org.opensearch.flint.app.FlintCommand.serialize import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{currentTimeProvider, getFailedData, writeData} -import org.apache.spark.sql.FlintREPL.executeQuery import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} -import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} +import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint @@ -42,13 +44,13 @@ import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val INACTIVITY_LIMIT_MILLIS = 30 * 60 * 1000 + private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 30 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) - private val QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) - private val QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) + private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 def update(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { - updater.update(flintCommand.statementId, serialize(flintCommand)) + updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) } def main(args: Array[String]) { @@ -60,6 +62,8 @@ object FlintREPL extends Logging with FlintJobExecutor { // init SparkContext val conf: SparkConf = createSparkConf() val dataSource = conf.get("spark.flint.datasource.name", "unknown") + // https://github.com/opensearch-project/opensearch-spark/issues/138 + conf.set("spark.sql.defaultCatalog", dataSource) val wait = conf.get("spark.flint.job.type", "continue") // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) @@ -80,95 +84,221 @@ object FlintREPL extends Logging with FlintJobExecutor { if (wait.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") val result = executeQuery(spark, query, dataSource, "", "") - writeData(result, resultIndex) + writeDataFrameToOpensearch(result, resultIndex, osClient) spark.streams.awaitAnyTermination() } else { + // Read the values from the Spark configuration or fall back to the default values + val inactivityLimitMillis: Long = + conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) + val queryExecutionTimeoutSecs: Duration = Duration( + conf.getLong( + "spark.flint.job.queryExecutionTimeoutSec", + DEFAULT_QUERY_EXECUTION_TIMEOUT.toSeconds), + SECONDS) + val queryWaitTimeoutMillis: Long = + conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) + + val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) + createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + // 1 thread for updating heart beat + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) + val jobStartTime = currentTimeProvider.currentEpochMillis() + try { - val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) - createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get); + // update heart beat every 30 seconds + // OpenSearch triggers recovery after 1 minute outdated heart beat + createHeartBeatUpdater( + HEARTBEAT_INTERVAL_MILLIS, + flintSessionIndexUpdater, + sessionId.get, + threadPool, + osClient, + sessionIndex.get) - queryLoop( - resultIndex, + if (setupFlintJobWithExclusionCheck( + conf, + sessionIndex, + sessionId, + osClient, + jobId, + applicationId, + flintSessionIndexUpdater, + jobStartTime)) { + return + } + + val commandContext = CommandContext( + spark, dataSource, - sessionIndex.get, + resultIndex, sessionId.get, - spark, + flintSessionIndexUpdater, osClient, + sessionIndex.get, jobId, - applicationId, - flintSessionIndexUpdater) + queryExecutionTimeoutSecs, + inactivityLimitMillis, + queryWaitTimeoutMillis) + + exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { + queryLoop(commandContext) + } + } catch { + case e: Exception => + handleSessionError( + e, + applicationId, + jobId, + sessionId.get, + jobStartTime, + flintSessionIndexUpdater, + osClient, + sessionIndex.get) } finally { spark.stop() + if (threadPool != null) { + threadPool.shutdown() + } } } } - private def queryLoop( - resultIndex: String, - dataSource: String, - sessionIndex: String, - sessionId: String, - spark: SparkSession, + /** + * Sets up a Flint job with exclusion checks based on the job configuration. + * + * This method will first check if there are any jobs to exclude from execution based on the + * configuration provided. If the current job's ID is in the exclusion list, the method will + * signal to exit early to avoid redundant execution. This is also true if the job is identified + * as a duplicate of a currently running job. + * + * If there are no conflicts with excluded job IDs or duplicate jobs, the method proceeds to set + * up the Flint job as normal. + * + * @param conf + * A SparkConf object containing the job's configuration. + * @param sessionIndex + * The index within OpenSearch where session information is stored. + * @param sessionId + * The current session's ID. + * @param osClient + * The OpenSearch client used to interact with OpenSearch. + * @param jobId + * The ID of the current job. + * @param applicationId + * The application ID for the current Flint session. + * @param flintSessionIndexUpdater + * An OpenSearch updater for Flint session indices. + * @param jobStartTime + * The start time of the job. + * @return + * A Boolean value indicating whether to exit the job early (true) or not (false). + * @note + * If the sessionIndex or sessionId Options are empty, the method will throw a + * NoSuchElementException, as `.get` is called on these options without checking for their + * presence. + */ + def setupFlintJobWithExclusionCheck( + conf: SparkConf, + sessionIndex: Option[String], + sessionId: Option[String], osClient: OSClient, jobId: String, applicationId: String, - flintSessionIndexUpdater: OpenSearchUpdater): Unit = { + flintSessionIndexUpdater: OpenSearchUpdater, + jobStartTime: Long): Boolean = { + val confExcludeJobsOpt = conf.getOption("spark.flint.deployment.excludeJobs") + + confExcludeJobsOpt match { + case None => + // If confExcludeJobs is None, pass null or an empty sequence as per your setupFlintJob method's signature + setupFlintJob( + applicationId, + jobId, + sessionId.get, + flintSessionIndexUpdater, + sessionIndex.get, + jobStartTime) + + case Some(confExcludeJobs) => + // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 + val excludeJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis + + if (excludeJobIds.contains(jobId)) { + // Edge case, current job is excluded, exit the application + return true + } + + val getResponse = osClient.getDoc(sessionIndex.get, sessionId.get) + if (getResponse.isExists()) { + val source = getResponse.getSourceAsMap + if (source != null) { + val existingExcludedJobIds = parseExcludedJobIds(source) + if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { + // Edge case, duplicate job running, exit the application + return true + } + } + } + + // If none of the edge cases are met, proceed with setup + setupFlintJob( + applicationId, + jobId, + sessionId.get, + flintSessionIndexUpdater, + sessionIndex.get, + jobStartTime, + excludeJobIds) + } + false + } + + def queryLoop(commandContext: CommandContext): Unit = { // 1 thread for updating heart beat - val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) + val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - try { val futureMappingCheck = Future { - checkAndCreateIndex(osClient, resultIndex) + checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) } - setupFlintJob(applicationId, jobId, sessionId, flintSessionIndexUpdater, sessionIndex) - - // update heart beat every 30 seconds - // OpenSearch triggers recovery after 1 minute outdated heart beat - createHeartBeatUpdater( - HEARTBEAT_INTERVAL_MILLIS, - flintSessionIndexUpdater, - sessionId: String, - threadPool, - osClient, - sessionIndex) - - var lastActivityTime = Instant.now().toEpochMilli() + var lastActivityTime = currentTimeProvider.currentEpochMillis() var verificationResult: VerificationResult = NotVerified - - while (Instant.now().toEpochMilli() - lastActivityTime <= INACTIVITY_LIMIT_MILLIS) { - logInfo(s"""read from ${sessionIndex}""") + var canPickUpNextStatement = true + while (currentTimeProvider + .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { + logDebug(s"""read from ${commandContext.sessionIndex}""") val flintReader: FlintReader = - createQueryReader(osClient, applicationId, sessionId, sessionIndex, dataSource) + createQueryReader( + commandContext.osClient, + commandContext.sessionId, + commandContext.sessionIndex, + commandContext.dataSource) try { - val result: (Long, VerificationResult) = processCommands( + val commandState = CommandState( + lastActivityTime, + verificationResult, flintReader, - spark, - dataSource, - resultIndex, - sessionId, futureMappingCheck, - verificationResult, - executionContext, - flintSessionIndexUpdater, - lastActivityTime) + executionContext) + val result: (Long, VerificationResult, Boolean) = + processCommands(commandContext, commandState) - val (updatedLastActivityTime, updatedVerificationResult) = result + val ( + updatedLastActivityTime, + updatedVerificationResult, + updatedCanPickUpNextStatement) = result lastActivityTime = updatedLastActivityTime verificationResult = updatedVerificationResult + canPickUpNextStatement = updatedCanPickUpNextStatement } finally { flintReader.close() } Thread.sleep(100) } - - } catch { - case e: Exception => - handleSessionError(e, applicationId, jobId, sessionId, flintSessionIndexUpdater) } finally { if (threadPool != null) { threadPool.shutdown() @@ -181,35 +311,86 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId: String, sessionId: String, flintSessionIndexUpdater: OpenSearchUpdater, - sessionIndex: String): Unit = { + sessionIndex: String, + jobStartTime: Long, + excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { val flintJob = - new FlintInstance(applicationId, jobId, sessionId, "running", System.currentTimeMillis()) - flintSessionIndexUpdater.upsert(sessionId, FlintInstance.serialize(flintJob)) - logInfo( + new FlintInstance( + applicationId, + jobId, + sessionId, + "running", + currentTimeProvider.currentEpochMillis(), + jobStartTime, + excludeJobIds) + flintSessionIndexUpdater.upsert( + sessionId, + FlintInstance.serialize(flintJob, currentTimeProvider.currentEpochMillis())) + logDebug( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") } - private def handleSessionError( + def handleSessionError( e: Exception, applicationId: String, jobId: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater): Unit = { - val error = s"Unexpected error: ${e.getMessage}" + jobStartTime: Long, + flintSessionIndexUpdater: OpenSearchUpdater, + osClient: OSClient, + sessionIndex: String): Unit = { + val error = s"Session error: ${e.getMessage}" logError(error, e) - val flintJob = new FlintInstance( - applicationId, - jobId, + + val flintInstance = getExistingFlintInstance(osClient, sessionIndex, sessionId) + .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) + + updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) + } + + private def getExistingFlintInstance( + osClient: OSClient, + sessionIndex: String, + sessionId: String): Option[FlintInstance] = Try( + osClient.getDoc(sessionIndex, sessionId)) match { + case Success(getResponse) if getResponse.isExists() => + Option(getResponse.getSourceAsMap) + .map(FlintInstance.deserializeFromMap) + case Failure(exception) => + logError(s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", exception) + None + case _ => None + } + + private def createFailedFlintInstance( + applicationId: String, + jobId: String, + sessionId: String, + jobStartTime: Long, + errorMessage: String): FlintInstance = new FlintInstance( + applicationId, + jobId, + sessionId, + "fail", + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = Some(errorMessage)) + + private def updateFlintInstance( + flintInstance: FlintInstance, + flintSessionIndexUpdater: OpenSearchUpdater, + sessionId: String): Unit = { + val currentTime = currentTimeProvider.currentEpochMillis() + flintSessionIndexUpdater.upsert( sessionId, - "fail", - System.currentTimeMillis(), - Some(error)) - flintSessionIndexUpdater.upsert(sessionId, FlintInstance.serialize(flintJob)) + FlintInstance.serialize(flintInstance, currentTime)) } /** * handling the case where a command's execution fails, updates the flintCommand with the error - * and failure status, and then delegates to the second method for actual DataFrame creation + * and failure status, and then write the result to result index. Thus, an error is written to + * both result index or statement model in request index + * * @param spark * spark session * @param dataSource @@ -264,43 +445,50 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processCommands( - flintReader: FlintReader, - spark: SparkSession, - dataSource: String, - resultIndex: String, - sessionId: String, - futureMappingCheck: Future[Either[String, Unit]], - recordedVerificationResult: VerificationResult, - executionContext: ExecutionContextExecutor, - flintSessionIndexUpdater: OpenSearchUpdater, - recordedLastActivityTime: Long): (Long, VerificationResult) = { + context: CommandContext, + state: CommandState): (Long, VerificationResult, Boolean) = { + import context._ + import state._ + var lastActivityTime = recordedLastActivityTime var verificationResult = recordedVerificationResult + var canProceed = true + var canPickNextStatementResult = true // Add this line to keep track of canPickNextStatement + + while (canProceed) { + if (!canPickNextStatement(sessionId, jobId, osClient, sessionIndex)) { + canPickNextStatementResult = false + canProceed = false + } else if (!flintReader.hasNext) { + canProceed = false + } else { + lastActivityTime = currentTimeProvider.currentEpochMillis() + val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) + + val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( + recordedVerificationResult, + spark, + flintCommand, + dataSource, + sessionId, + executionContext, + futureMappingCheck, + resultIndex, + queryExecutionTimeout, + queryWaitTimeMillis) - while (flintReader.hasNext) { - lastActivityTime = Instant.now().toEpochMilli() - val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) - - spark.sparkContext.setJobGroup( - flintCommand.queryId, - "Job group for " + flintCommand.queryId, - interruptOnCancel = true) - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintCommand, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex) - - verificationResult = returnedVerificationResult - finalizeCommand(dataToWrite, flintCommand, resultIndex, flintSessionIndexUpdater) + verificationResult = returnedVerificationResult + finalizeCommand( + dataToWrite, + flintCommand, + resultIndex, + flintSessionIndexUpdater, + osClient) + } } // return tuple indicating if still active and mapping verification result - (lastActivityTime, verificationResult) + (lastActivityTime, verificationResult, canPickNextStatementResult) } /** @@ -319,16 +507,18 @@ object FlintREPL extends Logging with FlintJobExecutor { dataToWrite: Option[DataFrame], flintCommand: FlintCommand, resultIndex: String, - flintSessionIndexUpdater: OpenSearchUpdater): Unit = { + flintSessionIndexUpdater: OpenSearchUpdater, + osClient: OSClient): Unit = { try { - dataToWrite.foreach(df => writeData(df, resultIndex)) - if (flintCommand.isRunning()) { + dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + if (flintCommand.isRunning() || flintCommand.isWaiting()) { // we have set failed state in exception handling flintCommand.complete() } update(flintCommand, flintSessionIndexUpdater) } catch { // e.g., maybe due to authentication service connection issue + // or invalid catalog (e.g., we are operating on data not defined in provided data source) case e: Exception => val error = s"""Fail to write result of ${flintCommand}, cause: ${e.getMessage}""" logError(error, e) @@ -336,6 +526,80 @@ object FlintREPL extends Logging with FlintJobExecutor { update(flintCommand, flintSessionIndexUpdater) } } + + private def handleCommandTimeout( + spark: SparkSession, + dataSource: String, + error: String, + flintCommand: FlintCommand, + sessionId: String, + startTime: Long): Option[DataFrame] = { + /* + * https://tinyurl.com/2ezs5xj9 + * + * This only interrupts active Spark jobs that are actively running. + * This would then throw the error from ExecutePlan and terminate it. + * But if the query is not running a Spark job, but executing code on Spark driver, this + * would be a noop and the execution will keep running. + * + * In Apache Spark, actions that trigger a distributed computation can lead to the creation + * of Spark jobs. In the context of Spark SQL, this typically happens when we perform + * actions that require the computation of results that need to be collected or stored. + */ + spark.sparkContext.cancelJobGroup(flintCommand.queryId) + logError(error) + Some( + handleCommandFailureAndGetFailedData( + spark, + dataSource, + error, + flintCommand, + sessionId, + startTime)) + } + + def executeAndHandle( + spark: SparkSession, + flintCommand: FlintCommand, + dataSource: String, + sessionId: String, + executionContext: ExecutionContextExecutor, + startTime: Long, + queryExecuitonTimeOut: Duration, + queryWaitTimeMillis: Long): Option[DataFrame] = { + try { + Some( + executeQueryAsync( + spark, + flintCommand, + dataSource, + sessionId, + executionContext, + startTime, + queryExecuitonTimeOut, + queryWaitTimeMillis)) + } catch { + case e: TimeoutException => + handleCommandTimeout( + spark, + dataSource, + s"Executing ${flintCommand.query} timed out", + flintCommand, + sessionId, + startTime) + case e: Exception => + val error = processQueryException(e, spark, dataSource, flintCommand.query, "", "") + Some( + handleCommandFailureAndGetFailedData( + spark, + dataSource, + error, + flintCommand, + sessionId, + startTime)) + } + } + private def processStatementOnVerification( recordedVerificationResult: VerificationResult, spark: SparkSession, @@ -344,43 +608,30 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId: String, executionContext: ExecutionContextExecutor, futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String): (Option[DataFrame], VerificationResult) = { - val startTime: Long = System.currentTimeMillis() + resultIndex: String, + queryExecutionTimeout: Duration, + queryWaitTimeMillis: Long): (Option[DataFrame], VerificationResult) = { + val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None - try { - verificationResult match { - case NotVerified => - try { - val mappingCheckResult = - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) - // time out after 10 minutes - val result = executeQueryAsync( - spark, - flintCommand, - dataSource, - sessionId, - executionContext, - startTime) - - dataToWrite = Some(mappingCheckResult match { - case Right(_) => - verificationResult = VerifiedWithoutError - result - case Left(error) => - verificationResult = VerifiedWithError(error) - handleCommandFailureAndGetFailedData( - spark, - dataSource, - error, - flintCommand, - sessionId, - startTime) - }) - } catch { - case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" - logError(error, e) + + verificationResult match { + case NotVerified => + try { + ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + case Right(_) => + dataToWrite = executeAndHandle( + spark, + flintCommand, + dataSource, + sessionId, + executionContext, + startTime, + queryExecutionTimeout, + queryWaitTimeMillis) + verificationResult = VerifiedWithoutError + case Left(error) => + verificationResult = VerifiedWithError(error) dataToWrite = Some( handleCommandFailureAndGetFailedData( spark, @@ -390,51 +641,44 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, startTime)) } - case VerifiedWithError(err) => - dataToWrite = Some( - handleCommandFailureAndGetFailedData( - spark, - dataSource, - err, - flintCommand, - sessionId, - startTime)) - case VerifiedWithoutError => - dataToWrite = Some( - executeQueryAsync( - spark, - flintCommand, - dataSource, - sessionId, - executionContext, - startTime)) - } - - logDebug(s"""command complete: ${flintCommand}""") - } catch { - case e: TimeoutException => - val error = s"Executing ${flintCommand.query} timed out" - spark.sparkContext.cancelJobGroup(flintCommand.queryId) - logError(error, e) + } catch { + case e: TimeoutException => + val error = s"Getting the mapping of index $resultIndex timed out" + dataToWrite = + handleCommandTimeout(spark, dataSource, error, flintCommand, sessionId, startTime) + case NonFatal(e) => + val error = s"An unexpected error occurred: ${e.getMessage}" + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + spark, + dataSource, + error, + flintCommand, + sessionId, + startTime)) + } + case VerifiedWithError(err) => dataToWrite = Some( handleCommandFailureAndGetFailedData( spark, dataSource, - error, - flintCommand, - sessionId, - startTime)) - case e: Exception => - val error = processQueryException(e, spark, dataSource, flintCommand, sessionId) - dataToWrite = Some( - handleCommandFailureAndGetFailedData( - spark, - dataSource, - error, + err, flintCommand, sessionId, startTime)) + case VerifiedWithoutError => + dataToWrite = executeAndHandle( + spark, + flintCommand, + dataSource, + sessionId, + executionContext, + startTime, + queryExecutionTimeout, + queryWaitTimeMillis) } + + logDebug(s"command complete: $flintCommand") (dataToWrite, verificationResult) } @@ -444,8 +688,11 @@ object FlintREPL extends Logging with FlintJobExecutor { dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, - startTime: Long): DataFrame = { - if (Instant.now().toEpochMilli() - flintCommand.submitTime > QUERY_WAIT_TIMEOUT_MILLIS) { + startTime: Long, + queryExecutionTimeOut: Duration, + queryWaitTimeMillis: Long): DataFrame = { + if (currentTimeProvider + .currentEpochMillis() - flintCommand.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( spark, dataSource, @@ -459,7 +706,7 @@ object FlintREPL extends Logging with FlintJobExecutor { }(executionContext) // time out after 10 minutes - ThreadUtils.awaitResult(futureQueryExecution, QUERY_EXECUTION_TIMEOUT) + ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) } } private def processCommandInitiation( @@ -477,21 +724,12 @@ object FlintREPL extends Logging with FlintJobExecutor { private def createQueryReader( osClient: OSClient, - applicationId: String, sessionId: String, sessionIndex: String, dataSource: String) = { // all state in index are in lower case - // - // add application to deal with emr-s deployment: - // Should the EMR-S application be terminated or halted (e.g., during deployments), OpenSearch associates each - // query to ascertain the latest EMR-S application ID. While existing EMR-S jobs continue executing queries sharing - // the same applicationId and sessionId, new requests with differing application IDs will not be processed by the - // current EMR-S job. Gradually, the queue with the same applicationId and sessionId diminishes, enabling the - // job to self-terminate due to inactivity timeout. Following the termination of all jobs, the CP can gracefully - // shut down the application. The session has a phantom state where the old application set the state to dead and - // then the new application job sets it to running through heartbeat. Also, application id can help in the case - // of job restart where the job id is different but application id is the same. + // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the + // same doc val dsl = s"""{ | "bool": { @@ -508,22 +746,23 @@ object FlintREPL extends Logging with FlintJobExecutor { | }, | { | "term": { - | "applicationId": "$applicationId" + | "sessionId": "$sessionId" | } | }, | { | "term": { - | "sessionId": "$sessionId" + | "dataSourceName": "$dataSource" | } | }, | { - | "term": { - | "dataSourceName": "$dataSource" + | "range": { + | "submitTime": { "gte": "now-1h" } | } | } | ] | } |}""".stripMargin + val flintReader = osClient.createReader(sessionIndex, dsl, "submitTime") flintReader } @@ -564,17 +803,12 @@ object FlintREPL extends Logging with FlintJobExecutor { getResponse: GetResponse, flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String): Unit = { - val flintInstant = new FlintInstance( - source.get("applicationId").asInstanceOf[String], - source.get("jobId").asInstanceOf[String], - source.get("sessionId").asInstanceOf[String], - "dead", - source.get("lastUpdateTime").asInstanceOf[Long], - Option(source.get("error").asInstanceOf[String])) + val flintInstance = FlintInstance.deserializeFromMap(source) + flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( sessionId, - FlintInstance.serialize(flintInstant), + FlintInstance.serialize(flintInstance, currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) } @@ -607,16 +841,11 @@ object FlintREPL extends Logging with FlintJobExecutor { val getResponse = osClient.getDoc(sessionIndex, sessionId) if (getResponse.isExists()) { val source = getResponse.getSourceAsMap - val flintInstant: FlintInstance = new FlintInstance( - source.get("applicationId").asInstanceOf[String], - source.get("jobId").asInstanceOf[String], - source.get("sessionId").asInstanceOf[String], - "running", - source.get("lastUpdateTime").asInstanceOf[Long], - Option(source.get("error").asInstanceOf[String])) + val flintInstance = FlintInstance.deserializeFromMap(source) + flintInstance.state = "running" flintSessionUpdater.updateIf( sessionId, - FlintInstance.serialize(flintInstant), + FlintInstance.serialize(flintInstance, currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) } @@ -634,4 +863,106 @@ object FlintREPL extends Logging with FlintJobExecutor { currentInterval, java.util.concurrent.TimeUnit.MILLISECONDS) } + + /** + * Reads the session store to get excluded jobs and the current job ID. If the current job ID + * (myJobId) is not the running job ID (runJobId), or if myJobId is in the list of excluded + * jobs, it returns false. The first condition ensures we only have one active job for one + * session thus avoid race conditions on statement execution and states. The 2nd condition + * ensures we don't pick up a job that has been excluded from the session store and thus CP has + * a way to notify spark when deployments are happening. If excludeJobs is null or none of the + * above conditions are met, it returns true. + * @return + * whether we can start fetching next statement or not + */ + def canPickNextStatement( + sessionId: String, + jobId: String, + osClient: OSClient, + sessionIndex: String): Boolean = { + try { + val getResponse = osClient.getDoc(sessionIndex, sessionId) + if (getResponse.isExists()) { + val source = getResponse.getSourceAsMap + if (source == null) { + logError(s"""Session id ${sessionId} is empty""") + // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) + return true + } + + val runJobId = Option(source.get("jobId")).map(_.asInstanceOf[String]).orNull + val excludeJobIds: Seq[String] = parseExcludedJobIds(source) + + if (runJobId != null && jobId != runJobId) { + logInfo(s"""the current job ID ${jobId} is not the running job ID ${runJobId}""") + return false + } + if (excludeJobIds != null && excludeJobIds.contains(jobId)) { + logInfo(s"""${jobId} is in the list of excluded jobs""") + return false + } + true + } else { + // still proceed since we are not sure what happened (e.g., session doc may not be available yet) + logError(s"""Fail to find id ${sessionId} from session index""") + true + } + } catch { + // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) + case e: Exception => + logError(s"""Fail to find id ${sessionId} from session index.""", e) + true + } + } + + private def parseExcludedJobIds(source: java.util.Map[String, AnyRef]): Seq[String] = { + + val rawExcludeJobIds = source.get("excludeJobIds") + Option(rawExcludeJobIds) + .map { + case s: String => Seq(s) + case list: java.util.List[_] @unchecked => + import scala.collection.JavaConverters._ + list.asScala.toList + .collect { case str: String => str } // Collect only strings from the list + case other => + logInfo(s"Unexpected type: ${other.getClass.getName}") + Seq.empty + } + .getOrElse(Seq.empty[String]) // In case of null, return an empty Seq + } + + def exponentialBackoffRetry[T](maxRetries: Int, initialDelay: FiniteDuration)( + block: => T): T = { + var retries = 0 + var result: Option[T] = None + var toContinue = true + while (retries < maxRetries && toContinue) { + try { + result = Some(block) + toContinue = false + } catch { + /* + * If `request_index` is unavailable, the system attempts to retry up to five times. After unsuccessful retries, + * the session state is set to 'failed', and the job will terminate. There are cases where `request_index` + * unavailability might prevent the Spark job from updating the session state, leading to it erroneously remaining + * as 'not_started' or 'running'. While 'not_started' is not problematic, a 'running' status requires the plugin + * to handle it effectively to prevent inconsistencies. Notably, there's a bug where InteractiveHandler fails to + * invalidate a REPL session with an outdated `lastUpdateTime`. This issue is documented as sql#2415 and must be + * resolved to maintain system reliability. + */ + case e: RuntimeException + if e.getCause != null && e.getCause.isInstanceOf[ConnectException] => + retries += 1 + val delay = initialDelay * math.pow(2, retries - 1).toLong + logError(s"Fail to connect to OpenSearch cluster. Retrying in $delay...", e) + Thread.sleep(delay.toMillis) + + case e: Exception => + throw e + } + } + + result.getOrElse(throw new RuntimeException("Failed after retries")) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index f3f118702..ceacc7bcd 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -140,4 +140,16 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { case e: IOException => throw new RuntimeException(e) } + + def doesIndexExist(indexName: String): Boolean = { + using(flintClient.createClient()) { client => + try { + val request = new GetIndexRequest(indexName) + client.indices().exists(request, RequestOptions.DEFAULT) + } catch { + case e: Exception => + throw new IllegalStateException(s"Failed to check if index $indexName exists", e) + } + } + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultThreadPoolFactory.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultThreadPoolFactory.scala new file mode 100644 index 000000000..33d44a985 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultThreadPoolFactory.scala @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ScheduledExecutorService + +import org.apache.spark.util.ThreadUtils + +class DefaultThreadPoolFactory extends ThreadPoolFactory { + override def newDaemonThreadPoolScheduledExecutor( + threadNamePrefix: String, + numThreads: Int): ScheduledExecutorService = { + ThreadUtils.newDaemonThreadPoolScheduledExecutor(threadNamePrefix, numThreads) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala new file mode 100644 index 000000000..9c97dfe96 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThreadPoolFactory.scala @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ScheduledExecutorService + +trait ThreadPoolFactory { + def newDaemonThreadPoolScheduledExecutor( + threadNamePrefix: String, + numThreads: Int): ScheduledExecutorService +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index e6b8d8289..065c0bb67 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -7,6 +7,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.MockTimeProvider class FlintJobTest extends SparkFunSuite with JobMatchers { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 96ea91ebd..704045e8a 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -5,28 +5,37 @@ package org.apache.spark.sql +import java.net.ConnectException +import java.time.Instant import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, TimeoutException} +import scala.concurrent.duration._ +import scala.concurrent.duration.{Duration, MINUTES} +import scala.reflect.runtime.universe.TypeTag -import org.mockito.{ArgumentMatchersSugar, IdiomaticMockito} +import org.mockito.ArgumentMatchersSugar import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.app.FlintCommand -import org.opensearch.flint.core.storage.OpenSearchUpdater +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.scalatestplus.mockito.MockitoSugar -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{ArrayType, LongType, NullType, StringType, StructField, StructType} -import org.apache.spark.sql.util.ShutdownHookManagerTrait +import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} +import org.apache.spark.util.ThreadUtils class FlintREPLTest extends SparkFunSuite with MockitoSugar with ArgumentMatchersSugar with JobMatchers { - val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() // By using a type alias and casting, I can bypass the type checking error. type AnyScheduledFuture = ScheduledFuture[_] @@ -36,7 +45,7 @@ class FlintREPLTest val osClient = mock[OSClient] val threadPool = mock[ScheduledExecutorService] val getResponse = mock[GetResponse] - val scheduledFutureRaw: ScheduledFuture[_] = mock[ScheduledFuture[_]] + val scheduledFutureRaw = mock[ScheduledFuture[_]] // Mock behaviors when(osClient.getDoc(*, *)).thenReturn(getResponse) @@ -47,7 +56,9 @@ class FlintREPLTest "jobId" -> "job1", "sessionId" -> "session1", "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError").asJava) + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) when(getResponse.getSeqNo).thenReturn(0L) when(getResponse.getPrimaryTerm).thenReturn(0L) // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. @@ -94,7 +105,9 @@ class FlintREPLTest "sessionId" -> "session1", "state" -> "running", "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError").asJava) + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) val mockShutdownHookManager = new ShutdownHookManagerTrait { override def addShutdownHook(hook: () => Unit): AnyRef = { @@ -132,8 +145,8 @@ class FlintREPLTest StructField("updateTime", LongType, nullable = false), StructField("queryRunTime", LongType, nullable = false))) - val currentTime: Long = System.currentTimeMillis() - val queryRunTime: Long = 3000L + val currentTime = System.currentTimeMillis() + val queryRunTime = 3000L val error = "some error" val expectedRows = Seq( @@ -150,24 +163,863 @@ class FlintREPLTest "20", currentTime, queryRunTime)) - val expected: DataFrame = + val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() + + val expected = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) val flintCommand = new FlintCommand("failed", "select 1", "30", "10", currentTime, None) - FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime) + try { + FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime) + + // Compare the result + val result = + FlintREPL.handleCommandFailureAndGetFailedData( + spark, + dataSourceName, + error, + flintCommand, + "20", + currentTime - queryRunTime) + assertEqualDataframe(expected, result) + assert("failed" == flintCommand.state) + assert(error == flintCommand.error.get) + } finally { + spark.close() + FlintREPL.currentTimeProvider = new RealTimeProvider() + + } + } + + test("test canPickNextStatement: Doc Exists and Valid JobId") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + when(getResponse.getSourceAsMap).thenReturn(sourceMap) + + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + assert(result) + } + + test("test canPickNextStatement: Doc Exists but Different JobId") { + val sessionId = "session123" + val jobId = "jobABC" + val differentJobId = "jobXYZ" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", differentJobId.asInstanceOf[Object]) + when(getResponse.getSourceAsMap).thenReturn(sourceMap) + + // Execute the method under test + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // Assertions + assert(!result) // The function should return false + } + + test("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val excludeJobIdsList = new java.util.ArrayList[String]() + excludeJobIdsList.add(jobId) // Add the jobId to the list to simulate exclusion + + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", jobId) // The jobId matches + sourceMap.put("excludeJobIds", excludeJobIdsList) // But jobId is in the exclude list + when(getResponse.getSourceAsMap).thenReturn(sourceMap) + + // Execute the method under test + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // Assertions + assert(!result) // The function should return false because jobId is excluded + } + + test("test canPickNextStatement: Doc Exists but Source is Null") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + // Mock the getDoc response + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null + + // Execute the method under test + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // Assertions + assert(result) // The function should return true despite the null source + } + + test("test canPickNextStatement: Doc Exists with Unexpected Type in excludeJobIds") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", jobId) + sourceMap.put( + "excludeJobIds", + Integer.valueOf(123) + ) // Using an Integer here to represent an unexpected type + + when(getResponse.getSourceAsMap).thenReturn(sourceMap) + + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + assert(result) // The function should return true + } + + test("test canPickNextStatement: Doc Does Not Exist") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + // Set up the mock GetResponse + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist + + // Execute the function under test + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // Assert the function returns true + assert(result) + } + + test("test canPickNextStatement: OSClient Throws Exception") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + // Set up the mock OSClient to throw an exception + when(osClient.getDoc(sessionIndex, sessionId)) + .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) + + // Execute the method under test and expect true, since the method is designed to return true even in case of an exception + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // Verify the result is true despite the exception + assert(result) + } + + test( + "test canPickNextStatement: Doc Exists and excludeJobIds is a Single String Not Matching JobId") { + val sessionId = "session123" + val jobId = "jobABC" + val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + // Create a sourceMap with excludeJobIds as a String that does NOT match jobId + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("excludeJobIds", nonMatchingExcludeJobId.asInstanceOf[Object]) + + when(getResponse.getSourceAsMap).thenReturn(sourceMap) + + // Execute the method under test + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // The function should return true since jobId is not excluded + assert(result) + } - // Compare the result - val result = - FlintREPL.handleCommandFailureAndGetFailedData( + test("Doc Exists and excludeJobIds is an ArrayList Containing JobId") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + val handleSessionError = mock[Function1[String, Unit]] + + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + // Create a sourceMap with excludeJobIds as an ArrayList containing jobId + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + + // Creating an ArrayList and adding the jobId to it + val excludeJobIdsList = new java.util.ArrayList[String]() + excludeJobIdsList.add(jobId) + sourceMap.put("excludeJobIds", excludeJobIdsList.asInstanceOf[Object]) + + when(getResponse.getSourceAsMap).thenReturn(sourceMap) + + // Execute the method under test + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // The function should return false since jobId is excluded + assert(!result) + } + + test("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { + val sessionId = "session123" + val jobId = "jobABC" + val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + + val getResponse = mock[GetResponse] + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + + // Creating an ArrayList and adding a different jobId to it + val excludeJobIdsList = new java.util.ArrayList[String]() + excludeJobIdsList.add("jobXYZ") // This ID does not match the jobId + sourceMap.put("excludeJobIds", excludeJobIdsList.asInstanceOf[Object]) + + when(getResponse.getSourceAsMap).thenReturn(sourceMap) + + // Execute the method under test + val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + + // The function should return true since the jobId is not in the excludeJobIds list + assert(result) + } + + test("exponentialBackoffRetry should retry on ConnectException") { + val mockReader = mock[OpenSearchReader] + val exception = new RuntimeException( + new ConnectException( + "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) + val osClient = mock[OSClient] + when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(mockReader.hasNext).thenThrow(exception) + + val maxRetries = 1 + var actualRetries = 0 + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + val applicationId = "testApplicationId" + + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + try { + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( spark, - dataSourceName, - error, - flintCommand, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + 60, + 60) + + intercept[RuntimeException] { + FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { + actualRetries += 1 + FlintREPL.queryLoop(commandContext) + } + } + + assert(actualRetries == maxRetries) + } finally { + // Stop the SparkSession + spark.stop() + } + } + + test("executeAndHandle should handle TimeoutException properly") { + val mockSparkSession = mock[SparkSession] + val mockFlintCommand = mock[FlintCommand] + // val mockExecutionContextExecutor: ExecutionContextExecutor = mock[ExecutionContextExecutor] + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + try { + val dataSource = "someDataSource" + val sessionId = "someSessionId" + val startTime = System.currentTimeMillis() + val expectedDataFrame = mock[DataFrame] + + when(mockFlintCommand.query).thenReturn("SELECT 1") + when(mockFlintCommand.submitTime).thenReturn(Instant.now().toEpochMilli()) + // When the `sql` method is called, execute the custom Answer that introduces a delay + when(mockSparkSession.sql(any[String])).thenAnswer(new Answer[DataFrame] { + override def answer(invocation: InvocationOnMock): DataFrame = { + // Introduce a delay of 60 seconds + Thread.sleep(60000) + + expectedDataFrame + } + }) + + when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]])) + .thenReturn(expectedDataFrame) + + when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) + + val sparkContext = mock[SparkContext] + when(mockSparkSession.sparkContext).thenReturn(sparkContext) + + val result = FlintREPL.executeAndHandle( + mockSparkSession, + mockFlintCommand, + dataSource, + sessionId, + executionContext, + startTime, + // make sure it times out before mockSparkSession.sql can return, which takes 60 seconds + Duration(1, SECONDS), + 600000) + + verify(mockSparkSession, times(1)).sql(any[String]) + verify(sparkContext, times(1)).cancelJobGroup(any[String]) + result should not be None + } finally threadPool.shutdown() + } + + test("executeAndHandle should handle ParseException properly") { + val mockSparkSession = mock[SparkSession] + val flintCommand = + new FlintCommand( + "Running", + "select * from default.http_logs limit1 1", + "10", "20", - currentTime - queryRunTime) - assertEqualDataframe(expected, result) - assert("failed" == flintCommand.state) - assert(error == flintCommand.error.get) + Instant.now().toEpochMilli, + None) + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + try { + val dataSource = "someDataSource" + val sessionId = "someSessionId" + val startTime = System.currentTimeMillis() + val expectedDataFrame = mock[DataFrame] + + // sql method can only throw RuntimeException + when(mockSparkSession.sql(any[String])).thenThrow( + new RuntimeException(new ParseException(None, "INVALID QUERY", Origin(), Origin()))) + val sparkContext = mock[SparkContext] + when(mockSparkSession.sparkContext).thenReturn(sparkContext) + + // Assume handleQueryException logs the error and returns an error message string + val mockErrorString = "Error due to syntax" + when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]])) + .thenReturn(expectedDataFrame) + when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) + + val result = FlintREPL.executeAndHandle( + mockSparkSession, + flintCommand, + dataSource, + sessionId, + executionContext, + startTime, + Duration.Inf, // Use Duration.Inf or a large enough duration to avoid a timeout, + 600000) + + // Verify that ParseException was caught and handled + result should not be None // or result.isDefined shouldBe true + flintCommand.error should not be None + flintCommand.error.get should include("Syntax error:") + } finally threadPool.shutdown() + + } + + test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { + val osClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> "app1", + "jobId" -> "job1", + "sessionId" -> "session1", + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + when(getResponse.getSeqNo).thenReturn(0L) + when(getResponse.getPrimaryTerm).thenReturn(0L) + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + + // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking + val result = FlintREPL.setupFlintJobWithExclusionCheck( + mockConf, + Some("sessionIndex"), + Some("sessionId"), + osClient, + "jobId", + "appId", + flintSessionIndexUpdater, + System.currentTimeMillis()) + assert(!result) // Expecting false as the job should proceed normally + } + + test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { + val osClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + // Mock the rest of the GetResponse as needed + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") + + val result = FlintREPL.setupFlintJobWithExclusionCheck( + mockConf, + Some("sessionIndex"), + Some("sessionId"), + osClient, + "jobId", + "appId", + flintSessionIndexUpdater, + System.currentTimeMillis()) + assert(result) // Expecting true as the job should exit early + } + + test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { + val osClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + // Mock the GetResponse to simulate a scenario of a duplicate job + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> "app1", + "jobId" -> "job1", + "sessionId" -> "session1", + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L), + "excludeJobIds" -> java.util.Arrays + .asList("job-2", "job-1") // Include this inside the Map + ).asJava) + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") + + val result = FlintREPL.setupFlintJobWithExclusionCheck( + mockConf, + Some("sessionIndex"), + Some("sessionId"), + osClient, + "jobId", + "appId", + flintSessionIndexUpdater, + System.currentTimeMillis()) + assert(result) // Expecting true for early exit due to duplicate job + } + + test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { + val osClient = mock[OSClient] + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") + + val result = FlintREPL.setupFlintJobWithExclusionCheck( + mockConf, + Some("sessionIndex"), + Some("sessionId"), + osClient, + "jobId", + "appId", + flintSessionIndexUpdater, + System.currentTimeMillis()) + assert(!result) // Expecting false as the job proceeds normally + } + + test( + "setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") { + val osClient = mock[OSClient] + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + + assertThrows[NoSuchElementException] { + FlintREPL.setupFlintJobWithExclusionCheck( + mockConf, + None, // No sessionIndex provided + None, // No sessionId provided + osClient, + "jobId", + "appId", + flintSessionIndexUpdater, + System.currentTimeMillis()) + } + } + + test("queryLoop continue until inactivity limit is reached") { + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + val shortInactivityLimit = 500 // 500 milliseconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + shortInactivityLimit, + 60) + + // Mock processCommands to always allow loop continuation + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + + val startTime = System.currentTimeMillis() + + FlintREPL.queryLoop(commandContext) + + val endTime = System.currentTimeMillis() + + // Check if the loop ran for approximately the duration of the inactivity limit + assert(endTime - startTime >= shortInactivityLimit) + + // Stop the SparkSession + spark.stop() + } + + test("queryLoop should stop when canPickUpNextStatement is false") { + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(true) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + val longInactivityLimit = 10000 // 10 seconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + longInactivityLimit, + 60) + + // Mocking canPickNextStatement to return false + when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { + val mockGetResponse = mock[GetResponse] + when(mockGetResponse.isExists()).thenReturn(true) + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("jobId", "differentJobId") + when(mockGetResponse.getSourceAsMap).thenReturn(sourceMap) + mockGetResponse + }) + + val startTime = System.currentTimeMillis() + + FlintREPL.queryLoop(commandContext) + + val endTime = System.currentTimeMillis() + + // Check if the loop stopped before the inactivity limit + assert(endTime - startTime < longInactivityLimit) + + // Stop the SparkSession + spark.stop() + } + + test("queryLoop should properly shut down the thread pool after execution") { + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + val inactivityLimit = 500 // 500 milliseconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60) + + try { + // Mocking ThreadUtils to track the shutdown call + val mockThreadPool = mock[ScheduledExecutorService] + FlintREPL.threadPoolFactory = new MockThreadPoolFactory(mockThreadPool) + + FlintREPL.queryLoop(commandContext) + + // Verify if the shutdown method was called on the thread pool + verify(mockThreadPool).shutdown() + } finally { + // Stop the SparkSession + spark.stop() + FlintREPL.threadPoolFactory = new DefaultThreadPoolFactory() + } + } + + test("queryLoop handle exceptions within the loop gracefully") { + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + // Simulate an exception thrown when hasNext is called + when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + val inactivityLimit = 500 // 500 milliseconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60) + + try { + // Mocking ThreadUtils to track the shutdown call + val mockThreadPool = mock[ScheduledExecutorService] + FlintREPL.threadPoolFactory = new MockThreadPoolFactory(mockThreadPool) + + intercept[RuntimeException] { + FlintREPL.queryLoop(commandContext) + } + + // Verify if the shutdown method was called on the thread pool + verify(mockThreadPool).shutdown() + } finally { + // Stop the SparkSession + spark.stop() + FlintREPL.threadPoolFactory = new DefaultThreadPoolFactory() + } + } + + test("queryLoop should correctly update loop control variables") { + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + when(osClient.doesIndexExist(*)).thenReturn(true) + when(osClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + // Configure mockReader to return true once and then false to exit the loop + when(mockReader.hasNext).thenReturn(true).thenReturn(false) + val command = + """ { + "state": "running", + "query": "SELECT * FROM table", + "statementId": "stmt123", + "queryId": "query456", + "submitTime": 1234567890, + "error": "Some error" + } + """ + when(mockReader.next).thenReturn(command) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + val inactivityLimit = 5000 // 5 seconds + + // Create a SparkSession for testing\ + val mockSparkSession = mock[SparkSession] + val expectedDataFrame = mock[DataFrame] + when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]])) + .thenReturn(expectedDataFrame) + val sparkContext = mock[SparkContext] + when(mockSparkSession.sparkContext).thenReturn(sparkContext) + + when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( + mockSparkSession, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60) + + val startTime = Instant.now().toEpochMilli() + + // Running the queryLoop + FlintREPL.queryLoop(commandContext) + + val endTime = Instant.now().toEpochMilli() + + // Assuming processCommands updates the lastActivityTime to the current time + assert(endTime - startTime >= inactivityLimit) + verify(osClient, times(1)).getIndexMetadata(*) + } + + test("queryLoop should execute loop without processing any commands") { + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + + // Configure mockReader to always return false, indicating no commands to process + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + val inactivityLimit = 5000 // 5 seconds + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60) + + val startTime = Instant.now().toEpochMilli() + + // Running the queryLoop + FlintREPL.queryLoop(commandContext) + + val endTime = Instant.now().toEpochMilli() + + // Assert that the loop ran for at least the duration of the inactivity limit + assert(endTime - startTime >= inactivityLimit) + + // Verify that no command was actually processed + verify(mockReader, never()).next() + + // Stop the SparkSession + spark.stop() } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockThreadPoolFactory.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockThreadPoolFactory.scala new file mode 100644 index 000000000..4aa8e4577 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockThreadPoolFactory.scala @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ScheduledExecutorService + +class MockThreadPoolFactory(mockExecutor: ScheduledExecutorService) extends ThreadPoolFactory { + override def newDaemonThreadPoolScheduledExecutor( + threadNamePrefix: String, + numThreads: Int): ScheduledExecutorService = mockExecutor +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/MockTimeProvider.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockTimeProvider.scala similarity index 72% rename from spark-sql-application/src/test/scala/org/apache/spark/sql/MockTimeProvider.scala rename to spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockTimeProvider.scala index d987577f8..b2b0167c2 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/MockTimeProvider.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockTimeProvider.scala @@ -3,9 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.apache.spark.sql - -import org.apache.spark.sql.util.TimeProvider +package org.apache.spark.sql.util class MockTimeProvider(fixedTime: Long) extends TimeProvider { override def currentEpochMillis(): Long = fixedTime