From 89a5a3fa244c89d71c396dcb034858d754bcb7de Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Tue, 24 Oct 2023 15:19:06 -0700 Subject: [PATCH] Implement REPL mode in Spark and enhance error handling (#99) - **Read (R)**: Source queries from the OpenSearch flint-query-submission index. - **Execute (E)**: Run queries within the SparkContext environment. - **Publish (P)**: - Push results to the flint-query-result index. - Update query state in the flint-query-submission index. - **Loop (L)**: Continue process until a set exit condition is reached. Additional improvements: - Enable cancelation of running statements in Spark. - Fail statements that wait too long. - Provide detailed error feedback. - Introduce query run time metric. Testing: - Introduced unit tests. - Conducted manual tests. Signed-off-by: Kaituo Li --- build.sbt | 12 +- .../opensearch/flint/core/FlintClient.java | 1 - .../flint/core/storage/OpenSearchUpdater.java | 76 +++ .../opensearch/flint/app/FlintCommand.scala | 70 ++ .../opensearch/flint/app/FlintInstance.scala | 55 ++ .../scala/org/apache/spark/sql/FlintJob.scala | 244 +------ .../apache/spark/sql/FlintJobExecutor.scala | 396 +++++++++++ .../org/apache/spark/sql/FlintREPL.scala | 637 ++++++++++++++++++ .../scala/org/apache/spark/sql/OSClient.scala | 68 +- .../apache/spark/sql/VerificationResult.scala | 12 + .../sql/util/DefaultShutdownHookManager.scala | 14 + .../spark/sql/util/RealTimeProvider.scala | 10 + .../sql/util/ShutdownHookManagerTrait.scala | 13 + .../apache/spark/sql/util/TimeProvider.scala | 14 + .../org/apache/spark/sql/FlintJobTest.scala | 38 +- .../org/apache/spark/sql/FlintREPLTest.scala | 173 +++++ .../org/apache/spark/sql/JobMatchers.scala | 15 + .../apache/spark/sql/MockTimeProvider.scala | 12 + .../scala/org/opensearch/sql/SQLJobTest.scala | 11 +- 19 files changed, 1614 insertions(+), 257 deletions(-) create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java create mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala create mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/VerificationResult.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultShutdownHookManager.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealTimeProvider.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShutdownHookManagerTrait.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/TimeProvider.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/JobMatchers.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/MockTimeProvider.scala diff --git a/build.sbt b/build.sbt index 9389384fd..0dcfb8af7 100644 --- a/build.sbt +++ b/build.sbt @@ -179,7 +179,17 @@ lazy val sparkSqlApplication = (project in file("spark-sql-application")) libraryDependencies ++= Seq( "org.scalatest" %% "scalatest" % "3.2.15" % "test"), libraryDependencies ++= deps(sparkVersion), - libraryDependencies += "com.typesafe.play" %% "play-json" % "2.9.2", + libraryDependencies ++= Seq( + "com.typesafe.play" %% "play-json" % "2.9.2", + // handle AmazonS3Exception + "com.amazonaws" % "aws-java-sdk-s3" % "1.12.568" % "provided" + // the transitive jackson.core dependency conflicts with existing scala + // error: Scala module 2.13.4 requires Jackson Databind version >= 2.13.0 and < 2.14.0 - + // Found jackson-databind version 2.14.2 + exclude ("com.fasterxml.jackson.core", "jackson-databind"), + "org.scalatest" %% "scalatest" % "3.2.15" % "test", + "org.mockito" %% "mockito-scala" % "1.16.42" % "test", + "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test"), // Assembly settings // the sbt assembly plugin found multiple copies of the module-info.class file with // different contents in the jars that it was merging flintCore dependencies. diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index d50c0002e..9519df8bc 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -73,7 +73,6 @@ public interface FlintClient { * @return {@link FlintWriter} */ FlintWriter createWriter(String indexName); - /** * Create {@link RestHighLevelClient}. * @return {@link RestHighLevelClient} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java new file mode 100644 index 000000000..58963ab74 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java @@ -0,0 +1,76 @@ +package org.opensearch.flint.core.storage; + +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.flint.core.FlintClient; +import org.opensearch.flint.core.FlintClientBuilder; +import org.opensearch.flint.core.FlintOptions; + +import java.io.IOException; + +public class OpenSearchUpdater { + private final String indexName; + + private final FlintClient flintClient; + + + public OpenSearchUpdater(String indexName, FlintClient flintClient) { + this.indexName = indexName; + this.flintClient = flintClient; + } + + public void upsert(String id, String doc) { + // we might need to keep the updater for a long time. Reusing the client may not work as the temporary + // credentials may expire. + // also, failure to close the client causes the job to be stuck in the running state as the client resource + // is not released. + try (RestHighLevelClient client = flintClient.createClient()) { + UpdateRequest + updateRequest = + new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL) + .docAsUpsert(true); + client.update(updateRequest, RequestOptions.DEFAULT); + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to execute update request on index: %s, id: %s", + indexName, + id), e); + } + } + + public void update(String id, String doc) { + try (RestHighLevelClient client = flintClient.createClient()) { + UpdateRequest + updateRequest = + new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + client.update(updateRequest, RequestOptions.DEFAULT); + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to execute update request on index: %s, id: %s", + indexName, + id), e); + } + } + + public void updateIf(String id, String doc, long seqNo, long primaryTerm) { + try (RestHighLevelClient client = flintClient.createClient()) { + UpdateRequest + updateRequest = + new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm); + client.update(updateRequest, RequestOptions.DEFAULT); + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to execute update request on index: %s, id: %s", + indexName, + id), e); + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala new file mode 100644 index 000000000..288d661cd --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.app + +import org.json4s.{Formats, NoTypeHints} +import org.json4s.JsonAST.JString +import org.json4s.native.JsonMethods.parse +import org.json4s.native.Serialization + +class FlintCommand( + var state: String, + val query: String, + // statementId is the statement type doc id + val statementId: String, + val queryId: String, + val submitTime: Long, + var error: Option[String] = None) { + def running(): Unit = { + state = "running" + } + + def complete(): Unit = { + state = "success" + } + + def fail(): Unit = { + state = "failed" + } + + def isRunning(): Boolean = { + state == "running" + } + + def isComplete(): Boolean = { + state == "success" + } + + def isFailed(): Boolean = { + state == "failed" + } +} + +object FlintCommand { + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + def deserialize(command: String): FlintCommand = { + val meta = parse(command) + val state = (meta \ "state").extract[String] + val query = (meta \ "query").extract[String] + val statementId = (meta \ "statementId").extract[String] + val queryId = (meta \ "queryId").extract[String] + val submitTime = (meta \ "submitTime").extract[Long] + val maybeError: Option[String] = (meta \ "error") match { + case JString(str) => Some(str) + case _ => None + } + + new FlintCommand(state, query, statementId, queryId, submitTime, maybeError) + } + + def serialize(flintCommand: FlintCommand): String = { + // we only need to modify state and error + Serialization.write( + Map("state" -> flintCommand.state, "error" -> flintCommand.error.getOrElse(""))) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala new file mode 100644 index 000000000..3b5ed2a74 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.app + +import org.json4s.{Formats, NoTypeHints} +import org.json4s.JsonAST.JString +import org.json4s.native.JsonMethods.parse +import org.json4s.native.Serialization +import org.opensearch.index.seqno.SequenceNumbers + +// lastUpdateTime is added to FlintInstance to track the last update time of the instance. Its unit is millisecond. +class FlintInstance( + val applicationId: String, + val jobId: String, + // sessionId is the session type doc id + val sessionId: String, + val state: String, + val lastUpdateTime: Long, + val error: Option[String] = None) {} + +object FlintInstance { + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + def deserialize(job: String): FlintInstance = { + val meta = parse(job) + val applicationId = (meta \ "applicationId").extract[String] + val state = (meta \ "state").extract[String] + val jobId = (meta \ "jobId").extract[String] + val sessionId = (meta \ "sessionId").extract[String] + val lastUpdateTime = (meta \ "lastUpdateTime").extract[Long] + val maybeError: Option[String] = (meta \ "error") match { + case JString(str) => Some(str) + case _ => None + } + + new FlintInstance(applicationId, jobId, sessionId, state, lastUpdateTime, maybeError) + } + + def serialize(job: FlintInstance): String = { + Serialization.write( + Map( + "type" -> "session", + "sessionId" -> job.sessionId, + "error" -> job.error.getOrElse(""), + "applicationId" -> job.applicationId, + "jobId" -> job.jobId, + "state" -> job.state, + // update last update time + "lastUpdateTime" -> System.currentTimeMillis())) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 51bf4d734..8acc2308d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -11,7 +11,6 @@ import java.util.Locale import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} -import org.opensearch.ExceptionsHelper import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.cluster.metadata.MappingMetadata import org.opensearch.common.settings.Settings @@ -22,6 +21,7 @@ import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} import org.apache.spark.util.ThreadUtils @@ -36,7 +36,7 @@ import org.apache.spark.util.ThreadUtils * @return * write sql query result to given opensearch index */ -object FlintJob extends Logging { +object FlintJob extends Logging with FlintJobExecutor { def main(args: Array[String]): Unit = { // Validate command line arguments if (args.length != 2) { @@ -54,29 +54,32 @@ object FlintJob extends Logging { implicit val executionContext = ExecutionContext.fromExecutor(threadPool) var dataToWrite: Option[DataFrame] = None + val startTime = System.currentTimeMillis() try { - // osClient needs spark session to be created first. Otherwise, we will have connection - // exception from EMR-S to OS. + // osClient needs spark session to be created first to get FlintOptions initialized. + // Otherwise, we will have connection exception from EMR-S to OS. val osClient = new OSClient(FlintSparkConf().flintOptions()) val futureMappingCheck = Future { checkAndCreateIndex(osClient, resultIndex) } - val data = executeQuery(spark, query, dataSource) + val data = executeQuery(spark, query, dataSource, "", "") val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) dataToWrite = Some(mappingCheckResult match { case Right(_) => data - case Left(error) => getFailedData(spark, dataSource, error) + case Left(error) => + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) }) } catch { case e: TimeoutException => - val error = "Future operations timed out" + val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) - dataToWrite = Some(getFailedData(spark, dataSource, error)) + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) case e: Exception => - val error = "Fail to verify existing mapping or write result" - logError(error, e) - dataToWrite = Some(getFailedData(spark, dataSource, error)) + val error = processQueryException(e, spark, dataSource, query, "", "") + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) } finally { dataToWrite.foreach(df => writeData(df, resultIndex)) // Stop SparkSession if it is not streaming job @@ -89,223 +92,4 @@ object FlintJob extends Logging { threadPool.shutdown() } } - - def createSparkConf(): SparkConf = { - new SparkConf() - .setAppName("FlintJob") - } - - def createSparkSession(conf: SparkConf): SparkSession = { - SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() - } - - def writeData(resultData: DataFrame, resultIndex: String): Unit = { - resultData.write - .format("flint") - .mode("append") - .save(resultIndex) - } - - /** - * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. - * - * @param result - * sql query result dataframe - * @param spark - * spark session - * @return - * dataframe with result, schema and emr step id - */ - def getFormattedData(result: DataFrame, spark: SparkSession, dataSource: String): DataFrame = { - // Create the schema dataframe - val schemaRows = result.schema.fields.map { field => - Row(field.name, field.dataType.typeName) - } - val resultSchema = spark.createDataFrame( - spark.sparkContext.parallelize(schemaRows), - StructType( - Seq( - StructField("column_name", StringType, nullable = false), - StructField("data_type", StringType, nullable = false)))) - - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true))) - - // Create the data rows - val rows = Seq( - ( - result.toJSON.collect.toList - .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")), - resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), - dataSource, - "SUCCESS", - "")) - - // Create the DataFrame for data - spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) - } - - def getFailedData(spark: SparkSession, dataSource: String, error: String): DataFrame = { - - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true))) - - // Create the data rows - val rows = Seq( - ( - null, - null, - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), - dataSource, - "FAILED", - error)) - - // Create the DataFrame for data - spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) - } - - def isSuperset(input: String, mapping: String): Boolean = { - - /** - * Determines whether one JSON structure is a superset of another. - * - * This method checks if the `input` JSON structure contains all the fields and values present - * in the `mapping` JSON structure. The comparison is recursive and structure-sensitive, - * ensuring that nested objects and arrays are also compared accurately. - * - * Additionally, this method accommodates the edge case where boolean values in the JSON are - * represented as strings (e.g., "true" or "false" instead of true or false). This is handled - * by performing a case-insensitive comparison of string representations of boolean values. - * - * @param input - * The input JSON structure as a String. - * @param mapping - * The mapping JSON structure as a String. - * @return - * A Boolean value indicating whether the `input` JSON structure is a superset of the - * `mapping` JSON structure. - */ - def compareJson(inputJson: JsValue, mappingJson: JsValue): Boolean = { - (inputJson, mappingJson) match { - case (JsObject(inputFields), JsObject(mappingFields)) => - logInfo(s"Comparing objects: $inputFields vs $mappingFields") - mappingFields.forall { case (key, value) => - inputFields - .get(key) - .exists(inputValue => compareJson(inputValue, value)) - } - case (JsArray(inputValues), JsArray(mappingValues)) => - logInfo(s"Comparing arrays: $inputValues vs $mappingValues") - mappingValues.forall(mappingValue => - inputValues.exists(inputValue => compareJson(inputValue, mappingValue))) - case (JsString(inputValue), JsString(mappingValue)) - if (inputValue.toLowerCase(Locale.ROOT) == "true" || - inputValue.toLowerCase(Locale.ROOT) == "false") && - (mappingValue.toLowerCase(Locale.ROOT) == "true" || - mappingValue.toLowerCase(Locale.ROOT) == "false") => - inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase(Locale.ROOT) - case (JsBoolean(inputValue), JsString(mappingValue)) - if mappingValue.toLowerCase(Locale.ROOT) == "true" || - mappingValue.toLowerCase(Locale.ROOT) == "false" => - inputValue.toString.toLowerCase(Locale.ROOT) == mappingValue - .toLowerCase(Locale.ROOT) - case (JsString(inputValue), JsBoolean(mappingValue)) - if inputValue.toLowerCase(Locale.ROOT) == "true" || - inputValue.toLowerCase(Locale.ROOT) == "false" => - inputValue.toLowerCase(Locale.ROOT) == mappingValue.toString - .toLowerCase(Locale.ROOT) - case (inputValue, mappingValue) => - inputValue == mappingValue - } - } - - val inputJson = Json.parse(input) - val mappingJson = Json.parse(mapping) - - compareJson(inputJson, mappingJson) - } - - def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { - // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, - val mapping = - """{ - "dynamic": false, - "properties": { - "result": { - "type": "object", - "enabled": false - }, - "schema": { - "type": "object", - "enabled": false - }, - "jobRunId": { - "type": "keyword" - }, - "applicationId": { - "type": "keyword" - }, - "dataSourceName": { - "type": "keyword" - }, - "status": { - "type": "keyword" - }, - "error": { - "type": "text" - } - } - }""".stripMargin - - try { - val existingSchema = osClient.getIndexMetadata(resultIndex) - if (!isSuperset(existingSchema, mapping)) { - Left(s"The mapping of $resultIndex is incorrect.") - } else { - Right(()) - } - } catch { - case e: IllegalStateException - if e.getCause().getMessage().contains("index_not_found_exception") => - try { - osClient.createIndex(resultIndex, mapping) - Right(()) - } catch { - case e: Exception => - val error = s"Failed to create result index $resultIndex" - logError(error, e) - Left(error) - } - case e: Exception => - val error = "Failed to verify existing mapping" - logError(error, e) - Left(error) - } - } - - def executeQuery(spark: SparkSession, query: String, dataSource: String): DataFrame = { - // Execute SQL query - val result: DataFrame = spark.sql(query) - // Get Data - getFormattedData(result, spark, dataSource) - } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala new file mode 100644 index 000000000..c54d420e8 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -0,0 +1,396 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.Locale + +import com.amazonaws.services.s3.model.AmazonS3Exception +import org.opensearch.flint.core.FlintClient +import org.opensearch.flint.core.metadata.FlintMetadata +import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.{getFormattedData, handleIndexNotFoundException, isSuperset, logError, logInfo} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.util.{RealTimeProvider, TimeProvider} + +trait FlintJobExecutor { + this: Logging => + + var currentTimeProvider: TimeProvider = new RealTimeProvider() + + def createSparkConf(): SparkConf = { + new SparkConf() + .setAppName(getClass.getSimpleName) + .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") + } + + def createSparkSession(conf: SparkConf): SparkSession = { + SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() + } + + def writeData(resultData: DataFrame, resultIndex: String): Unit = { + resultData.write + .format("flint") + .mode("append") + .save(resultIndex) + } + + /** + * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. + * + * @param result + * sql query result dataframe + * @param spark + * spark session + * @return + * dataframe with result, schema and emr step id + */ + def getFormattedData( + result: DataFrame, + spark: SparkSession, + dataSource: String, + queryId: String, + query: String, + sessionId: String, + startTime: Long, + timeProvider: TimeProvider): DataFrame = { + // Create the schema dataframe + val schemaRows = result.schema.fields.map { field => + Row(field.name, field.dataType.typeName) + } + val resultSchema = spark.createDataFrame( + spark.sparkContext.parallelize(schemaRows), + StructType( + Seq( + StructField("column_name", StringType, nullable = false), + StructField("data_type", StringType, nullable = false)))) + + // Define the data schema + val schema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + // number is not nullable + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = true))) + + val resultToSave = result.toJSON.collect.toList + .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")) + + val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")) + val endTime = timeProvider.currentEpochMillis() + + // Create the data rows + val rows = Seq( + ( + resultToSave, + resultSchemaToSave, + sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), + sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + dataSource, + "SUCCESS", + "", + queryId, + query, + sessionId, + endTime, + endTime - startTime)) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } + + def getFailedData( + spark: SparkSession, + dataSource: String, + error: String, + queryId: String, + query: String, + sessionId: String, + startTime: Long, + timeProvider: TimeProvider): DataFrame = { + + // Define the data schema + val schema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + // number is not nullable + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = true))) + + val endTime = timeProvider.currentEpochMillis() + + // Create the data rows + val rows = Seq( + ( + null, + null, + sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), + sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + dataSource, + "FAILED", + error, + queryId, + query, + sessionId, + endTime, + endTime - startTime)) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } + + def isSuperset(input: String, mapping: String): Boolean = { + + /** + * Determines whether one JSON structure is a superset of another. + * + * This method checks if the `input` JSON structure contains all the fields and values present + * in the `mapping` JSON structure. The comparison is recursive and structure-sensitive, + * ensuring that nested objects and arrays are also compared accurately. + * + * Additionally, this method accommodates the edge case where boolean values in the JSON are + * represented as strings (e.g., "true" or "false" instead of true or false). This is handled + * by performing a case-insensitive comparison of string representations of boolean values. + * + * @param input + * The input JSON structure as a String. + * @param mapping + * The mapping JSON structure as a String. + * @return + * A Boolean value indicating whether the `input` JSON structure is a superset of the + * `mapping` JSON structure. + */ + def compareJson(inputJson: JsValue, mappingJson: JsValue): Boolean = { + (inputJson, mappingJson) match { + case (JsObject(inputFields), JsObject(mappingFields)) => + mappingFields.forall { case (key, value) => + inputFields + .get(key) + .exists(inputValue => compareJson(inputValue, value)) + } + case (JsArray(inputValues), JsArray(mappingValues)) => + mappingValues.forall(mappingValue => + inputValues.exists(inputValue => compareJson(inputValue, mappingValue))) + case (JsString(inputValue), JsString(mappingValue)) + if (inputValue.toLowerCase(Locale.ROOT) == "true" || + inputValue.toLowerCase(Locale.ROOT) == "false") && + (mappingValue.toLowerCase(Locale.ROOT) == "true" || + mappingValue.toLowerCase(Locale.ROOT) == "false") => + inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase(Locale.ROOT) + case (JsBoolean(inputValue), JsString(mappingValue)) + if mappingValue.toLowerCase(Locale.ROOT) == "true" || + mappingValue.toLowerCase(Locale.ROOT) == "false" => + inputValue.toString.toLowerCase(Locale.ROOT) == mappingValue + .toLowerCase(Locale.ROOT) + case (JsString(inputValue), JsBoolean(mappingValue)) + if inputValue.toLowerCase(Locale.ROOT) == "true" || + inputValue.toLowerCase(Locale.ROOT) == "false" => + inputValue.toLowerCase(Locale.ROOT) == mappingValue.toString + .toLowerCase(Locale.ROOT) + case (inputValue, mappingValue) => + inputValue == mappingValue + } + } + + val inputJson = Json.parse(input) + val mappingJson = Json.parse(mapping) + + compareJson(inputJson, mappingJson) + } + + def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { + // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, + val mapping = + """{ + "dynamic": false, + "properties": { + "result": { + "type": "object", + "enabled": false + }, + "schema": { + "type": "object", + "enabled": false + }, + "jobRunId": { + "type": "keyword" + }, + "applicationId": { + "type": "keyword" + }, + "dataSourceName": { + "type": "keyword" + }, + "status": { + "type": "keyword" + }, + "queryId": { + "type": "keyword" + }, + "queryText": { + "type": "text" + }, + "sessionId": { + "type": "keyword" + }, + "updateTime": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "error": { + "type": "text" + }, + "queryRunTime" : { + "type" : "long" + } + } + }""".stripMargin + + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, mapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause().getMessage().contains("index_not_found_exception") => + handleIndexNotFoundException(osClient, resultIndex, mapping) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + + def handleIndexNotFoundException( + osClient: OSClient, + resultIndex: String, + mapping: String): Either[String, Unit] = { + try { + logInfo(s"create $resultIndex") + osClient.createIndex(resultIndex, mapping) + logInfo(s"create $resultIndex successfully") + Right(()) + } catch { + case e: Exception => + val error = s"Failed to create result index $resultIndex" + logError(error, e) + Left(error) + } + } + + def executeQuery( + spark: SparkSession, + query: String, + dataSource: String, + queryId: String, + sessionId: String): DataFrame = { + // Execute SQL query + val startTime = System.currentTimeMillis() + val result: DataFrame = spark.sql(query) + // Get Data + getFormattedData( + result, + spark, + dataSource, + queryId, + query, + sessionId, + startTime, + currentTimeProvider) + } + + private def handleQueryException( + e: Exception, + message: String, + spark: SparkSession, + dataSource: String, + query: String, + queryId: String, + sessionId: String): String = { + val error = s"$message: ${e.getMessage}" + logError(error, e) + error + } + + def getRootCause(e: Throwable): Throwable = { + if (e.getCause == null) e + else getRootCause(e.getCause) + } + + def processQueryException( + ex: Exception, + spark: SparkSession, + dataSource: String, + query: String, + queryId: String, + sessionId: String): String = { + getRootCause(ex) match { + case r: ParseException => + handleQueryException(r, "Syntax error", spark, dataSource, query, queryId, sessionId) + case r: AmazonS3Exception => + handleQueryException( + r, + "Fail to read data from S3. Cause", + spark, + dataSource, + query, + queryId, + sessionId) + case r: AnalysisException => + handleQueryException( + r, + "Fail to analyze query. Cause", + spark, + dataSource, + query, + queryId, + sessionId) + case r: SparkException => + handleQueryException( + r, + "Fail to run query. Cause", + spark, + dataSource, + query, + queryId, + sessionId) + case r: Exception => + handleQueryException( + r, + "Fail to write result, cause", + spark, + dataSource, + query, + queryId, + sessionId) + } + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala new file mode 100644 index 000000000..c714682c0 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -0,0 +1,637 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.time.Instant +import java.util.concurrent.{ScheduledExecutorService, ThreadPoolExecutor} + +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES, SECONDS} + +import org.opensearch.action.get.GetResponse +import org.opensearch.common.Strings +import org.opensearch.flint.app.{FlintCommand, FlintInstance} +import org.opensearch.flint.app.FlintCommand.serialize +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.{currentTimeProvider, getFailedData, writeData} +import org.apache.spark.sql.FlintREPL.executeQuery +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} +import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} +import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} + +/** + * Spark SQL Application entrypoint + * + * @param args(0) + * sql query + * @param args(1) + * opensearch result index name + * @param args(2) + * opensearch connection values required for flint-integration jar. host, port, scheme, auth, + * region respectively. + * @return + * write sql query result to given opensearch index + */ +object FlintREPL extends Logging with FlintJobExecutor { + + private val HEARTBEAT_INTERVAL_MILLIS = 60000L + private val INACTIVITY_LIMIT_MILLIS = 30 * 60 * 1000 + private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + private val QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) + private val QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + + def update(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { + updater.update(flintCommand.statementId, serialize(flintCommand)) + } + + def main(args: Array[String]) { + val Array(query, resultIndex) = args + if (Strings.isNullOrEmpty(resultIndex)) { + throw new IllegalArgumentException("resultIndex is not set") + } + + // init SparkContext + val conf: SparkConf = createSparkConf() + val dataSource = conf.get("spark.flint.datasource.name", "unknown") + val wait = conf.get("spark.flint.job.type", "continue") + // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. + val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) + val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) + + if (sessionIndex.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") + } + if (sessionId.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.sessionId is not set") + } + + val spark = createSparkSession(conf) + val osClient = new OSClient(FlintSparkConf().flintOptions()) + val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") + val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + + if (wait.equalsIgnoreCase("streaming")) { + logInfo(s"""streaming query ${query}""") + val result = executeQuery(spark, query, dataSource, "", "") + writeData(result, resultIndex) + spark.streams.awaitAnyTermination() + } else { + try { + val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) + createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get); + + queryLoop( + resultIndex, + dataSource, + sessionIndex.get, + sessionId.get, + spark, + osClient, + jobId, + applicationId, + flintSessionIndexUpdater) + } finally { + spark.stop() + } + } + } + + private def queryLoop( + resultIndex: String, + dataSource: String, + sessionIndex: String, + sessionId: String, + spark: SparkSession, + osClient: OSClient, + jobId: String, + applicationId: String, + flintSessionIndexUpdater: OpenSearchUpdater): Unit = { + // 1 thread for updating heart beat + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + try { + val futureMappingCheck = Future { + checkAndCreateIndex(osClient, resultIndex) + } + + setupFlintJob(applicationId, jobId, sessionId, flintSessionIndexUpdater, sessionIndex) + + // update heart beat every 30 seconds + // OpenSearch triggers recovery after 1 minute outdated heart beat + createHeartBeatUpdater( + HEARTBEAT_INTERVAL_MILLIS, + flintSessionIndexUpdater, + sessionId: String, + threadPool, + osClient, + sessionIndex) + + var lastActivityTime = Instant.now().toEpochMilli() + var verificationResult: VerificationResult = NotVerified + + while (Instant.now().toEpochMilli() - lastActivityTime <= INACTIVITY_LIMIT_MILLIS) { + logInfo(s"""read from ${sessionIndex}""") + val flintReader: FlintReader = + createQueryReader(osClient, applicationId, sessionId, sessionIndex, dataSource) + + try { + val result: (Long, VerificationResult) = processCommands( + flintReader, + spark, + dataSource, + resultIndex, + sessionId, + futureMappingCheck, + verificationResult, + executionContext, + flintSessionIndexUpdater, + lastActivityTime) + + val (updatedLastActivityTime, updatedVerificationResult) = result + + lastActivityTime = updatedLastActivityTime + verificationResult = updatedVerificationResult + } finally { + flintReader.close() + } + + Thread.sleep(100) + } + + } catch { + case e: Exception => + handleSessionError(e, applicationId, jobId, sessionId, flintSessionIndexUpdater) + } finally { + if (threadPool != null) { + threadPool.shutdown() + } + } + } + + private def setupFlintJob( + applicationId: String, + jobId: String, + sessionId: String, + flintSessionIndexUpdater: OpenSearchUpdater, + sessionIndex: String): Unit = { + val flintJob = + new FlintInstance(applicationId, jobId, sessionId, "running", System.currentTimeMillis()) + flintSessionIndexUpdater.upsert(sessionId, FlintInstance.serialize(flintJob)) + logInfo( + s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") + } + + private def handleSessionError( + e: Exception, + applicationId: String, + jobId: String, + sessionId: String, + flintSessionIndexUpdater: OpenSearchUpdater): Unit = { + val error = s"Unexpected error: ${e.getMessage}" + logError(error, e) + val flintJob = new FlintInstance( + applicationId, + jobId, + sessionId, + "fail", + System.currentTimeMillis(), + Some(error)) + flintSessionIndexUpdater.upsert(sessionId, FlintInstance.serialize(flintJob)) + } + + /** + * handling the case where a command's execution fails, updates the flintCommand with the error + * and failure status, and then delegates to the second method for actual DataFrame creation + * @param spark + * spark session + * @param dataSource + * data source + * @param error + * error message + * @param flintCommand + * flint command + * @param sessionId + * session id + * @param startTime + * start time + * @return + * failed data frame + */ + def handleCommandFailureAndGetFailedData( + spark: SparkSession, + dataSource: String, + error: String, + flintCommand: FlintCommand, + sessionId: String, + startTime: Long): DataFrame = { + flintCommand.fail() + flintCommand.error = Some(error) + super.getFailedData( + spark, + dataSource, + error, + flintCommand.queryId, + flintCommand.query, + sessionId, + startTime, + currentTimeProvider) + } + + def processQueryException( + ex: Exception, + spark: SparkSession, + dataSource: String, + flintCommand: FlintCommand, + sessionId: String): String = { + val error = super.processQueryException( + ex, + spark, + dataSource, + flintCommand.query, + flintCommand.queryId, + sessionId) + flintCommand.fail() + flintCommand.error = Some(error) + error + } + + private def processCommands( + flintReader: FlintReader, + spark: SparkSession, + dataSource: String, + resultIndex: String, + sessionId: String, + futureMappingCheck: Future[Either[String, Unit]], + recordedVerificationResult: VerificationResult, + executionContext: ExecutionContextExecutor, + flintSessionIndexUpdater: OpenSearchUpdater, + recordedLastActivityTime: Long): (Long, VerificationResult) = { + var lastActivityTime = recordedLastActivityTime + var verificationResult = recordedVerificationResult + + while (flintReader.hasNext) { + lastActivityTime = Instant.now().toEpochMilli() + val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) + + spark.sparkContext.setJobGroup( + flintCommand.queryId, + "Job group for " + flintCommand.queryId, + interruptOnCancel = true) + val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( + recordedVerificationResult, + spark, + flintCommand, + dataSource, + sessionId, + executionContext, + futureMappingCheck, + resultIndex) + + verificationResult = returnedVerificationResult + finalizeCommand(dataToWrite, flintCommand, resultIndex, flintSessionIndexUpdater) + } + + // return tuple indicating if still active and mapping verification result + (lastActivityTime, verificationResult) + } + + /** + * finalize command after processing + * + * @param dataToWrite + * data to write + * @param flintCommand + * flint command + * @param resultIndex + * result index + * @param flintSessionIndexUpdater + * flint session index updater + */ + private def finalizeCommand( + dataToWrite: Option[DataFrame], + flintCommand: FlintCommand, + resultIndex: String, + flintSessionIndexUpdater: OpenSearchUpdater): Unit = { + try { + dataToWrite.foreach(df => writeData(df, resultIndex)) + if (flintCommand.isRunning()) { + // we have set failed state in exception handling + flintCommand.complete() + } + update(flintCommand, flintSessionIndexUpdater) + } catch { + // e.g., maybe due to authentication service connection issue + case e: Exception => + val error = s"""Fail to write result of ${flintCommand}, cause: ${e.getMessage}""" + logError(error, e) + flintCommand.fail() + update(flintCommand, flintSessionIndexUpdater) + } + } + private def processStatementOnVerification( + recordedVerificationResult: VerificationResult, + spark: SparkSession, + flintCommand: FlintCommand, + dataSource: String, + sessionId: String, + executionContext: ExecutionContextExecutor, + futureMappingCheck: Future[Either[String, Unit]], + resultIndex: String): (Option[DataFrame], VerificationResult) = { + val startTime: Long = System.currentTimeMillis() + var verificationResult = recordedVerificationResult + var dataToWrite: Option[DataFrame] = None + try { + verificationResult match { + case NotVerified => + try { + val mappingCheckResult = + ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) + // time out after 10 minutes + val result = executeQueryAsync( + spark, + flintCommand, + dataSource, + sessionId, + executionContext, + startTime) + + dataToWrite = Some(mappingCheckResult match { + case Right(_) => + verificationResult = VerifiedWithoutError + result + case Left(error) => + verificationResult = VerifiedWithError(error) + handleCommandFailureAndGetFailedData( + spark, + dataSource, + error, + flintCommand, + sessionId, + startTime) + }) + } catch { + case e: TimeoutException => + val error = s"Getting the mapping of index $resultIndex timed out" + logError(error, e) + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + spark, + dataSource, + error, + flintCommand, + sessionId, + startTime)) + } + case VerifiedWithError(err) => + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + spark, + dataSource, + err, + flintCommand, + sessionId, + startTime)) + case VerifiedWithoutError => + dataToWrite = Some( + executeQueryAsync( + spark, + flintCommand, + dataSource, + sessionId, + executionContext, + startTime)) + } + + logDebug(s"""command complete: ${flintCommand}""") + } catch { + case e: TimeoutException => + val error = s"Executing ${flintCommand.query} timed out" + spark.sparkContext.cancelJobGroup(flintCommand.queryId) + logError(error, e) + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + spark, + dataSource, + error, + flintCommand, + sessionId, + startTime)) + case e: Exception => + val error = processQueryException(e, spark, dataSource, flintCommand, sessionId) + dataToWrite = Some( + handleCommandFailureAndGetFailedData( + spark, + dataSource, + error, + flintCommand, + sessionId, + startTime)) + } + (dataToWrite, verificationResult) + } + + def executeQueryAsync( + spark: SparkSession, + flintCommand: FlintCommand, + dataSource: String, + sessionId: String, + executionContext: ExecutionContextExecutor, + startTime: Long): DataFrame = { + if (Instant.now().toEpochMilli() - flintCommand.submitTime > QUERY_WAIT_TIMEOUT_MILLIS) { + handleCommandFailureAndGetFailedData( + spark, + dataSource, + "wait timeout", + flintCommand, + sessionId, + startTime) + } else { + val futureQueryExecution = Future { + executeQuery(spark, flintCommand.query, dataSource, flintCommand.queryId, sessionId) + }(executionContext) + + // time out after 10 minutes + ThreadUtils.awaitResult(futureQueryExecution, QUERY_EXECUTION_TIMEOUT) + } + } + private def processCommandInitiation( + flintReader: FlintReader, + flintSessionIndexUpdater: OpenSearchUpdater): FlintCommand = { + val command = flintReader.next() + logDebug(s"raw command: $command") + val flintCommand = FlintCommand.deserialize(command) + logDebug(s"command: $flintCommand") + flintCommand.running() + logDebug(s"command running: $flintCommand") + update(flintCommand, flintSessionIndexUpdater) + flintCommand + } + + private def createQueryReader( + osClient: OSClient, + applicationId: String, + sessionId: String, + sessionIndex: String, + dataSource: String) = { + // all state in index are in lower case + // + // add application to deal with emr-s deployment: + // Should the EMR-S application be terminated or halted (e.g., during deployments), OpenSearch associates each + // query to ascertain the latest EMR-S application ID. While existing EMR-S jobs continue executing queries sharing + // the same applicationId and sessionId, new requests with differing application IDs will not be processed by the + // current EMR-S job. Gradually, the queue with the same applicationId and sessionId diminishes, enabling the + // job to self-terminate due to inactivity timeout. Following the termination of all jobs, the CP can gracefully + // shut down the application. The session has a phantom state where the old application set the state to dead and + // then the new application job sets it to running through heartbeat. Also, application id can help in the case + // of job restart where the job id is different but application id is the same. + val dsl = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "type": "statement" + | } + | }, + | { + | "term": { + | "state": "waiting" + | } + | }, + | { + | "term": { + | "applicationId": "$applicationId" + | } + | }, + | { + | "term": { + | "sessionId": "$sessionId" + | } + | }, + | { + | "term": { + | "dataSourceName": "$dataSource" + | } + | } + | ] + | } + |}""".stripMargin + val flintReader = osClient.createReader(sessionIndex, dsl, "submitTime") + flintReader + } + + def createShutdownHook( + flintSessionIndexUpdater: OpenSearchUpdater, + osClient: OSClient, + sessionIndex: String, + sessionId: String, + shutdownHookManager: ShutdownHookManagerTrait = DefaultShutdownHookManager): Unit = { + + shutdownHookManager.addShutdownHook(() => { + logInfo("Shutting down REPL") + + val getResponse = osClient.getDoc(sessionIndex, sessionId) + if (!getResponse.isExists()) { + return + } + + val source = getResponse.getSourceAsMap + if (source == null) { + return + } + + val state = Option(source.get("state")).map(_.asInstanceOf[String]) + if (state.isDefined && state.get != "dead" && state.get != "fail") { + updateFlintInstanceBeforeShutdown( + source, + getResponse, + flintSessionIndexUpdater, + sessionId) + } + }) + } + + private def updateFlintInstanceBeforeShutdown( + source: java.util.Map[String, AnyRef], + getResponse: GetResponse, + flintSessionIndexUpdater: OpenSearchUpdater, + sessionId: String): Unit = { + val flintInstant = new FlintInstance( + source.get("applicationId").asInstanceOf[String], + source.get("jobId").asInstanceOf[String], + source.get("sessionId").asInstanceOf[String], + "dead", + source.get("lastUpdateTime").asInstanceOf[Long], + Option(source.get("error").asInstanceOf[String])) + + flintSessionIndexUpdater.updateIf( + sessionId, + FlintInstance.serialize(flintInstant), + getResponse.getSeqNo, + getResponse.getPrimaryTerm) + } + + /** + * Create a new thread to update the last update time of the flint instance. + * @param currentInterval + * the interval of updating the last update time. Unit is millisecond. + * @param flintSessionUpdater + * the updater of the flint instance. + * @param sessionId + * the session id of the flint instance. + * @param threadPool + * the thread pool. + * @param osClient + * the OpenSearch client. + */ + def createHeartBeatUpdater( + currentInterval: Long, + flintSessionUpdater: OpenSearchUpdater, + sessionId: String, + threadPool: ScheduledExecutorService, + osClient: OSClient, + sessionIndex: String): Unit = { + + threadPool.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = { + try { + val getResponse = osClient.getDoc(sessionIndex, sessionId) + if (getResponse.isExists()) { + val source = getResponse.getSourceAsMap + val flintInstant: FlintInstance = new FlintInstance( + source.get("applicationId").asInstanceOf[String], + source.get("jobId").asInstanceOf[String], + source.get("sessionId").asInstanceOf[String], + "running", + source.get("lastUpdateTime").asInstanceOf[Long], + Option(source.get("error").asInstanceOf[String])) + flintSessionUpdater.updateIf( + sessionId, + FlintInstance.serialize(flintInstant), + getResponse.getSeqNo, + getResponse.getPrimaryTerm) + } + // do nothing if the session doc does not exist + } catch { + // maybe due to invalid sequence number or primary term + case e: Exception => + logWarning( + s"""Fail to update the last update time of the flint instance ${sessionId}""", + e) + } + } + }, + 0L, + currentInterval, + java.util.concurrent.TimeUnit.MILLISECONDS) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index cf2a5860d..f3f118702 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -5,19 +5,40 @@ package org.apache.spark.sql -import org.opensearch.client.RequestOptions +import java.io.IOException +import java.util.ArrayList +import java.util.Locale + +import org.opensearch.action.get.{GetRequest, GetResponse} +import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest, GetIndexResponse} import org.opensearch.client.indices.CreateIndexRequest -import org.opensearch.common.xcontent.XContentType -import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} +import org.opensearch.common.Strings +import org.opensearch.common.settings.Settings +import org.opensearch.common.xcontent.{NamedXContentRegistry, XContentParser, XContentType} +import org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchScrollReader, OpenSearchUpdater} +import org.opensearch.index.query.{AbstractQueryBuilder, MatchAllQueryBuilder, QueryBuilder} +import org.opensearch.plugins.SearchPlugin +import org.opensearch.search.SearchModule +import org.opensearch.search.builder.SearchSourceBuilder +import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging class OSClient(val flintOptions: FlintOptions) extends Logging { + val flintClient: FlintClient = FlintClientBuilder.build(flintOptions) + /** + * {@link NamedXContentRegistry} from {@link SearchModule} used for construct {@link + * QueryBuilder} from DSL query string. + */ + private val xContentRegistry: NamedXContentRegistry = new NamedXContentRegistry( + new SearchModule(Settings.builder.build, new ArrayList[SearchPlugin]).getNamedXContents) def getIndexMetadata(osIndexName: String): String = { - using(FlintClientBuilder.build(flintOptions).createClient()) { client => + using(flintClient.createClient()) { client => val request = new GetIndexRequest(osIndexName) try { val response = client.indices.get(request, RequestOptions.DEFAULT) @@ -45,7 +66,7 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { def createIndex(osIndexName: String, mapping: String): Unit = { logInfo(s"create $osIndexName") - using(FlintClientBuilder.build(flintOptions).createClient()) { client => + using(flintClient.createClient()) { client => val request = new CreateIndexRequest(osIndexName) request.mapping(mapping, XContentType.JSON) @@ -82,4 +103,41 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { } } + def createUpdater(indexName: String): OpenSearchUpdater = + new OpenSearchUpdater(indexName, flintClient) + + def getDoc(osIndexName: String, id: String): GetResponse = { + using(flintClient.createClient()) { client => + try { + val request = new GetRequest(osIndexName, id) + client.get(request, RequestOptions.DEFAULT) + } catch { + case e: Exception => + throw new IllegalStateException( + String.format( + Locale.ROOT, + "Failed to retrieve doc %s from index %s", + osIndexName, + id), + e) + } + } + } + + def createReader(indexName: String, query: String, sort: String): FlintReader = try { + var queryBuilder: QueryBuilder = new MatchAllQueryBuilder + if (!Strings.isNullOrEmpty(query)) { + val parser = + XContentType.JSON.xContent.createParser(xContentRegistry, IGNORE_DEPRECATIONS, query) + queryBuilder = AbstractQueryBuilder.parseInnerQueryBuilder(parser) + } + new OpenSearchScrollReader( + flintClient.createClient(), + indexName, + new SearchSourceBuilder().query(queryBuilder).sort(sort, SortOrder.ASC), + flintOptions) + } catch { + case e: IOException => + throw new RuntimeException(e) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/VerificationResult.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/VerificationResult.scala new file mode 100644 index 000000000..c4e84e35a --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/VerificationResult.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +sealed trait VerificationResult + +case object NotVerified extends VerificationResult +case object VerifiedWithoutError extends VerificationResult +case class VerifiedWithError(error: String) extends VerificationResult diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultShutdownHookManager.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultShutdownHookManager.scala new file mode 100644 index 000000000..770bc7e54 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/DefaultShutdownHookManager.scala @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import org.apache.spark.util.ShutdownHookManager + +object DefaultShutdownHookManager extends ShutdownHookManagerTrait { + override def addShutdownHook(hook: () => Unit): AnyRef = { + ShutdownHookManager.addShutdownHook(hook) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealTimeProvider.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealTimeProvider.scala new file mode 100644 index 000000000..dddb30b2b --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealTimeProvider.scala @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +class RealTimeProvider extends TimeProvider { + override def currentEpochMillis(): Long = System.currentTimeMillis() +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShutdownHookManagerTrait.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShutdownHookManagerTrait.scala new file mode 100644 index 000000000..c1615d647 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShutdownHookManagerTrait.scala @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * A trait allows injecting shutdown hook manager. + */ +trait ShutdownHookManagerTrait { + def addShutdownHook(hook: () => Unit): AnyRef +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/TimeProvider.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/TimeProvider.scala new file mode 100644 index 000000000..0703875d2 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/TimeProvider.scala @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * A trait to provide current epoch time in milliseconds. This trait helps make it current time + * provider mockable. + */ +trait TimeProvider { + def currentEpochMillis(): Long +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index b891be0e1..e6b8d8289 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -5,12 +5,10 @@ package org.apache.spark.sql -import org.scalatest.matchers.should.Matchers - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class FlintJobTest extends SparkFunSuite with Matchers { +class FlintJobTest extends SparkFunSuite with JobMatchers { val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() @@ -35,7 +33,16 @@ class FlintJobTest extends SparkFunSuite with Matchers { StructField("applicationId", StringType, nullable = true), StructField("dataSourceName", StringType, nullable = true), StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true))) + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = false))) + + val currentTime: Long = System.currentTimeMillis() + val queryRunTime: Long = 3000L + val expectedRows = Seq( Row( Array( @@ -49,20 +56,29 @@ class FlintJobTest extends SparkFunSuite with Matchers { "unknown", dataSourceName, "SUCCESS", - "")) + "", + "10", + "select 1", + "20", + currentTime, + queryRunTime)) val expected: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) // Compare the result - val result = FlintJob.getFormattedData(input, spark, dataSourceName) + val result = + FlintJob.getFormattedData( + input, + spark, + dataSourceName, + "10", + "select 1", + "20", + currentTime - queryRunTime, + new MockTimeProvider(currentTime)) assertEqualDataframe(expected, result) } - def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit = { - assert(expected.schema === result.schema) - assert(expected.collect() === result.collect()) - } - test("test isSuperset") { // note in input false has enclosed double quotes, while mapping just has false val input = diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala new file mode 100644 index 000000000..96ea91ebd --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} + +import scala.collection.JavaConverters._ + +import org.mockito.{ArgumentMatchersSugar, IdiomaticMockito} +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.opensearch.action.get.GetResponse +import org.opensearch.flint.app.FlintCommand +import org.opensearch.flint.core.storage.OpenSearchUpdater +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{ArrayType, LongType, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.util.ShutdownHookManagerTrait + +class FlintREPLTest + extends SparkFunSuite + with MockitoSugar + with ArgumentMatchersSugar + with JobMatchers { + val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() + // By using a type alias and casting, I can bypass the type checking error. + type AnyScheduledFuture = ScheduledFuture[_] + + test("createHeartBeatUpdater should update heartbeat correctly") { + // Mocks + val flintSessionUpdater = mock[OpenSearchUpdater] + val osClient = mock[OSClient] + val threadPool = mock[ScheduledExecutorService] + val getResponse = mock[GetResponse] + val scheduledFutureRaw: ScheduledFuture[_] = mock[ScheduledFuture[_]] + + // Mock behaviors + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> "app1", + "jobId" -> "job1", + "sessionId" -> "session1", + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError").asJava) + when(getResponse.getSeqNo).thenReturn(0L) + when(getResponse.getPrimaryTerm).thenReturn(0L) + // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. + when( + threadPool.scheduleAtFixedRate( + any[Runnable], + eqTo(0), + *, + eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + .thenAnswer((invocation: InvocationOnMock) => { + val runnable = invocation.getArgument[Runnable](0) + runnable.run() + scheduledFutureRaw + }) + + // Invoke the method + FlintREPL.createHeartBeatUpdater( + 1000L, + flintSessionUpdater, + "session1", + threadPool, + osClient, + "sessionIndex") + + // Verifications + verify(osClient, atLeastOnce()).getDoc("sessionIndex", "session1") + verify(flintSessionUpdater, atLeastOnce()).updateIf(eqTo("session1"), *, eqTo(0L), eqTo(0L)) + } + + test("createShutdownHook add shutdown hook and update FlintInstance if conditions are met") { + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val osClient = mock[OSClient] + val getResponse = mock[GetResponse] + val sessionIndex = "testIndex" + val sessionId = "testSessionId" + + when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> "app1", + "jobId" -> "job1", + "sessionId" -> "session1", + "state" -> "running", + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError").asJava) + + val mockShutdownHookManager = new ShutdownHookManagerTrait { + override def addShutdownHook(hook: () => Unit): AnyRef = { + hook() // execute the hook immediately + new Object() // return a dummy AnyRef as per the method signature + } + } + + // Here, we're injecting our mockShutdownHookManager into the method + FlintREPL.createShutdownHook( + flintSessionIndexUpdater, + osClient, + sessionIndex, + sessionId, + mockShutdownHookManager) + + verify(flintSessionIndexUpdater).updateIf(*, *, *, *) + } + + test("Test getFailedData method") { + // Define expected dataframe + val dataSourceName = "myGlueS3" + val expectedSchema = StructType( + Seq( + StructField("result", NullType, nullable = true), + StructField("schema", NullType, nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = false))) + + val currentTime: Long = System.currentTimeMillis() + val queryRunTime: Long = 3000L + + val error = "some error" + val expectedRows = Seq( + Row( + null, + null, + "unknown", + "unknown", + dataSourceName, + "FAILED", + error, + "10", + "select 1", + "20", + currentTime, + queryRunTime)) + val expected: DataFrame = + spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) + + val flintCommand = new FlintCommand("failed", "select 1", "30", "10", currentTime, None) + + FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime) + + // Compare the result + val result = + FlintREPL.handleCommandFailureAndGetFailedData( + spark, + dataSourceName, + error, + flintCommand, + "20", + currentTime - queryRunTime) + assertEqualDataframe(expected, result) + assert("failed" == flintCommand.state) + assert(error == flintCommand.error.get) + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/JobMatchers.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/JobMatchers.scala new file mode 100644 index 000000000..61aebad60 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/JobMatchers.scala @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.scalatest.matchers.should.Matchers + +trait JobMatchers extends Matchers { + def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit = { + assert(expected.schema === result.schema) + assert(expected.collect() === result.collect()) + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/MockTimeProvider.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/MockTimeProvider.scala new file mode 100644 index 000000000..d987577f8 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/MockTimeProvider.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.util.TimeProvider + +class MockTimeProvider(fixedTime: Long) extends TimeProvider { + override def currentEpochMillis(): Long = fixedTime +} diff --git a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala index 063c76c4d..bd672f390 100644 --- a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala +++ b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala @@ -5,13 +5,11 @@ package org.opensearch.sql -import org.scalatest.matchers.should.Matchers - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, JobMatchers, Row, SparkSession} import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} -class SQLJobTest extends SparkFunSuite with Matchers { +class SQLJobTest extends SparkFunSuite with JobMatchers { val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() @@ -50,9 +48,4 @@ class SQLJobTest extends SparkFunSuite with Matchers { val result = SQLJob.getFormattedData(input, spark) assertEqualDataframe(expected, result) } - - def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit = { - assert(expected.schema === result.schema) - assert(expected.collect() === result.collect()) - } }