Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move query from entry point to SparkConf #274

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.datasource.name")
.doc("data source name")
.createOptional()
val QUERY =
FlintConfig("spark.flint.job.query")
.doc("Flint query for batch and streaming job")
.createOptional()
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
val JOB_TYPE =
FlintConfig(s"spark.flint.job.type")
.doc("Flint job type. Including interactive and streaming")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
FlintREPL.enableHiveSupport = false
FlintREPL.terminateJVM = false
FlintREPL.main(Array("select 1", resultIndex))
FlintREPL.main(Array(resultIndex))
}
futureResult.onComplete {
case Success(result) => logInfo(s"Success result: $result")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,29 @@ import org.apache.spark.sql.types.{StructField, _}
*/
object FlintJob extends Logging with FlintJobExecutor {
def main(args: Array[String]): Unit = {
// Validate command line arguments
if (args.length != 2) {
throw new IllegalArgumentException("Usage: FlintJob <query> <resultIndex>")
val (queryOption, resultIndex) = args.length match {
case 1 =>
(None, args(0)) // Starting from OS 2.13, resultIndex is the only argument
case 2 =>
(
Some(args(0)),
args(1)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ =>
throw new IllegalArgumentException(
"Unsupported number of arguments. Expected 1 or 2 arguments.")
}

val Array(query, resultIndex) = args

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", "batch")
logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, ""))
if (query.isEmpty) {
throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
private val statementRunningCount = new AtomicInteger(0)

def main(args: Array[String]) {
val Array(query, resultIndex) = args
val (queryOption, resultIndex) = parseArgs(args)

if (Strings.isNullOrEmpty(resultIndex)) {
throw new IllegalArgumentException("resultIndex is not set")
}
Expand All @@ -90,6 +91,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val query = getQuery(queryOption, jobType, conf)

if (jobType.equalsIgnoreCase("streaming")) {
logInfo(s"""streaming query ${query}""")
val streamingRunningCount = new AtomicInteger(0)
Expand Down Expand Up @@ -228,6 +231,33 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
args.length match {
case 1 =>
(None, args(0)) // Starting from OS 2.13, resultIndex is the only argument
case 2 =>
(
Some(args(0)),
args(1)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ =>
throw new IllegalArgumentException(
"Unsupported number of arguments. Expected 1 or 2 arguments.")
}
}

def getQuery(queryOption: Option[String], jobType: String, conf: SparkConf): String = {
queryOption.getOrElse {
if (jobType.equalsIgnoreCase("streaming")) {
val defaultQuery = conf.get(FlintSparkConf.QUERY.key, "")
if (defaultQuery.isEmpty) {
throw new IllegalArgumentException("Query undefined for the streaming job.")
}
defaultQuery
} else ""
}
}

/**
* Sets up a Flint job with exclusion checks based on the job configuration.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, StructType}
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait}
import org.apache.spark.util.ThreadUtils
Expand All @@ -42,6 +43,78 @@ class FlintREPLTest
// By using a type alias and casting, I can bypass the type checking error.
type AnyScheduledFuture = ScheduledFuture[_]

test(
"parseArgs with one argument should return None for query and the argument as resultIndex") {
val args = Array("resultIndexName")
val (queryOption, resultIndex) = FlintREPL.parseArgs(args)
queryOption shouldBe None
resultIndex shouldBe "resultIndexName"
}

test(
"parseArgs with two arguments should return the first argument as query and the second as resultIndex") {
val args = Array("SELECT * FROM table", "resultIndexName")
val (queryOption, resultIndex) = FlintREPL.parseArgs(args)
queryOption shouldBe Some("SELECT * FROM table")
resultIndex shouldBe "resultIndexName"
}

test(
"parseArgs with no arguments should throw IllegalArgumentException with specific message") {
val args = Array.empty[String]
val exception = intercept[IllegalArgumentException] {
FlintREPL.parseArgs(args)
}
exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments."
}

test(
"parseArgs with more than two arguments should throw IllegalArgumentException with specific message") {
val args = Array("arg1", "arg2", "arg3")
val exception = intercept[IllegalArgumentException] {
FlintREPL.parseArgs(args)
}
exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments."
}

test("getQuery should return query from queryOption if present") {
val queryOption = Some("SELECT * FROM table")
val jobType = "streaming"
val conf = new SparkConf()

val query = FlintREPL.getQuery(queryOption, jobType, conf)
query shouldBe "SELECT * FROM table"
}

test("getQuery should return default query for streaming job if queryOption is None") {
val queryOption = None
val jobType = "streaming"
val conf = new SparkConf().set(FlintSparkConf.QUERY.key, "SELECT * FROM table")

val query = FlintREPL.getQuery(queryOption, jobType, conf)
query shouldBe "SELECT * FROM table"
}

test(
"getQuery should throw IllegalArgumentException if queryOption is None and default query is not defined for streaming job") {
val queryOption = None
val jobType = "streaming"
val conf = new SparkConf() // Default query not set

intercept[IllegalArgumentException] {
FlintREPL.getQuery(queryOption, jobType, conf)
}.getMessage shouldBe "Query undefined for the streaming job."
}

test("getQuery should return empty string for non-streaming job if queryOption is None") {
val queryOption = None
val jobType = "interactive"
val conf = new SparkConf() // Default query not needed

val query = FlintREPL.getQuery(queryOption, jobType, conf)
query shouldBe ""
}

test("createHeartBeatUpdater should update heartbeat correctly") {
// Mocks
val flintSessionUpdater = mock[OpenSearchUpdater]
Expand Down
Loading