diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 8da7d2072..2e77dad44 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -25,22 +25,23 @@ import org.apache.spark.util.ThreadUtils * Spark SQL Application entrypoint * * @param args - * (0) sql query + * (0) spark extensions (Flint, PPL) + * (1) sql query * @param args - * (1) opensearch index name + * (2) opensearch index name * @return * write sql query result to given opensearch index */ object FlintJob extends Logging { def main(args: Array[String]): Unit = { // Validate command line arguments - if (args.length != 2) { - throw new IllegalArgumentException("Usage: FlintJob ") + if (args.length != 3) { + throw new IllegalArgumentException("Usage: FlintJob ") } - val Array(query, resultIndex) = args + val Array(sparkExtensions, query, resultIndex) = args - val conf = createSparkConf() + val conf = createSparkConf(sparkExtensions) val wait = conf.get("spark.flint.job.type", "continue") val dataSource = conf.get("spark.flint.datasource.name", "") val spark = createSparkSession(conf) @@ -82,10 +83,10 @@ object FlintJob extends Logging { } } - def createSparkConf(): SparkConf = { + def createSparkConf(sparkExtensions: String): SparkConf = { new SparkConf() .setAppName("FlintJob") - .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") + .set("spark.sql.extensions", sparkExtensions) } def createSparkSession(conf: SparkConf): SparkSession = {