From 64cbfd9ed3ed239402ece3249a6a7d83a0664e91 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Thu, 21 Sep 2023 11:36:13 -0700 Subject: [PATCH] Implement FlintJob Logic for EMR-S This commit introduces FlintJob logic for EMR-S, mirroring the existing SQLJob implementation for EMR cluster. The key differences in FlintJob are: 1. It reads OpenSearch host information from spark command parameters. 2. It ensures the existence of a result index with the correct mapping in OpenSearch, creating it if necessary. This process occurs in parallel to SQL query execution. 3. It reports an error if the result index mapping is incorrect. 4. It saves a failure status if the SQL execution fails. Testing: 1. Manual testing was conducted using the EMR-S CLI. 2. New unit tests were added to verify the functionality. Signed-off-by: Kaituo Li --- .scalafmt.conf | 1 + build.sbt | 30 +- spark-sql-application/README.md | 84 ++++- .../scala/org/apache/spark/sql/FlintJob.scala | 349 ++++++++++++++++++ .../org/apache/spark/sql/FlintJobTest.scala | 82 ++++ 5 files changed, 539 insertions(+), 7 deletions(-) create mode 100644 .scalafmt.conf create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 000000000..834f2d20f --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1 @@ +version = 2.7.5 \ No newline at end of file diff --git a/build.sbt b/build.sbt index 6b7c8d53a..7bb01226d 100644 --- a/build.sbt +++ b/build.sbt @@ -114,6 +114,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", "com.github.sbt" % "junit-interface" % "0.13.3" % "test"), libraryDependencies ++= deps(sparkVersion), + libraryDependencies += "com.typesafe.play" %% "play-json" % "2.9.2", // ANTLR settings Antlr4 / antlr4Version := "4.8", Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.sql"), @@ -165,12 +166,37 @@ lazy val standaloneCosmetic = project Compile / packageBin := (flintSparkIntegration / assembly).value) lazy val sparkSqlApplication = (project in file("spark-sql-application")) + // dependency will be provided at runtime, so it doesn't need to be included in the assembled JAR + .dependsOn(flintSparkIntegration % "provided") .settings( commonSettings, name := "sql-job", scalaVersion := scala212, - libraryDependencies ++= Seq("org.scalatest" %% "scalatest" % "3.2.15" % "test"), - libraryDependencies ++= deps(sparkVersion)) + libraryDependencies ++= Seq( + "org.scalatest" %% "scalatest" % "3.2.15" % "test"), + libraryDependencies ++= deps(sparkVersion), + libraryDependencies += "com.typesafe.play" %% "play-json" % "2.9.2", + // Assembly settings + // the sbt assembly plugin found multiple copies of the module-info.class file with + // different contents in the jars that it was merging flintCore dependencies. + // This can happen if you have multiple dependencies that include the same library, + // but with different versions. + assemblyPackageScala / assembleArtifact := false, + assembly / assemblyOption ~= { + _.withIncludeScala(false) + }, + assembly / assemblyMergeStrategy := { + case PathList(ps@_*) if ps.last endsWith ("module-info.class") => + MergeStrategy.discard + case PathList("module-info.class") => MergeStrategy.discard + case PathList("META-INF", "versions", xs@_, "module-info.class") => + MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + }, + assembly / test := (Test / test).value + ) lazy val sparkSqlApplicationCosmetic = project .settings( diff --git a/spark-sql-application/README.md b/spark-sql-application/README.md index 07bf46406..533ee81cb 100644 --- a/spark-sql-application/README.md +++ b/spark-sql-application/README.md @@ -1,13 +1,25 @@ # Spark SQL Application -This application execute sql query and store the result in OpenSearch index in following format +We have two applications: SQLJob and FlintJob. + +SQLJob is designed for EMR Spark, executing SQL queries and storing the results in the OpenSearch index in the following format: ``` "stepId":"", -"applicationId":"" +"applicationId":"", "schema": "json blob", "result": "json blob" ``` +FlintJob is designed for EMR Serverless Spark, executing SQL queries and storing the results in the OpenSearch index in the following format: + +``` +"jobRunId":"", +"applicationId":"", +"schema": "json blob", +"result": "json blob", +"dataSourceName":"" +``` + ## Prerequisites + Spark 3.3.1 @@ -16,8 +28,9 @@ This application execute sql query and store the result in OpenSearch index in f ## Usage -To use this application, you can run Spark with Flint extension: +To use these applications, you can run Spark with Flint extension: +SQLJob ``` ./bin/spark-submit \ --class org.opensearch.sql.SQLJob \ @@ -32,11 +45,41 @@ To use this application, you can run Spark with Flint extension: \ ``` +FlintJob +``` +aws emr-serverless start-job-run \ + --region \ + --application-id \ + --execution-role-arn \ + --job-driver '{"sparkSubmit": {"entryPoint": "", \ + "entryPointArguments":["''", "", ""], \ + "sparkSubmitParameters":"--class org.opensearch.sql.FlintJob \ + --conf spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider \ + --conf spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN= \ + --conf spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN= \ + --conf spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory \ + --conf spark.hive.metastore.glue.role.arn= \ + --conf spark.jars= \ + --conf spark.jars.packages= \ + --conf spark.jars.repositories= \ + --conf spark.emr-serverless.driverEnv.JAVA_HOME= \ + --conf spark.executorEnv.JAVA_HOME= \ + --conf spark.datasource.flint.host= \ + --conf spark.datasource.flint.port= \ + --conf spark.datasource.flint.scheme= \ + --conf spark.datasource.flint.auth= \ + --conf spark.datasource.flint.region= \ + --conf spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider \ + --conf spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions \ + --conf spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory "}}' + +``` + ## Result Specifications Following example shows how the result is written to OpenSearch index after query execution. -Let's assume sql query result is +Let's assume SQL query result is ``` +------+------+ |Letter|Number| @@ -46,7 +89,7 @@ Let's assume sql query result is |C |3 | +------+------+ ``` -OpenSearch index document will look like +For SQLJob, OpenSearch index document will look like ```json { "_index" : ".query_execution_result", @@ -68,6 +111,29 @@ OpenSearch index document will look like } ``` +For FlintJob, OpenSearch index document will look like +```json +{ + "_index" : ".query_execution_result", + "_id" : "A2WOsYgBMUoqCqlDJHrn", + "_score" : 1.0, + "_source" : { + "result" : [ + "{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}" + ], + "schema" : [ + "{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}" + ], + "jobRunId" : "s-JZSB1139WIVU", + "applicationId" : "application_1687726870985_0003", + "dataSourceName": "myS3Glue" + } +} +``` + ## Build To build and run this application with Spark, you can run: @@ -76,6 +142,8 @@ To build and run this application with Spark, you can run: sbt clean sparkSqlApplicationCosmetic/publishM2 ``` +The jar file is located at `spark-sql-application/target/scala-2.12` folder. + ## Test To run tests, you can use: @@ -92,6 +160,12 @@ To check code with scalastyle, you can run: sbt scalastyle ``` +To check code with scalastyle, you can run: + +``` +sbt testScalastyle +``` + ## Code of Conduct This project has adopted an [Open Source Code of Conduct](../CODE_OF_CONDUCT.md). diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala new file mode 100644 index 000000000..9f4a464e7 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -0,0 +1,349 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +// defined in spark package so that I can use ThreadUtils +package org.apache.spark.sql + +import java.util.Locale + +import org.opensearch.ExceptionsHelper +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} +import org.opensearch.flint.core.metadata.FlintMetadata +import play.api.libs.json._ +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.ThreadUtils + +/** + * Spark SQL Application entrypoint + * + * @param args + * (0) sql query + * @param args + * (1) opensearch index name + * @param args + * (2) opensearch data source name + * @return + * write sql query result to given opensearch index + */ +object FlintJob extends Logging { + def main(args: Array[String]): Unit = { + // Validate command line arguments + if (args.length != 3) { + throw new IllegalArgumentException("Usage: FlintJob ") + } + + val Array(query, resultIndex, dataSource) = args + + val conf = createSparkConf() + val spark = createSparkSession(conf) + + val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + try { + // flintClient needs spark session to be created first. Otherwise, we will have connection + // exception from EMR-S to OS. + val flintClient = FlintClientBuilder.build(FlintSparkConf().flintOptions()) + val futureMappingCheck = Future { + checkAndCreateIndex(flintClient, resultIndex) + } + val data = executeQuery(spark, query, dataSource) + + val correctMapping = ThreadUtils.awaitResult(futureMappingCheck, Duration(10, MINUTES)) + writeData(spark, data, resultIndex, correctMapping, dataSource) + + } catch { + case e: TimeoutException => + logError("Future operations timed out", e) + throw e + case e: Exception => + logError("Fail to verify existing mapping or write result", e) + throw e + } finally { + // Stop SparkSession + spark.stop() + threadPool.shutdown() + } + } + + def createSparkConf(): SparkConf = { + new SparkConf() + .setAppName("FlintJob") + .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") + } + + def createSparkSession(conf: SparkConf): SparkSession = { + SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() + } + + def writeData(spark: SparkSession, data: DataFrame, resultIndex: String, correctMapping: Boolean, + dataSource: String): Unit = { + val resultData = if (correctMapping) data else getFailedData(spark, dataSource) + resultData.write + .format("flint") + .mode("append") + .save(resultIndex) + } + + /** + * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. + * + * @param result + * sql query result dataframe + * @param spark + * spark session + * @return + * dataframe with result, schema and emr step id + */ + def getFormattedData( + result: DataFrame, + spark: SparkSession, + dataSource: String + ): DataFrame = { + // Create the schema dataframe + val schemaRows = result.schema.fields.map { field => + Row(field.name, field.dataType.typeName) + } + val resultSchema = spark.createDataFrame( + spark.sparkContext.parallelize(schemaRows), + StructType( + Seq( + StructField("column_name", StringType, nullable = false), + StructField("data_type", StringType, nullable = false) + ) + ) + ) + + // Define the data schema + val schema = StructType( + Seq( + StructField( + "result", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField( + "schema", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true) + ) + ) + + // Create the data rows + val rows = Seq( + ( + result.toJSON.collect.toList + .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")), + resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), + sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), + sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + dataSource, + "SUCCESS" + ) + ) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } + + def getFailedData(spark: SparkSession, dataSource: String): DataFrame = { + + // Define the data schema + val schema = StructType( + Seq( + StructField( + "result", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField( + "schema", + ArrayType(StringType, containsNull = true), + nullable = true + ), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true) + ) + ) + + // Create the data rows + val rows = Seq( + ( + null, + null, + sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), + sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + dataSource, + "FAILED" + ) + ) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } + + def isSuperset(input: String, mapping: String): Boolean = { + + /** + * Determines whether one JSON structure is a superset of another. + * + * This method checks if the `input` JSON structure contains all the fields + * and values present in the `mapping` JSON structure. The comparison is + * recursive and structure-sensitive, ensuring that nested objects and arrays + * are also compared accurately. + * + * Additionally, this method accommodates the edge case where boolean values + * in the JSON are represented as strings (e.g., "true" or "false" instead of + * true or false). This is handled by performing a case-insensitive comparison + * of string representations of boolean values. + * + * @param input The input JSON structure as a String. + * @param mapping The mapping JSON structure as a String. + * @return A Boolean value indicating whether the `input` JSON structure + * is a superset of the `mapping` JSON structure. + */ + def compareJson(inputJson: JsValue, mappingJson: JsValue): Boolean = { + (inputJson, mappingJson) match { + case (JsObject(inputFields), JsObject(mappingFields)) => + logInfo(s"Comparing objects: $inputFields vs $mappingFields") + mappingFields.forall { case (key, value) => + inputFields + .get(key) + .exists(inputValue => compareJson(inputValue, value)) + } + case (JsArray(inputValues), JsArray(mappingValues)) => + logInfo(s"Comparing arrays: $inputValues vs $mappingValues") + mappingValues.forall(mappingValue => + inputValues.exists(inputValue => + compareJson(inputValue, mappingValue) + ) + ) + case (JsString(inputValue), JsString(mappingValue)) + if (inputValue.toLowerCase(Locale.ROOT) == "true" || + inputValue.toLowerCase(Locale.ROOT) == "false") && + (mappingValue.toLowerCase(Locale.ROOT) == "true" || + mappingValue.toLowerCase(Locale.ROOT) == "false") => + inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase( + Locale.ROOT + ) + case (JsBoolean(inputValue), JsString(mappingValue)) + if mappingValue.toLowerCase(Locale.ROOT) == "true" || + mappingValue.toLowerCase(Locale.ROOT) == "false" => + inputValue.toString.toLowerCase(Locale.ROOT) == mappingValue + .toLowerCase(Locale.ROOT) + case (JsString(inputValue), JsBoolean(mappingValue)) + if inputValue.toLowerCase(Locale.ROOT) == "true" || + inputValue.toLowerCase(Locale.ROOT) == "false" => + inputValue.toLowerCase(Locale.ROOT) == mappingValue.toString + .toLowerCase(Locale.ROOT) + case (inputValue, mappingValue) => + inputValue == mappingValue + } + } + + val inputJson = Json.parse(input) + val mappingJson = Json.parse(mapping) + logInfo(s"inputJson $inputJson") + logInfo(s"mappingJson $mappingJson") + + compareJson(inputJson, mappingJson) + } + + def checkAndCreateIndex( + flintClient: FlintClient, + resultIndex: String + ): Boolean = { + var correctMapping = false + + val mapping = + """{ + "dynamic": false, + "properties": { + "result": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "schema": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "jobRunId": { + "type": "keyword" + }, + "applicationId": { + "type": "keyword" + }, + "dataSourceName": { + "type": "keyword" + }, + "status": { + "type": "keyword" + } + } + }""".stripMargin + + try { + val existingSchema = flintClient.getIndexMetadata(resultIndex).getContent + if (!isSuperset(existingSchema, mapping)) { + logError(s"The mapping of $resultIndex is incorrect.") + } else { + correctMapping = true + } + } catch { + case e: IllegalStateException => + logInfo("get mapping exception", e) + val cause = ExceptionsHelper.unwrapCause(e.getCause()) + logInfo("cause", cause) + logInfo("cause2", cause.getCause()) + if (cause.getMessage().contains("index_not_found_exception")) { + try { + logInfo(s"create $resultIndex") + flintClient.createIndex(resultIndex, new FlintMetadata(mapping)) + logInfo(s"create $resultIndex successfully") + correctMapping = true + } catch { + case _: Exception => + logError(s"Fail to create result index $resultIndex") + } + } + case e: Exception => logError("Fail to verify existing mapping", e); + } + correctMapping + } + + def executeQuery( + spark: SparkSession, + query: String, + dataSource: String + ): DataFrame = { + // Execute SQL query + val result: DataFrame = spark.sql(query) + // Get Data + getFormattedData(result, spark, dataSource) + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala new file mode 100644 index 000000000..56ac45c2f --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class FlintJobTest extends SparkFunSuite with Matchers { + + val spark = + SparkSession.builder().appName("Test").master("local").getOrCreate() + + // Define input dataframe + val inputSchema = StructType( + Seq( + StructField("Letter", StringType, nullable = false), + StructField("Number", IntegerType, nullable = false))) + val inputRows = Seq(Row("A", 1), Row("B", 2), Row("C", 3)) + val input: DataFrame = + spark.createDataFrame(spark.sparkContext.parallelize(inputRows), inputSchema) + + test("Test getFormattedData method") { + // Define expected dataframe + val dataSourceName = "myGlueS3" + val expectedSchema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true) + )) + val expectedRows = Seq( + Row( + Array( + "{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}"), + Array( + "{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}"), + "unknown", + "unknown", + dataSourceName, + "SUCCESS" + )) + val expected: DataFrame = + spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) + + // Compare the result + val result = FlintJob.getFormattedData(input, spark, dataSourceName) + assertEqualDataframe(expected, result) + } + + def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit = { + assert(expected.schema === result.schema) + assert(expected.collect() === result.collect()) + } + + test("test isSuperset") { + // note in input false has enclosed double quotes, while mapping just has false + val input = + """{"dynamic":"false","properties":{"result":{"type":"text","fields":{"keyword":{ + |"ignore_above":256,"type":"keyword"}}},"schema":{"type":"text","fields":{"keyword":{ + |"ignore_above":256,"type":"keyword"}}},"applicationId":{"type":"keyword"},"jobRunId":{ + |"type":"keyword"},"dataSourceName":{"type":"keyword"},"status":{"type":"keyword"}}} + |""".stripMargin + val mapping = + """{"dynamic":false,"properties":{"result":{"type":"text","fields":{"keyword":{ + |"type":"keyword","ignore_above":256}}},"schema":{"type":"text","fields":{"keyword":{ + |"type":"keyword","ignore_above":256}}},"jobRunId":{"type":"keyword"},"applicationId":{ + |"type":"keyword"},"dataSourceName":{"type":"keyword"},"status":{"type":"keyword"}}} + |""".stripMargin + assert(FlintJob.isSuperset(input, mapping)) + } +}