Skip to content

Commit

Permalink
reformat code using sbt scalafmtAll
Browse files Browse the repository at this point in the history
Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Oct 10, 2023
1 parent 6e7c88a commit a6b9598
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface

override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText)


// Starting from here is copied and modified from Spark 3.3.1

protected def parse[T](sqlText: String)(toResult: FlintSparkSqlExtensionsParser => T): T = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{PropertyKey
import org.apache.spark.sql.catalyst.parser.ParserUtils.string

/**
* AST builder that builds for common rule in Spark SQL grammar. The main logic is modified slightly
* from Spark AstBuilder code.
* AST builder that builds for common rule in Spark SQL grammar. The main logic is modified
* slightly from Spark AstBuilder code.
*/
trait SparkSqlAstBuilder extends FlintSparkSqlExtensionsVisitor[AnyRef] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class FlintDataTypeSuite extends FlintSuite with Matchers {
StructField("varcharTextField", VarcharType(20), true, textMetadata) ::
StructField("charTextField", CharType(20), true, textMetadata) ::
Nil)
FlintDataType.serialize(sparkStructType) shouldBe compactJson(flintDataType)
FlintDataType.serialize(sparkStructType) shouldBe compactJson(flintDataType)
// flint data type should not deserialize to varchar or char
FlintDataType.deserialize(flintDataType) should contain theSameElementsAs StructType(
StructField("varcharTextField", StringType, true, textMetadata) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ package org.apache.spark.sql

import java.util.Locale

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import org.opensearch.ExceptionsHelper
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.FlintMetadata
import play.api.libs.json._
import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
Expand All @@ -23,14 +24,14 @@ import org.apache.spark.util.ThreadUtils

/**
* Spark SQL Application entrypoint
*
* @param args
* (0) sql query
* @param args
* (1) opensearch index name
* @return
* write sql query result to given opensearch index
*/
*
* @param args
* (0) sql query
* @param args
* (1) opensearch index name
* @return
* write sql query result to given opensearch index
*/
object FlintJob extends Logging {
def main(args: Array[String]): Unit = {
// Validate command line arguments
Expand All @@ -48,7 +49,7 @@ object FlintJob extends Logging {
val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index")
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var dataToWrite : Option[DataFrame] = None
var dataToWrite: Option[DataFrame] = None
try {
// flintClient needs spark session to be created first. Otherwise, we will have connection
// exception from EMR-S to OS.
Expand All @@ -58,7 +59,8 @@ object FlintJob extends Logging {
}
val data = executeQuery(spark, query, dataSource)

val (correctMapping, error) = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
val (correctMapping, error) =
ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(if (correctMapping) data else getFailedData(spark, dataSource, error))
} catch {
case e: TimeoutException =>
Expand Down Expand Up @@ -100,19 +102,15 @@ object FlintJob extends Logging {

/**
* Create a new formatted dataframe with json result, json schema and EMR_STEP_ID.
*
* @param result
* sql query result dataframe
* @param spark
* spark session
* @return
* dataframe with result, schema and emr step id
*/
def getFormattedData(
result: DataFrame,
spark: SparkSession,
dataSource: String
): DataFrame = {
*
* @param result
* sql query result dataframe
* @param spark
* spark session
* @return
* dataframe with result, schema and emr step id
*/
def getFormattedData(result: DataFrame, spark: SparkSession, dataSource: String): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Row(field.name, field.dataType.typeName)
Expand All @@ -122,31 +120,18 @@ object FlintJob extends Logging {
StructType(
Seq(
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false)
)
)
)
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(
Seq(
StructField(
"result",
ArrayType(StringType, containsNull = true),
nullable = true
),
StructField(
"schema",
ArrayType(StringType, containsNull = true),
nullable = true
),
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true)
)
)
StructField("error", StringType, nullable = true)))

// Create the data rows
val rows = Seq(
Expand All @@ -158,9 +143,7 @@ object FlintJob extends Logging {
sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
dataSource,
"SUCCESS",
""
)
)
""))

// Create the DataFrame for data
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
Expand All @@ -171,23 +154,13 @@ object FlintJob extends Logging {
// Define the data schema
val schema = StructType(
Seq(
StructField(
"result",
ArrayType(StringType, containsNull = true),
nullable = true
),
StructField(
"schema",
ArrayType(StringType, containsNull = true),
nullable = true
),
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true)
)
)
StructField("error", StringType, nullable = true)))

// Create the data rows
val rows = Seq(
Expand All @@ -198,9 +171,7 @@ object FlintJob extends Logging {
sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
dataSource,
"FAILED",
error
)
)
error))

// Create the DataFrame for data
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
Expand All @@ -210,22 +181,23 @@ object FlintJob extends Logging {

/**
* Determines whether one JSON structure is a superset of another.
*
* This method checks if the `input` JSON structure contains all the fields
* and values present in the `mapping` JSON structure. The comparison is
* recursive and structure-sensitive, ensuring that nested objects and arrays
* are also compared accurately.
*
* Additionally, this method accommodates the edge case where boolean values
* in the JSON are represented as strings (e.g., "true" or "false" instead of
* true or false). This is handled by performing a case-insensitive comparison
* of string representations of boolean values.
*
* @param input The input JSON structure as a String.
* @param mapping The mapping JSON structure as a String.
* @return A Boolean value indicating whether the `input` JSON structure
* is a superset of the `mapping` JSON structure.
*/
*
* This method checks if the `input` JSON structure contains all the fields and values present
* in the `mapping` JSON structure. The comparison is recursive and structure-sensitive,
* ensuring that nested objects and arrays are also compared accurately.
*
* Additionally, this method accommodates the edge case where boolean values in the JSON are
* represented as strings (e.g., "true" or "false" instead of true or false). This is handled
* by performing a case-insensitive comparison of string representations of boolean values.
*
* @param input
* The input JSON structure as a String.
* @param mapping
* The mapping JSON structure as a String.
* @return
* A Boolean value indicating whether the `input` JSON structure is a superset of the
* `mapping` JSON structure.
*/
def compareJson(inputJson: JsValue, mappingJson: JsValue): Boolean = {
(inputJson, mappingJson) match {
case (JsObject(inputFields), JsObject(mappingFields)) =>
Expand All @@ -238,18 +210,13 @@ object FlintJob extends Logging {
case (JsArray(inputValues), JsArray(mappingValues)) =>
logInfo(s"Comparing arrays: $inputValues vs $mappingValues")
mappingValues.forall(mappingValue =>
inputValues.exists(inputValue =>
compareJson(inputValue, mappingValue)
)
)
inputValues.exists(inputValue => compareJson(inputValue, mappingValue)))
case (JsString(inputValue), JsString(mappingValue))
if (inputValue.toLowerCase(Locale.ROOT) == "true" ||
inputValue.toLowerCase(Locale.ROOT) == "false") &&
(mappingValue.toLowerCase(Locale.ROOT) == "true" ||
mappingValue.toLowerCase(Locale.ROOT) == "false") =>
inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase(
Locale.ROOT
)
inputValue.toLowerCase(Locale.ROOT) == mappingValue.toLowerCase(Locale.ROOT)
case (JsBoolean(inputValue), JsString(mappingValue))
if mappingValue.toLowerCase(Locale.ROOT) == "true" ||
mappingValue.toLowerCase(Locale.ROOT) == "false" =>
Expand All @@ -271,10 +238,7 @@ object FlintJob extends Logging {
compareJson(inputJson, mappingJson)
}

def checkAndCreateIndex(
flintClient: FlintClient,
resultIndex: String
): (Boolean, String) = {
def checkAndCreateIndex(flintClient: FlintClient, resultIndex: String): (Boolean, String) = {
// The enabled setting, which can be applied only to the top-level mapping definition and to object fields,
val mapping =
"""{
Expand Down Expand Up @@ -314,7 +278,8 @@ object FlintJob extends Logging {
(true, "")
}
} catch {
case e: IllegalStateException if e.getCause().getMessage().contains("index_not_found_exception") =>
case e: IllegalStateException
if e.getCause().getMessage().contains("index_not_found_exception") =>
handleIndexNotFoundException(flintClient, resultIndex, mapping)
case e: Exception =>
val error = "Failed to verify existing mapping"
Expand All @@ -324,10 +289,9 @@ object FlintJob extends Logging {
}

def handleIndexNotFoundException(
flintClient: FlintClient,
resultIndex: String,
mapping: String
): (Boolean, String) = {
flintClient: FlintClient,
resultIndex: String,
mapping: String): (Boolean, String) = {
try {
logInfo(s"create $resultIndex")
flintClient.createIndex(resultIndex, new FlintMetadata(mapping))
Expand All @@ -340,11 +304,7 @@ object FlintJob extends Logging {
(false, error)
}
}
def executeQuery(
spark: SparkSession,
query: String,
dataSource: String
): DataFrame = {
def executeQuery(spark: SparkSession, query: String, dataSource: String): DataFrame = {
// Execute SQL query
val result: DataFrame = spark.sql(query)
// Get Data
Expand Down
Loading

0 comments on commit a6b9598

Please sign in to comment.