Skip to content

Commit

Permalink
Move query from entry point to SparkConf (#274)
Browse files Browse the repository at this point in the history
* Refactor query input

Signed-off-by: Louis Chu <[email protected]>

* Support backword compatiblity

Signed-off-by: Louis Chu <[email protected]>

* Add more UTs

Signed-off-by: Louis Chu <[email protected]>

---------

Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger authored Mar 14, 2024
1 parent 8cdc171 commit 9ffe857
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.datasource.name")
.doc("data source name")
.createOptional()
val QUERY =
FlintConfig("spark.flint.job.query")
.doc("Flint query for batch and streaming job")
.createOptional()
val JOB_TYPE =
FlintConfig(s"spark.flint.job.type")
.doc("Flint job type. Including interactive and streaming")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
FlintREPL.enableHiveSupport = false
FlintREPL.terminateJVM = false
FlintREPL.main(Array("select 1", resultIndex))
FlintREPL.main(Array(resultIndex))
}
futureResult.onComplete {
case Success(result) => logInfo(s"Success result: $result")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,29 @@ import org.apache.spark.sql.types.{StructField, _}
*/
object FlintJob extends Logging with FlintJobExecutor {
def main(args: Array[String]): Unit = {
// Validate command line arguments
if (args.length != 2) {
throw new IllegalArgumentException("Usage: FlintJob <query> <resultIndex>")
val (queryOption, resultIndex) = args.length match {
case 1 =>
(None, args(0)) // Starting from OS 2.13, resultIndex is the only argument
case 2 =>
(
Some(args(0)),
args(1)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ =>
throw new IllegalArgumentException(
"Unsupported number of arguments. Expected 1 or 2 arguments.")
}

val Array(query, resultIndex) = args

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", "batch")
logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, ""))
if (query.isEmpty) {
throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
private val statementRunningCount = new AtomicInteger(0)

def main(args: Array[String]) {
val Array(query, resultIndex) = args
val (queryOption, resultIndex) = parseArgs(args)

if (Strings.isNullOrEmpty(resultIndex)) {
throw new IllegalArgumentException("resultIndex is not set")
}
Expand All @@ -90,6 +91,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val query = getQuery(queryOption, jobType, conf)

if (jobType.equalsIgnoreCase("streaming")) {
logInfo(s"""streaming query ${query}""")
val streamingRunningCount = new AtomicInteger(0)
Expand Down Expand Up @@ -228,6 +231,33 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
args.length match {
case 1 =>
(None, args(0)) // Starting from OS 2.13, resultIndex is the only argument
case 2 =>
(
Some(args(0)),
args(1)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ =>
throw new IllegalArgumentException(
"Unsupported number of arguments. Expected 1 or 2 arguments.")
}
}

def getQuery(queryOption: Option[String], jobType: String, conf: SparkConf): String = {
queryOption.getOrElse {
if (jobType.equalsIgnoreCase("streaming")) {
val defaultQuery = conf.get(FlintSparkConf.QUERY.key, "")
if (defaultQuery.isEmpty) {
throw new IllegalArgumentException("Query undefined for the streaming job.")
}
defaultQuery
} else ""
}
}

/**
* Sets up a Flint job with exclusion checks based on the job configuration.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, StructType}
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait}
import org.apache.spark.util.ThreadUtils
Expand All @@ -42,6 +43,78 @@ class FlintREPLTest
// By using a type alias and casting, I can bypass the type checking error.
type AnyScheduledFuture = ScheduledFuture[_]

test(
"parseArgs with one argument should return None for query and the argument as resultIndex") {
val args = Array("resultIndexName")
val (queryOption, resultIndex) = FlintREPL.parseArgs(args)
queryOption shouldBe None
resultIndex shouldBe "resultIndexName"
}

test(
"parseArgs with two arguments should return the first argument as query and the second as resultIndex") {
val args = Array("SELECT * FROM table", "resultIndexName")
val (queryOption, resultIndex) = FlintREPL.parseArgs(args)
queryOption shouldBe Some("SELECT * FROM table")
resultIndex shouldBe "resultIndexName"
}

test(
"parseArgs with no arguments should throw IllegalArgumentException with specific message") {
val args = Array.empty[String]
val exception = intercept[IllegalArgumentException] {
FlintREPL.parseArgs(args)
}
exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments."
}

test(
"parseArgs with more than two arguments should throw IllegalArgumentException with specific message") {
val args = Array("arg1", "arg2", "arg3")
val exception = intercept[IllegalArgumentException] {
FlintREPL.parseArgs(args)
}
exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments."
}

test("getQuery should return query from queryOption if present") {
val queryOption = Some("SELECT * FROM table")
val jobType = "streaming"
val conf = new SparkConf()

val query = FlintREPL.getQuery(queryOption, jobType, conf)
query shouldBe "SELECT * FROM table"
}

test("getQuery should return default query for streaming job if queryOption is None") {
val queryOption = None
val jobType = "streaming"
val conf = new SparkConf().set(FlintSparkConf.QUERY.key, "SELECT * FROM table")

val query = FlintREPL.getQuery(queryOption, jobType, conf)
query shouldBe "SELECT * FROM table"
}

test(
"getQuery should throw IllegalArgumentException if queryOption is None and default query is not defined for streaming job") {
val queryOption = None
val jobType = "streaming"
val conf = new SparkConf() // Default query not set

intercept[IllegalArgumentException] {
FlintREPL.getQuery(queryOption, jobType, conf)
}.getMessage shouldBe "Query undefined for the streaming job."
}

test("getQuery should return empty string for non-streaming job if queryOption is None") {
val queryOption = None
val jobType = "interactive"
val conf = new SparkConf() // Default query not needed

val query = FlintREPL.getQuery(queryOption, jobType, conf)
query shouldBe ""
}

test("createHeartBeatUpdater should update heartbeat correctly") {
// Mocks
val flintSessionUpdater = mock[OpenSearchUpdater]
Expand Down

0 comments on commit 9ffe857

Please sign in to comment.