From 71d67a0dcdd4c91d3cbfc9546ece9b272a0b0c5c Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 4 Oct 2023 13:35:16 -0700 Subject: [PATCH] Implement FlintJob Logic for EMR-S (#52) * Implement FlintJob Logic for EMR-S This commit introduces FlintJob logic for EMR-S, mirroring the existing SQLJob implementation for EMR cluster. The key differences in FlintJob are: 1. It reads OpenSearch host information from spark command parameters. 2. It ensures the existence of a result index with the correct mapping in OpenSearch, creating it if necessary. This process occurs in parallel to SQL query execution. 3. It reports an error if the result index mapping is incorrect. 4. It saves a failure status if the SQL execution fails. Testing: 1. Manual testing was conducted using the EMR-S CLI. 2. New unit tests were added to verify the functionality. Signed-off-by: Kaituo Li * address comments Signed-off-by: Kaituo Li --------- Signed-off-by: Kaituo Li --- .scalafmt.conf | 1 + build.sbt | 29 +- spark-sql-application/README.md | 86 ++++- .../scala/org/apache/spark/sql/FlintJob.scala | 354 ++++++++++++++++++ .../org/apache/spark/sql/FlintJobTest.scala | 84 +++++ 5 files changed, 547 insertions(+), 7 deletions(-) create mode 100644 .scalafmt.conf create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 000000000..834f2d20f --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1 @@ +version = 2.7.5 \ No newline at end of file diff --git a/build.sbt b/build.sbt index 6b7c8d53a..bc018c265 100644 --- a/build.sbt +++ b/build.sbt @@ -165,12 +165,37 @@ lazy val standaloneCosmetic = project Compile / packageBin := (flintSparkIntegration / assembly).value) lazy val sparkSqlApplication = (project in file("spark-sql-application")) + // dependency will be provided at runtime, so it doesn't need to be included in the assembled JAR + .dependsOn(flintSparkIntegration % "provided") .settings( commonSettings, name := "sql-job", scalaVersion := scala212, - libraryDependencies ++= Seq("org.scalatest" %% "scalatest" % "3.2.15" % "test"), - libraryDependencies ++= deps(sparkVersion)) + libraryDependencies ++= Seq( + "org.scalatest" %% "scalatest" % "3.2.15" % "test"), + libraryDependencies ++= deps(sparkVersion), + libraryDependencies += "com.typesafe.play" %% "play-json" % "2.9.2", + // Assembly settings + // the sbt assembly plugin found multiple copies of the module-info.class file with + // different contents in the jars that it was merging flintCore dependencies. + // This can happen if you have multiple dependencies that include the same library, + // but with different versions. + assemblyPackageScala / assembleArtifact := false, + assembly / assemblyOption ~= { + _.withIncludeScala(false) + }, + assembly / assemblyMergeStrategy := { + case PathList(ps@_*) if ps.last endsWith ("module-info.class") => + MergeStrategy.discard + case PathList("module-info.class") => MergeStrategy.discard + case PathList("META-INF", "versions", xs@_, "module-info.class") => + MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + }, + assembly / test := (Test / test).value + ) lazy val sparkSqlApplicationCosmetic = project .settings( diff --git a/spark-sql-application/README.md b/spark-sql-application/README.md index 07bf46406..bd12ee933 100644 --- a/spark-sql-application/README.md +++ b/spark-sql-application/README.md @@ -1,13 +1,25 @@ # Spark SQL Application -This application execute sql query and store the result in OpenSearch index in following format +We have two applications: SQLJob and FlintJob. + +SQLJob is designed for EMR Spark, executing SQL queries and storing the results in the OpenSearch index in the following format: ``` "stepId":"", -"applicationId":"" +"applicationId":"", "schema": "json blob", "result": "json blob" ``` +FlintJob is designed for EMR Serverless Spark, executing SQL queries and storing the results in the OpenSearch index in the following format: + +``` +"jobRunId":"", +"applicationId":"", +"schema": "json blob", +"result": "json blob", +"dataSourceName":"" +``` + ## Prerequisites + Spark 3.3.1 @@ -16,8 +28,9 @@ This application execute sql query and store the result in OpenSearch index in f ## Usage -To use this application, you can run Spark with Flint extension: +To use these applications, you can run Spark with Flint extension: +SQLJob ``` ./bin/spark-submit \ --class org.opensearch.sql.SQLJob \ @@ -32,11 +45,41 @@ To use this application, you can run Spark with Flint extension: \ ``` +FlintJob +``` +aws emr-serverless start-job-run \ + --region \ + --application-id \ + --execution-role-arn \ + --job-driver '{"sparkSubmit": {"entryPoint": "", \ + "entryPointArguments":["''", "", ""], \ + "sparkSubmitParameters":"--class org.opensearch.sql.FlintJob \ + --conf spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider \ + --conf spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN= \ + --conf spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN= \ + --conf spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory \ + --conf spark.hive.metastore.glue.role.arn= \ + --conf spark.jars= \ + --conf spark.jars.packages= \ + --conf spark.jars.repositories= \ + --conf spark.emr-serverless.driverEnv.JAVA_HOME= \ + --conf spark.executorEnv.JAVA_HOME= \ + --conf spark.datasource.flint.host= \ + --conf spark.datasource.flint.port= \ + --conf spark.datasource.flint.scheme= \ + --conf spark.datasource.flint.auth= \ + --conf spark.datasource.flint.region= \ + --conf spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider \ + --conf spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions \ + --conf spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory "}}' + +``` + ## Result Specifications Following example shows how the result is written to OpenSearch index after query execution. -Let's assume sql query result is +Let's assume SQL query result is ``` +------+------+ |Letter|Number| @@ -46,7 +89,7 @@ Let's assume sql query result is |C |3 | +------+------+ ``` -OpenSearch index document will look like +For SQLJob, OpenSearch index document will look like ```json { "_index" : ".query_execution_result", @@ -68,6 +111,31 @@ OpenSearch index document will look like } ``` +For FlintJob, OpenSearch index document will look like +```json +{ + "_index" : ".query_execution_result", + "_id" : "A2WOsYgBMUoqCqlDJHrn", + "_score" : 1.0, + "_source" : { + "result" : [ + "{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}" + ], + "schema" : [ + "{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}" + ], + "jobRunId" : "s-JZSB1139WIVU", + "applicationId" : "application_1687726870985_0003", + "dataSourceName": "myS3Glue", + "status": "SUCCESS", + "error": "" + } +} +``` + ## Build To build and run this application with Spark, you can run: @@ -76,6 +144,8 @@ To build and run this application with Spark, you can run: sbt clean sparkSqlApplicationCosmetic/publishM2 ``` +The jar file is located at `spark-sql-application/target/scala-2.12` folder. + ## Test To run tests, you can use: @@ -92,6 +162,12 @@ To check code with scalastyle, you can run: sbt scalastyle ``` +To check code with scalastyle, you can run: + +``` +sbt testScalastyle +``` + ## Code of Conduct This project has adopted an [Open Source Code of Conduct](../CODE_OF_CONDUCT.md). diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala new file mode 100644 index 000000000..8da7d2072 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -0,0 +1,354 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +// defined in spark package so that I can use ThreadUtils +package org.apache.spark.sql + +import java.util.Locale + +import org.opensearch.ExceptionsHelper +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} +import org.opensearch.flint.core.metadata.FlintMetadata +import play.api.libs.json._ +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.types.{StructField, _} +import org.apache.spark.util.ThreadUtils + +/** + * Spark SQL Application entrypoint + * + * @param args + * (0) sql query + * @param args + * (1) opensearch index name + * @return + * write sql query result to given opensearch index + */ +object FlintJob extends Logging { + def main(args: Array[String]): Unit = { + // Validate command line arguments + if (args.length != 2) { + throw new IllegalArgumentException("Usage: FlintJob ") + } + + val Array(query, resultIndex) = args + + val conf = createSparkConf() + val wait = conf.get("spark.flint.job.type", "continue") + val dataSource = conf.get("spark.flint.datasource.name", "") + val spark = createSparkSession(conf) + + val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + var dataToWrite : Option[DataFrame] = None + try { + // flintClient needs spark session to be created first. Otherwise, we will have connection + // exception from EMR-S to OS. + val flintClient = FlintClientBuilder.build(FlintSparkConf().flintOptions()) + val futureMappingCheck = Future { + checkAndCreateIndex(flintClient, resultIndex) + } + val data = executeQuery(spark, query, dataSource) + + val (correctMapping, error) = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + dataToWrite = Some(if (correctMapping) data else getFailedData(spark, dataSource, error)) + } catch { + case e: TimeoutException => + val error = "Future operations timed out" + logError(error, e) + dataToWrite = Some(getFailedData(spark, dataSource, error)) + case e: Exception => + val error = "Fail to verify existing mapping or write result" + logError(error, e) + dataToWrite = Some(getFailedData(spark, dataSource, error)) + } finally { + dataToWrite.foreach(df => writeData(df, resultIndex)) + // Stop SparkSession if it is not streaming job + if (wait.equalsIgnoreCase("streaming")) { + spark.streams.awaitAnyTermination() + } else { + spark.stop() + } + + threadPool.shutdown() + } + } + + def createSparkConf(): SparkConf = { + new SparkConf() + .setAppName("FlintJob") + .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") + } + + def createSparkSession(conf: SparkConf): SparkSession = { + SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() + } + + def writeData(resultData: DataFrame, resultIndex: String): Unit = { + resultData.write + .format("flint") + .mode("append") + .save(resultIndex) + } + + /** + * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. + * + * @param result + * sql query result dataframe + * @param spark + * spark session + * @return + * dataframe with result, schema and emr step id + */ + def getFormattedData( + result: DataFrame, + spark: SparkSession, + dataSource: String + ): DataFrame = { + // Create the schema dataframe + val schemaRows = result.schema.fields.map { field => + Row(field.name, field.dataType.typeName) + } + val resultSchema = spark.createDataFrame( + spark.sparkContext.parallelize(schemaRows), + StructType( + Seq( + StructField("column_name", StringType, nullable = false), + StructField("data_type", StringType, nullable = false) + ) + ) + ) + + // Define the data schema + val schema = StructType( + Seq( + StructField( + "result", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField( + "schema", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true) + ) + ) + + // Create the data rows + val rows = Seq( + ( + result.toJSON.collect.toList + .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")), + resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), + sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), + sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + dataSource, + "SUCCESS", + "" + ) + ) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } + + def getFailedData(spark: SparkSession, dataSource: String, error: String): DataFrame = { + + // Define the data schema + val schema = StructType( + Seq( + StructField( + "result", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField( + "schema", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true) + ) + ) + + // Create the data rows + val rows = Seq( + ( + null, + null, + sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), + sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + dataSource, + "FAILED", + error + ) + ) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } + + def isSuperset(input: String, mapping: String): Boolean = { + + /** + * Determines whether one JSON structure is a superset of another. + * + * This method checks if the `input` JSON structure contains all the fields + * and values present in the `mapping` JSON structure. The comparison is + * recursive and structure-sensitive, ensuring that nested objects and arrays + * are also compared accurately. + * + * Additionally, this method accommodates the edge case where boolean values + * in the JSON are represented as strings (e.g., "true" or "false" instead of + * true or false). This is handled by performing a case-insensitive comparison + * of string representations of boolean values. + * + * @param input The input JSON structure as a String. + * @param mapping The mapping JSON structure as a String. + * @return A Boolean value indicating whether the `input` JSON structure + * is a superset of the `mapping` JSON structure. + */ + def compareJson(inputJson: JsValue, mappingJson: JsValue): Boolean = { + (inputJson, mappingJson) match { + case (JsObject(inputFields), JsObject(mappingFields)) => + logInfo(s"Comparing objects: $inputFields vs $mappingFields") + mappingFields.forall { case (key, value) => + inputFields + .get(key) + .exists(inputValue => compareJson(inputValue, value)) + } + case (JsArray(inputValues), JsArray(mappingValues)) => + logInfo(s"Comparing arrays: $inputValues vs $mappingValues") + mappingValues.forall(mappingValue => + inputValues.exists(inputValue => + compareJson(inputValue, mappingValue) + ) + ) + case (JsString(inputValue), JsString(mappingValue)) + if (inputValue.toLowerCase(Locale.ROOT) == "true" || + inputValue.toLowerCase(Locale.ROOT) == "false") && + (mappingValue.toLowerCase(Locale.ROOT) == "true" || + mappingValue.toLowerCase(Locale.ROOT) == "false") => + inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase( + Locale.ROOT + ) + case (JsBoolean(inputValue), JsString(mappingValue)) + if mappingValue.toLowerCase(Locale.ROOT) == "true" || + mappingValue.toLowerCase(Locale.ROOT) == "false" => + inputValue.toString.toLowerCase(Locale.ROOT) == mappingValue + .toLowerCase(Locale.ROOT) + case (JsString(inputValue), JsBoolean(mappingValue)) + if inputValue.toLowerCase(Locale.ROOT) == "true" || + inputValue.toLowerCase(Locale.ROOT) == "false" => + inputValue.toLowerCase(Locale.ROOT) == mappingValue.toString + .toLowerCase(Locale.ROOT) + case (inputValue, mappingValue) => + inputValue == mappingValue + } + } + + val inputJson = Json.parse(input) + val mappingJson = Json.parse(mapping) + + compareJson(inputJson, mappingJson) + } + + def checkAndCreateIndex( + flintClient: FlintClient, + resultIndex: String + ): (Boolean, String) = { + // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, + val mapping = + """{ + "dynamic": false, + "properties": { + "result": { + "type": "object", + "enabled": false + }, + "schema": { + "type": "object", + "enabled": false + }, + "jobRunId": { + "type": "keyword" + }, + "applicationId": { + "type": "keyword" + }, + "dataSourceName": { + "type": "keyword" + }, + "status": { + "type": "keyword" + }, + "error": { + "type": "text" + } + } + }""".stripMargin + + try { + val existingSchema = flintClient.getIndexMetadata(resultIndex).getContent + if (!isSuperset(existingSchema, mapping)) { + (false, s"The mapping of $resultIndex is incorrect.") + } else { + (true, "") + } + } catch { + case e: IllegalStateException if e.getCause().getMessage().contains("index_not_found_exception") => + handleIndexNotFoundException(flintClient, resultIndex, mapping) + case e: Exception => + val error = "Failed to verify existing mapping" + logError(error, e) + (false, error) + } + } + + def handleIndexNotFoundException( + flintClient: FlintClient, + resultIndex: String, + mapping: String + ): (Boolean, String) = { + try { + logInfo(s"create $resultIndex") + flintClient.createIndex(resultIndex, new FlintMetadata(mapping)) + logInfo(s"create $resultIndex successfully") + (true, "") + } catch { + case e: Exception => + val error = s"Failed to create result index $resultIndex" + logError(error, e) + (false, error) + } + } + def executeQuery( + spark: SparkSession, + query: String, + dataSource: String + ): DataFrame = { + // Execute SQL query + val result: DataFrame = spark.sql(query) + // Get Data + getFormattedData(result, spark, dataSource) + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala new file mode 100644 index 000000000..c32e63194 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class FlintJobTest extends SparkFunSuite with Matchers { + + val spark = + SparkSession.builder().appName("Test").master("local").getOrCreate() + + // Define input dataframe + val inputSchema = StructType( + Seq( + StructField("Letter", StringType, nullable = false), + StructField("Number", IntegerType, nullable = false))) + val inputRows = Seq(Row("A", 1), Row("B", 2), Row("C", 3)) + val input: DataFrame = + spark.createDataFrame(spark.sparkContext.parallelize(inputRows), inputSchema) + + test("Test getFormattedData method") { + // Define expected dataframe + val dataSourceName = "myGlueS3" + val expectedSchema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true) + )) + val expectedRows = Seq( + Row( + Array( + "{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}"), + Array( + "{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}"), + "unknown", + "unknown", + dataSourceName, + "SUCCESS", + "" + )) + val expected: DataFrame = + spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) + + // Compare the result + val result = FlintJob.getFormattedData(input, spark, dataSourceName) + assertEqualDataframe(expected, result) + } + + def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit = { + assert(expected.schema === result.schema) + assert(expected.collect() === result.collect()) + } + + test("test isSuperset") { + // note in input false has enclosed double quotes, while mapping just has false + val input = + """{"dynamic":"false","properties":{"result":{"type":"object"},"schema":{"type":"object"}, + |"applicationId":{"type":"keyword"},"jobRunId":{ + |"type":"keyword"},"dataSourceName":{"type":"keyword"},"status":{"type":"keyword"}, + |"error":{"type":"text"}}} + |""".stripMargin + val mapping = + """{"dynamic":false,"properties":{"result":{"type":"object"},"schema":{"type":"object"}, + |"jobRunId":{"type":"keyword"},"applicationId":{ + |"type":"keyword"},"dataSourceName":{"type":"keyword"},"status":{"type":"keyword"}}} + |"error":{"type":"text"}}} + |""".stripMargin + assert(FlintJob.isSuperset(input, mapping)) + } +}