diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala index 0fa146b9d..78a9c0628 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala @@ -79,7 +79,6 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText) - // Starting from here is copied and modified from Spark 3.3.1 protected def parse[T](sqlText: String)(toResult: FlintSparkSqlExtensionsParser => T): T = { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/SparkSqlAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/SparkSqlAstBuilder.scala index 4dadd4d5e..6ba432ed7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/SparkSqlAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/SparkSqlAstBuilder.scala @@ -15,8 +15,8 @@ import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{PropertyKey import org.apache.spark.sql.catalyst.parser.ParserUtils.string /** - * AST builder that builds for common rule in Spark SQL grammar. The main logic is modified slightly - * from Spark AstBuilder code. + * AST builder that builds for common rule in Spark SQL grammar. The main logic is modified + * slightly from Spark AstBuilder code. */ trait SparkSqlAstBuilder extends FlintSparkSqlExtensionsVisitor[AnyRef] { diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala index b2b6adf81..eb3c2a371 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala @@ -163,7 +163,7 @@ class FlintDataTypeSuite extends FlintSuite with Matchers { StructField("varcharTextField", VarcharType(20), true, textMetadata) :: StructField("charTextField", CharType(20), true, textMetadata) :: Nil) - FlintDataType.serialize(sparkStructType) shouldBe compactJson(flintDataType) + FlintDataType.serialize(sparkStructType) shouldBe compactJson(flintDataType) // flint data type should not deserialize to varchar or char FlintDataType.deserialize(flintDataType) should contain theSameElementsAs StructType( StructField("varcharTextField", StringType, true, textMetadata) :: diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index dece8ae5f..d12d03565 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -8,12 +8,13 @@ package org.apache.spark.sql import java.util.Locale +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} + import org.opensearch.ExceptionsHelper import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} import org.opensearch.flint.core.metadata.FlintMetadata import play.api.libs.json._ -import scala.concurrent.{ExecutionContext, Future, TimeoutException} -import scala.concurrent.duration.{Duration, MINUTES} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -23,14 +24,14 @@ import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint - * - * @param args - * (0) sql query - * @param args - * (1) opensearch index name - * @return - * write sql query result to given opensearch index - */ + * + * @param args + * (0) sql query + * @param args + * (1) opensearch index name + * @return + * write sql query result to given opensearch index + */ object FlintJob extends Logging { def main(args: Array[String]): Unit = { // Validate command line arguments @@ -48,7 +49,7 @@ object FlintJob extends Logging { val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - var dataToWrite : Option[DataFrame] = None + var dataToWrite: Option[DataFrame] = None try { // flintClient needs spark session to be created first. Otherwise, we will have connection // exception from EMR-S to OS. @@ -58,7 +59,8 @@ object FlintJob extends Logging { } val data = executeQuery(spark, query, dataSource) - val (correctMapping, error) = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + val (correctMapping, error) = + ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) dataToWrite = Some(if (correctMapping) data else getFailedData(spark, dataSource, error)) } catch { case e: TimeoutException => @@ -100,19 +102,15 @@ object FlintJob extends Logging { /** * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. - * - * @param result - * sql query result dataframe - * @param spark - * spark session - * @return - * dataframe with result, schema and emr step id - */ - def getFormattedData( - result: DataFrame, - spark: SparkSession, - dataSource: String - ): DataFrame = { + * + * @param result + * sql query result dataframe + * @param spark + * spark session + * @return + * dataframe with result, schema and emr step id + */ + def getFormattedData(result: DataFrame, spark: SparkSession, dataSource: String): DataFrame = { // Create the schema dataframe val schemaRows = result.schema.fields.map { field => Row(field.name, field.dataType.typeName) @@ -122,31 +120,18 @@ object FlintJob extends Logging { StructType( Seq( StructField("column_name", StringType, nullable = false), - StructField("data_type", StringType, nullable = false) - ) - ) - ) + StructField("data_type", StringType, nullable = false)))) // Define the data schema val schema = StructType( Seq( - StructField( - "result", - ArrayType(StringType, containsNull = true), - nullable = true - ), - StructField( - "schema", - ArrayType(StringType, containsNull = true), - nullable = true - ), + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), StructField("jobRunId", StringType, nullable = true), StructField("applicationId", StringType, nullable = true), StructField("dataSourceName", StringType, nullable = true), StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true) - ) - ) + StructField("error", StringType, nullable = true))) // Create the data rows val rows = Seq( @@ -158,9 +143,7 @@ object FlintJob extends Logging { sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "SUCCESS", - "" - ) - ) + "")) // Create the DataFrame for data spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) @@ -171,23 +154,13 @@ object FlintJob extends Logging { // Define the data schema val schema = StructType( Seq( - StructField( - "result", - ArrayType(StringType, containsNull = true), - nullable = true - ), - StructField( - "schema", - ArrayType(StringType, containsNull = true), - nullable = true - ), + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), StructField("jobRunId", StringType, nullable = true), StructField("applicationId", StringType, nullable = true), StructField("dataSourceName", StringType, nullable = true), StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true) - ) - ) + StructField("error", StringType, nullable = true))) // Create the data rows val rows = Seq( @@ -198,9 +171,7 @@ object FlintJob extends Logging { sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "FAILED", - error - ) - ) + error)) // Create the DataFrame for data spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) @@ -210,22 +181,23 @@ object FlintJob extends Logging { /** * Determines whether one JSON structure is a superset of another. - * - * This method checks if the `input` JSON structure contains all the fields - * and values present in the `mapping` JSON structure. The comparison is - * recursive and structure-sensitive, ensuring that nested objects and arrays - * are also compared accurately. - * - * Additionally, this method accommodates the edge case where boolean values - * in the JSON are represented as strings (e.g., "true" or "false" instead of - * true or false). This is handled by performing a case-insensitive comparison - * of string representations of boolean values. - * - * @param input The input JSON structure as a String. - * @param mapping The mapping JSON structure as a String. - * @return A Boolean value indicating whether the `input` JSON structure - * is a superset of the `mapping` JSON structure. - */ + * + * This method checks if the `input` JSON structure contains all the fields and values present + * in the `mapping` JSON structure. The comparison is recursive and structure-sensitive, + * ensuring that nested objects and arrays are also compared accurately. + * + * Additionally, this method accommodates the edge case where boolean values in the JSON are + * represented as strings (e.g., "true" or "false" instead of true or false). This is handled + * by performing a case-insensitive comparison of string representations of boolean values. + * + * @param input + * The input JSON structure as a String. + * @param mapping + * The mapping JSON structure as a String. + * @return + * A Boolean value indicating whether the `input` JSON structure is a superset of the + * `mapping` JSON structure. + */ def compareJson(inputJson: JsValue, mappingJson: JsValue): Boolean = { (inputJson, mappingJson) match { case (JsObject(inputFields), JsObject(mappingFields)) => @@ -238,18 +210,13 @@ object FlintJob extends Logging { case (JsArray(inputValues), JsArray(mappingValues)) => logInfo(s"Comparing arrays: $inputValues vs $mappingValues") mappingValues.forall(mappingValue => - inputValues.exists(inputValue => - compareJson(inputValue, mappingValue) - ) - ) + inputValues.exists(inputValue => compareJson(inputValue, mappingValue))) case (JsString(inputValue), JsString(mappingValue)) if (inputValue.toLowerCase(Locale.ROOT) == "true" || inputValue.toLowerCase(Locale.ROOT) == "false") && (mappingValue.toLowerCase(Locale.ROOT) == "true" || mappingValue.toLowerCase(Locale.ROOT) == "false") => - inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase( - Locale.ROOT - ) + inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase(Locale.ROOT) case (JsBoolean(inputValue), JsString(mappingValue)) if mappingValue.toLowerCase(Locale.ROOT) == "true" || mappingValue.toLowerCase(Locale.ROOT) == "false" => @@ -271,10 +238,7 @@ object FlintJob extends Logging { compareJson(inputJson, mappingJson) } - def checkAndCreateIndex( - flintClient: FlintClient, - resultIndex: String - ): (Boolean, String) = { + def checkAndCreateIndex(flintClient: FlintClient, resultIndex: String): (Boolean, String) = { // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, val mapping = """{ @@ -314,7 +278,8 @@ object FlintJob extends Logging { (true, "") } } catch { - case e: IllegalStateException if e.getCause().getMessage().contains("index_not_found_exception") => + case e: IllegalStateException + if e.getCause().getMessage().contains("index_not_found_exception") => handleIndexNotFoundException(flintClient, resultIndex, mapping) case e: Exception => val error = "Failed to verify existing mapping" @@ -324,10 +289,9 @@ object FlintJob extends Logging { } def handleIndexNotFoundException( - flintClient: FlintClient, - resultIndex: String, - mapping: String - ): (Boolean, String) = { + flintClient: FlintClient, + resultIndex: String, + mapping: String): (Boolean, String) = { try { logInfo(s"create $resultIndex") flintClient.createIndex(resultIndex, new FlintMetadata(mapping)) @@ -340,11 +304,7 @@ object FlintJob extends Logging { (false, error) } } - def executeQuery( - spark: SparkSession, - query: String, - dataSource: String - ): DataFrame = { + def executeQuery(spark: SparkSession, query: String, dataSource: String): DataFrame = { // Execute SQL query val result: DataFrame = spark.sql(query) // Get Data diff --git a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala index ee0fa4513..43544c78d 100644 --- a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala +++ b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala @@ -12,15 +12,15 @@ import org.apache.spark.sql.types._ /** * Spark SQL Application entrypoint * - * @param args (0) - * sql query - * @param args (1) - * opensearch index name - * @param args (2-6) - * opensearch connection values required for flint-integration jar. - * host, port, scheme, auth, region respectively. + * @param args + * (0) sql query + * @param args + * (1) opensearch index name + * @param args + * (2-6) opensearch connection values required for flint-integration jar. host, port, scheme, + * auth, region respectively. * @return - * write sql query result to given opensearch index + * write sql query result to given opensearch index */ case class JobConfig( query: String, @@ -29,13 +29,13 @@ case class JobConfig( port: String, scheme: String, auth: String, - region: String - ) + region: String) object SQLJob { private def parseArgs(args: Array[String]): JobConfig = { if (args.length < 7) { - throw new IllegalArgumentException("Insufficient arguments provided! - args: [extensions, query, index, host, port, scheme, auth, region]") + throw new IllegalArgumentException( + "Insufficient arguments provided! - args: [extensions, query, index, host, port, scheme, auth, region]") } JobConfig( @@ -45,8 +45,7 @@ object SQLJob { port = args(3), scheme = args(4), auth = args(5), - region = args(6) - ) + region = args(6)) } def createSparkConf(config: JobConfig): SparkConf = { @@ -97,35 +96,39 @@ object SQLJob { * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. * * @param result - * sql query result dataframe + * sql query result dataframe * @param spark - * spark session + * spark session * @return - * dataframe with result, schema and emr step id + * dataframe with result, schema and emr step id */ def getFormattedData(result: DataFrame, spark: SparkSession): DataFrame = { // Create the schema dataframe val schemaRows = result.schema.fields.map { field => Row(field.name, field.dataType.typeName) } - val resultSchema = spark.createDataFrame(spark.sparkContext.parallelize(schemaRows), - StructType(Seq( - StructField("column_name", StringType, nullable = false), - StructField("data_type", StringType, nullable = false)))) + val resultSchema = spark.createDataFrame( + spark.sparkContext.parallelize(schemaRows), + StructType( + Seq( + StructField("column_name", StringType, nullable = false), + StructField("data_type", StringType, nullable = false)))) // Define the data schema - val schema = StructType(Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("stepId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true))) + val schema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("stepId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true))) // Create the data rows - val rows = Seq(( - result.toJSON.collect.toList.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")), - resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), - sys.env.getOrElse("EMR_STEP_ID", "unknown"), - spark.sparkContext.applicationId)) + val rows = Seq( + ( + result.toJSON.collect.toList.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")), + resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), + sys.env.getOrElse("EMR_STEP_ID", "unknown"), + spark.sparkContext.applicationId)) // Create the DataFrame for data spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index c32e63194..b891be0e1 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -35,8 +35,7 @@ class FlintJobTest extends SparkFunSuite with Matchers { StructField("applicationId", StringType, nullable = true), StructField("dataSourceName", StringType, nullable = true), StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true) - )) + StructField("error", StringType, nullable = true))) val expectedRows = Seq( Row( Array( @@ -50,8 +49,7 @@ class FlintJobTest extends SparkFunSuite with Matchers { "unknown", dataSourceName, "SUCCESS", - "" - )) + "")) val expected: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) diff --git a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala index f98608c80..063c76c4d 100644 --- a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala +++ b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala @@ -11,45 +11,40 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} - class SQLJobTest extends SparkFunSuite with Matchers { val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() // Define input dataframe - val inputSchema = StructType(Seq( - StructField("Letter", StringType, nullable = false), - StructField("Number", IntegerType, nullable = false) - )) - val inputRows = Seq( - Row("A", 1), - Row("B", 2), - Row("C", 3) - ) - val input: DataFrame = spark.createDataFrame( - spark.sparkContext.parallelize(inputRows), inputSchema) + val inputSchema = StructType( + Seq( + StructField("Letter", StringType, nullable = false), + StructField("Number", IntegerType, nullable = false))) + val inputRows = Seq(Row("A", 1), Row("B", 2), Row("C", 3)) + val input: DataFrame = + spark.createDataFrame(spark.sparkContext.parallelize(inputRows), inputSchema) test("Test getFormattedData method") { // Define expected dataframe - val expectedSchema = StructType(Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("stepId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true) - )) + val expectedSchema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("stepId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true))) val expectedRows = Seq( Row( - Array("{'Letter':'A','Number':1}", + Array( + "{'Letter':'A','Number':1}", "{'Letter':'B','Number':2}", "{'Letter':'C','Number':3}"), - Array("{'column_name':'Letter','data_type':'string'}", + Array( + "{'column_name':'Letter','data_type':'string'}", "{'column_name':'Number','data_type':'integer'}"), "unknown", - spark.sparkContext.applicationId - ) - ) - val expected: DataFrame = spark.createDataFrame( - spark.sparkContext.parallelize(expectedRows), expectedSchema) + spark.sparkContext.applicationId)) + val expected: DataFrame = + spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) // Compare the result val result = SQLJob.getFormattedData(input, spark)