From 026b36257c79d1d028b7e3c9f6665e73db4925bd Mon Sep 17 00:00:00 2001 From: Sean Kao Date: Wed, 10 Apr 2024 14:02:22 -0700 Subject: [PATCH] unescape query from EMR spark submit parameter Signed-off-by: Sean Kao --- .../src/main/scala/org/apache/spark/sql/FlintJob.scala | 3 ++- .../scala/org/apache/spark/sql/FlintJobExecutor.scala | 10 ++++++++++ .../main/scala/org/apache/spark/sql/FlintREPL.scala | 2 +- .../scala/org/apache/spark/sql/FlintREPLTest.scala | 10 ++++++++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 8b4bdeeaf..4f17b4c9a 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -56,7 +56,8 @@ object FlintJob extends Logging with FlintJobExecutor { conf.set(FlintSparkConf.JOB_TYPE.key, jobType) val dataSource = conf.get("spark.flint.datasource.name", "") - val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, "")) + val query = queryOption.getOrElse( + unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) if (query.isEmpty) { throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.") } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index ccd5c8f3f..bd2ecd7af 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -11,6 +11,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception +import org.apache.commons.text.StringEscapeUtils.unescapeJava import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient} import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.core.metrics.MetricConstants @@ -361,6 +362,14 @@ trait FlintJobExecutor { } } + /** + * Unescape the query string which is escaped for EMR spark submit parameter parsing. + * Ref: https://github.com/opensearch-project/sql/pull/2587 + */ + def unescapeQuery(query: String): String = { + unescapeJava(query) + } + def executeQuery( spark: SparkSession, query: String, @@ -371,6 +380,7 @@ trait FlintJobExecutor { val startTime = System.currentTimeMillis() // we have to set job group in the same thread that started the query according to spark doc spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true) + logInfo(s"Executing query: $query") val result: DataFrame = spark.sql(query) // Get Data getFormattedData( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 76e5f692c..69b655e57 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -251,7 +251,7 @@ object FlintREPL extends Logging with FlintJobExecutor { if (defaultQuery.isEmpty) { throw new IllegalArgumentException("Query undefined for the streaming job.") } - defaultQuery + unescapeQuery(defaultQuery) } else "" } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 421457c4e..a04a41227 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -95,6 +95,16 @@ class FlintREPLTest query shouldBe "SELECT * FROM table" } + test("getQuery should return unescaped default query for streaming job if queryOption is None") { + val queryOption = None + val jobType = "streaming" + val conf = new SparkConf().set( + FlintSparkConf.QUERY.key, "SELECT \\\"1\\\" UNION SELECT '\\\"1\\\"' UNION SELECT \\\"\\\\\\\"1\\\\\\\"\\\"") + + val query = FlintREPL.getQuery(queryOption, jobType, conf) + query shouldBe "SELECT \"1\" UNION SELECT '\"1\"' UNION SELECT \"\\\"1\\\"\"" + } + test( "getQuery should throw IllegalArgumentException if queryOption is None and default query is not defined for streaming job") { val queryOption = None