Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor FlintJob with FlintStatement and StatementExecutionManager #635

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ trait StatementExecutionManager {
/**
* Terminates the statement lifecycle.
*/
def terminateStatementsExecution(): Unit
def terminateStatementExecution(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ object FlintSparkConf {
FlintConfig("spark.flint.job.query")
.doc("Flint query for batch and streaming job")
.createOptional()
val QUERY_ID =
FlintConfig("spark.flint.job.queryId")
.doc("Flint query id for batch and streaming job")
.createOptional()
val JOB_TYPE =
FlintConfig(s"spark.flint.job.type")
.doc("Flint job type. Including interactive and streaming")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val resultIndex = "query_results2"
val appId = "00feq82b752mbt0p"
val dataSourceName = "my_glue1"
val queryId = "testQueryId"
var osClient: OSClient = _
val threadLocalFuture = new ThreadLocal[Future[Unit]]()

Expand Down Expand Up @@ -91,7 +92,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
* all Spark conf required by Flint code underlying manually.
*/
spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName)
spark.conf.set(JOB_TYPE.key, "streaming")
spark.conf.set(JOB_TYPE.key, FlintJobType.STREAMING)

/**
* FlintJob.main() is not called because we need to manually set these variables within a
Expand All @@ -103,9 +104,10 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
jobRunId,
spark,
query,
queryId,
dataSourceName,
resultIndex,
true,
FlintJobType.STREAMING,
streamingRunningCount)
job.terminateJVM = false
job.start()
Expand Down Expand Up @@ -144,7 +146,6 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {

assert(result.status == "SUCCESS", s"expected status is SUCCESS, but got ${result.status}")
assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}")
assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}")

commonAssert(result, jobRunId, query, queryStartTime)
true
Expand Down Expand Up @@ -362,7 +363,9 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
result.queryRunTime < System.currentTimeMillis() - queryStartTime,
s"expected query run time ${result.queryRunTime} should be less than ${System
.currentTimeMillis() - queryStartTime}, but it is not")
assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}")
assert(
result.queryId == queryId,
s"expected query id is ${queryId}, but got ${result.queryId}")
}

def pollForResultAndAssert(expected: REPLResult => Boolean, jobId: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ case class CommandContext(
jobId: String,
spark: SparkSession,
dataSource: String,
jobType: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ import java.util.concurrent.atomic.AtomicInteger
import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
import play.api.libs.json._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types._

/**
* Spark SQL Application entrypoint
Expand All @@ -32,7 +30,7 @@ object FlintJob extends Logging with FlintJobExecutor {
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", "batch")
val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

Expand All @@ -41,6 +39,8 @@ object FlintJob extends Logging with FlintJobExecutor {
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
Expand All @@ -66,9 +66,10 @@ object FlintJob extends Logging with FlintJobExecutor {
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
jobType.equalsIgnoreCase("streaming"),
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._
import org.apache.spark.util.Utils

object SparkConfConstants {
val SQL_EXTENSIONS_KEY = "spark.sql.extensions"
val DEFAULT_SQL_EXTENSIONS =
"org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions"
}

object FlintJobType {
val INTERACTIVE = "interactive"
val BATCH = "batch"
val STREAMING = "streaming"
}

trait FlintJobExecutor {
this: Logging =>

Expand Down Expand Up @@ -131,7 +138,7 @@ trait FlintJobExecutor {
* https://github.com/opensearch-project/opensearch-spark/issues/324
*/
def configDYNMaxExecutors(conf: SparkConf, jobType: String): Unit = {
if (jobType.equalsIgnoreCase("streaming")) {
if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) {
conf.set(
"spark.dynamicAllocation.maxExecutors",
conf
Expand Down Expand Up @@ -524,4 +531,25 @@ trait FlintJobExecutor {
CustomLogging.logError(t)
throw t
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
defaultConstructor
} else {
try {
val classObject = Utils.classForName(className)
val ctor = if (args.isEmpty) {
classObject.getDeclaredConstructor()
} else {
classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*)
}
ctor.setAccessible(true)
ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.FlintREPLConfConstants._
import org.apache.spark.sql.SessionUpdateMode._
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.util.ThreadUtils

object FlintREPLConfConstants {
val HEARTBEAT_INTERVAL_MILLIS = 60000L
Expand Down Expand Up @@ -87,8 +87,9 @@ object FlintREPL extends Logging with FlintJobExecutor {
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val query = getQuery(queryOption, jobType, conf)
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (jobType.equalsIgnoreCase("streaming")) {
if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) {
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
Expand All @@ -100,9 +101,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
true,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
Expand Down Expand Up @@ -174,6 +176,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
spark,
dataSource,
jobType,
sessionId,
sessionManager,
queryResultWriter,
Expand Down Expand Up @@ -220,7 +223,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

def getQuery(queryOption: Option[String], jobType: String, conf: SparkConf): String = {
queryOption.getOrElse {
if (jobType.equalsIgnoreCase("streaming")) {
if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) {
val defaultQuery = conf.get(FlintSparkConf.QUERY.key, "")
if (defaultQuery.isEmpty) {
logAndThrow("Query undefined for the streaming job.")
Expand Down Expand Up @@ -352,7 +355,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
canPickUpNextStatement = updatedCanPickUpNextStatement
lastCanPickCheckTime = updatedLastCanPickCheckTime
} finally {
statementsExecutionManager.terminateStatementsExecution()
statementsExecutionManager.terminateStatementExecution()
}

Thread.sleep(commandContext.queryLoopExecutionFrequency)
Expand Down Expand Up @@ -975,26 +978,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
defaultConstructor
} else {
try {
val classObject = Utils.classForName(className)
val ctor = if (args.isEmpty) {
classObject.getDeclaredConstructor()
} else {
classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*)
}
ctor.setAccessible(true)
ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

private def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
Expand Down
Loading
Loading