Skip to content

Commit

Permalink
Support custom extension conf (#438)
Browse files Browse the repository at this point in the history
* Support extension as conf

Signed-off-by: Louis Chu <[email protected]>

* Fix UT and add constants

Signed-off-by: Louis Chu <[email protected]>

---------

Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger authored Jul 22, 2024
1 parent 26759b9 commit 153b48f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@ import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._

object SparkConfConstants {
val SQL_EXTENSIONS_KEY = "spark.sql.extensions"
val DEFAULT_SQL_EXTENSIONS =
"org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions"
}

trait FlintJobExecutor {
this: Logging =>

Expand Down Expand Up @@ -90,11 +97,15 @@ trait FlintJobExecutor {
}""".stripMargin

def createSparkConf(): SparkConf = {
new SparkConf()
.setAppName(getClass.getSimpleName)
.set(
"spark.sql.extensions",
"org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions")
val conf = new SparkConf().setAppName(getClass.getSimpleName)

if (!conf.contains(SQL_EXTENSIONS_KEY)) {
conf.set(SQL_EXTENSIONS_KEY, DEFAULT_SQL_EXTENSIONS)
}

logInfo(s"Value of $SQL_EXTENSIONS_KEY: ${conf.get(SQL_EXTENSIONS_KEY)}")

conf
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.SparkListenerApplicationEnd
import org.apache.spark.sql.FlintREPL.PreShutdownListener
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.flint.config.FlintSparkConf
Expand Down Expand Up @@ -129,6 +130,31 @@ class FlintREPLTest
query shouldBe ""
}

test("createSparkConf should set the app name and default SQL extensions") {
val conf = FlintREPL.createSparkConf()

// Assert that the app name is set correctly
assert(conf.get("spark.app.name") === "FlintREPL$")

// Assert that the default SQL extensions are set correctly
assert(conf.get(SQL_EXTENSIONS_KEY) === DEFAULT_SQL_EXTENSIONS)
}

test(
"createSparkConf should not use defaultExtensions if spark.sql.extensions is already set") {
val customExtension = "my.custom.extension"
// Set the spark.sql.extensions property before calling createSparkConf
System.setProperty(SQL_EXTENSIONS_KEY, customExtension)

try {
val conf = FlintREPL.createSparkConf()
assert(conf.get(SQL_EXTENSIONS_KEY) === customExtension)
} finally {
// Clean up the system property after the test
System.clearProperty(SQL_EXTENSIONS_KEY)
}
}

test("createHeartBeatUpdater should update heartbeat correctly") {
// Mocks
val flintSessionUpdater = mock[OpenSearchUpdater]
Expand Down

0 comments on commit 153b48f

Please sign in to comment.