diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 4c99a61ca..f38a27ef4 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -20,12 +20,19 @@ import play.api.libs.json._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY import org.apache.spark.sql.types._ import org.apache.spark.sql.util._ +object SparkConfConstants { + val SQL_EXTENSIONS_KEY = "spark.sql.extensions" + val DEFAULT_SQL_EXTENSIONS = + "org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions" +} + trait FlintJobExecutor { this: Logging => @@ -90,11 +97,15 @@ trait FlintJobExecutor { }""".stripMargin def createSparkConf(): SparkConf = { - new SparkConf() - .setAppName(getClass.getSimpleName) - .set( - "spark.sql.extensions", - "org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions") + val conf = new SparkConf().setAppName(getClass.getSimpleName) + + if (!conf.contains(SQL_EXTENSIONS_KEY)) { + conf.set(SQL_EXTENSIONS_KEY, DEFAULT_SQL_EXTENSIONS) + } + + logInfo(s"Value of $SQL_EXTENSIONS_KEY: ${conf.get(SQL_EXTENSIONS_KEY)}") + + conf } /* diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index c1c32c492..d8ddcb665 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -30,6 +30,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.SparkListenerApplicationEnd import org.apache.spark.sql.FlintREPL.PreShutdownListener +import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.flint.config.FlintSparkConf @@ -129,6 +130,31 @@ class FlintREPLTest query shouldBe "" } + test("createSparkConf should set the app name and default SQL extensions") { + val conf = FlintREPL.createSparkConf() + + // Assert that the app name is set correctly + assert(conf.get("spark.app.name") === "FlintREPL$") + + // Assert that the default SQL extensions are set correctly + assert(conf.get(SQL_EXTENSIONS_KEY) === DEFAULT_SQL_EXTENSIONS) + } + + test( + "createSparkConf should not use defaultExtensions if spark.sql.extensions is already set") { + val customExtension = "my.custom.extension" + // Set the spark.sql.extensions property before calling createSparkConf + System.setProperty(SQL_EXTENSIONS_KEY, customExtension) + + try { + val conf = FlintREPL.createSparkConf() + assert(conf.get(SQL_EXTENSIONS_KEY) === customExtension) + } finally { + // Clean up the system property after the test + System.clearProperty(SQL_EXTENSIONS_KEY) + } + } + test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks val flintSessionUpdater = mock[OpenSearchUpdater]