From 879a5415a9180c5e1f8ce324813c70030ca8a8c2 Mon Sep 17 00:00:00 2001 From: Rupal Mahajan Date: Fri, 8 Sep 2023 10:03:47 -0700 Subject: [PATCH] Add spark sql app to snapshot workflow (#19) Signed-off-by: Rupal Mahajan --- .github/workflows/snapshot-publish.yml | 4 +- build.sbt | 19 ++- spark-sql-application/README.md | 109 +++++++++++++++++ .../scala/org/opensearch/sql/SQLJob.scala | 112 ++++++++++++++++++ .../scala/org/opensearch/sql/SQLJobTest.scala | 63 ++++++++++ 5 files changed, 305 insertions(+), 2 deletions(-) create mode 100644 spark-sql-application/README.md create mode 100644 spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala create mode 100644 spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala diff --git a/.github/workflows/snapshot-publish.yml b/.github/workflows/snapshot-publish.yml index 1e3367155..7dae546b3 100644 --- a/.github/workflows/snapshot-publish.yml +++ b/.github/workflows/snapshot-publish.yml @@ -27,7 +27,9 @@ jobs: java-version: 11 - name: Publish to Local Maven - run: sbt standaloneCosmetic/publishM2 + run: | + sbt standaloneCosmetic/publishM2 + sbt sparkSqlApplicationCosmetic/publishM2 - uses: actions/checkout@v3 with: diff --git a/build.sbt b/build.sbt index 7790104f2..a2c8d050a 100644 --- a/build.sbt +++ b/build.sbt @@ -43,7 +43,7 @@ lazy val commonSettings = Seq( Test / test := ((Test / test) dependsOn testScalastyle).value) lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration) + .aggregate(flintCore, flintSparkIntegration, sparkSqlApplication) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -125,6 +125,23 @@ lazy val standaloneCosmetic = project exportJars := true, Compile / packageBin := (flintSparkIntegration / assembly).value) +lazy val sparkSqlApplication = (project in file("spark-sql-application")) + .settings( + commonSettings, + name := "sql-job", + scalaVersion := scala212, + libraryDependencies ++= Seq( + "org.scalatest" %% "scalatest" % "3.2.15" % "test"), + libraryDependencies ++= deps(sparkVersion)) + +lazy val sparkSqlApplicationCosmetic = project + .settings( + name := "opensearch-spark-sql-application", + commonSettings, + releaseSettings, + exportJars := true, + Compile / packageBin := (sparkSqlApplication / assembly).value) + lazy val releaseSettings = Seq( publishMavenStyle := true, publishArtifact := true, diff --git a/spark-sql-application/README.md b/spark-sql-application/README.md new file mode 100644 index 000000000..07bf46406 --- /dev/null +++ b/spark-sql-application/README.md @@ -0,0 +1,109 @@ +# Spark SQL Application + +This application execute sql query and store the result in OpenSearch index in following format +``` +"stepId":"", +"applicationId":"" +"schema": "json blob", +"result": "json blob" +``` + +## Prerequisites + ++ Spark 3.3.1 ++ Scala 2.12.15 ++ flint-spark-integration + +## Usage + +To use this application, you can run Spark with Flint extension: + +``` +./bin/spark-submit \ + --class org.opensearch.sql.SQLJob \ + --jars \ + sql-job.jar \ + \ + \ + \ + \ + \ + \ + \ +``` + +## Result Specifications + +Following example shows how the result is written to OpenSearch index after query execution. + +Let's assume sql query result is +``` ++------+------+ +|Letter|Number| ++------+------+ +|A |1 | +|B |2 | +|C |3 | ++------+------+ +``` +OpenSearch index document will look like +```json +{ + "_index" : ".query_execution_result", + "_id" : "A2WOsYgBMUoqCqlDJHrn", + "_score" : 1.0, + "_source" : { + "result" : [ + "{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}" + ], + "schema" : [ + "{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}" + ], + "stepId" : "s-JZSB1139WIVU", + "applicationId" : "application_1687726870985_0003" + } +} +``` + +## Build + +To build and run this application with Spark, you can run: + +``` +sbt clean sparkSqlApplicationCosmetic/publishM2 +``` + +## Test + +To run tests, you can use: + +``` +sbt test +``` + +## Scalastyle + +To check code with scalastyle, you can run: + +``` +sbt scalastyle +``` + +## Code of Conduct + +This project has adopted an [Open Source Code of Conduct](../CODE_OF_CONDUCT.md). + +## Security + +If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public GitHub issue. + +## License + +See the [LICENSE](../LICENSE.txt) file for our project's licensing. We will ask you to confirm the licensing of your contribution. + +## Copyright + +Copyright OpenSearch Contributors. See [NOTICE](../NOTICE) for details. \ No newline at end of file diff --git a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala new file mode 100644 index 000000000..9e1d36857 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types._ + +/** + * Spark SQL Application entrypoint + * + * @param args(0) + * sql query + * @param args(1) + * opensearch index name + * @param args(2-6) + * opensearch connection values required for flint-integration jar. + * host, port, scheme, auth, region respectively. + * @return + * write sql query result to given opensearch index + */ +object SQLJob { + def main(args: Array[String]) { + // Get the SQL query and Opensearch Config from the command line arguments + val query = args(0) + val index = args(1) + val host = args(2) + val port = args(3) + val scheme = args(4) + val auth = args(5) + val region = args(6) + + val conf: SparkConf = new SparkConf() + .setAppName("SQLJob") + .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") + .set("spark.datasource.flint.host", host) + .set("spark.datasource.flint.port", port) + .set("spark.datasource.flint.scheme", scheme) + .set("spark.datasource.flint.auth", auth) + .set("spark.datasource.flint.region", region) + + // Create a SparkSession + val spark = SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() + + try { + // Execute SQL query + val result: DataFrame = spark.sql(query) + + // Get Data + val data = getFormattedData(result, spark) + + // Write data to OpenSearch index + val aos = Map( + "host" -> host, + "port" -> port, + "scheme" -> scheme, + "auth" -> auth, + "region" -> region) + + data.write + .format("flint") + .options(aos) + .mode("append") + .save(index) + + } finally { + // Stop SparkSession + spark.stop() + } + } + + /** + * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. + * + * @param result + * sql query result dataframe + * @param spark + * spark session + * @return + * dataframe with result, schema and emr step id + */ + def getFormattedData(result: DataFrame, spark: SparkSession): DataFrame = { + // Create the schema dataframe + val schemaRows = result.schema.fields.map { field => + Row(field.name, field.dataType.typeName) + } + val resultSchema = spark.createDataFrame(spark.sparkContext.parallelize(schemaRows), + StructType(Seq( + StructField("column_name", StringType, nullable = false), + StructField("data_type", StringType, nullable = false)))) + + // Define the data schema + val schema = StructType(Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("stepId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true))) + + // Create the data rows + val rows = Seq(( + result.toJSON.collect.toList.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")), + resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), + sys.env.getOrElse("EMR_STEP_ID", "unknown"), + spark.sparkContext.applicationId)) + + // Create the DataFrame for data + spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) + } +} diff --git a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala new file mode 100644 index 000000000..f98608c80 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} + + +class SQLJobTest extends SparkFunSuite with Matchers { + + val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() + + // Define input dataframe + val inputSchema = StructType(Seq( + StructField("Letter", StringType, nullable = false), + StructField("Number", IntegerType, nullable = false) + )) + val inputRows = Seq( + Row("A", 1), + Row("B", 2), + Row("C", 3) + ) + val input: DataFrame = spark.createDataFrame( + spark.sparkContext.parallelize(inputRows), inputSchema) + + test("Test getFormattedData method") { + // Define expected dataframe + val expectedSchema = StructType(Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("stepId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true) + )) + val expectedRows = Seq( + Row( + Array("{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}"), + Array("{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}"), + "unknown", + spark.sparkContext.applicationId + ) + ) + val expected: DataFrame = spark.createDataFrame( + spark.sparkContext.parallelize(expectedRows), expectedSchema) + + // Compare the result + val result = SQLJob.getFormattedData(input, spark) + assertEqualDataframe(expected, result) + } + + def assertEqualDataframe(expected: DataFrame, result: DataFrame): Unit = { + assert(expected.schema === result.schema) + assert(expected.collect() === result.collect()) + } +}