Skip to content

Commit

Permalink
unescape query from EMR spark submit parameter
Browse files Browse the repository at this point in the history
Signed-off-by: Sean Kao <[email protected]>
  • Loading branch information
seankao-az committed Apr 10, 2024
1 parent 77d0078 commit 026b362
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ object FlintJob extends Logging with FlintJobExecutor {
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, ""))
val query = queryOption.getOrElse(
unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import com.amazonaws.services.s3.model.AmazonS3Exception
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.metrics.MetricConstants
Expand Down Expand Up @@ -361,6 +362,14 @@ trait FlintJobExecutor {
}
}

/**
* Unescape the query string which is escaped for EMR spark submit parameter parsing.
* Ref: https://github.com/opensearch-project/sql/pull/2587
*/
def unescapeQuery(query: String): String = {
unescapeJava(query)
}

def executeQuery(
spark: SparkSession,
query: String,
Expand All @@ -371,6 +380,7 @@ trait FlintJobExecutor {
val startTime = System.currentTimeMillis()
// we have to set job group in the same thread that started the query according to spark doc
spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true)
logInfo(s"Executing query: $query")
val result: DataFrame = spark.sql(query)
// Get Data
getFormattedData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
if (defaultQuery.isEmpty) {
throw new IllegalArgumentException("Query undefined for the streaming job.")
}
defaultQuery
unescapeQuery(defaultQuery)
} else ""
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ class FlintREPLTest
query shouldBe "SELECT * FROM table"
}

test("getQuery should return unescaped default query for streaming job if queryOption is None") {
val queryOption = None
val jobType = "streaming"
val conf = new SparkConf().set(
FlintSparkConf.QUERY.key, "SELECT \\\"1\\\" UNION SELECT '\\\"1\\\"' UNION SELECT \\\"\\\\\\\"1\\\\\\\"\\\"")

val query = FlintREPL.getQuery(queryOption, jobType, conf)
query shouldBe "SELECT \"1\" UNION SELECT '\"1\"' UNION SELECT \"\\\"1\\\"\""
}

test(
"getQuery should throw IllegalArgumentException if queryOption is None and default query is not defined for streaming job") {
val queryOption = None
Expand Down

0 comments on commit 026b362

Please sign in to comment.