Skip to content

Commit

Permalink
Refactor FlintJob with FlintStatement and StatementExecutionManager
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Sep 9, 2024
1 parent 88ad15f commit 223c619
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ trait StatementExecutionManager {
/**
* Terminates the statement lifecycle.
*/
def terminateStatementsExecution(): Unit
def terminateStatementExecution(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ object FlintSparkConf {
FlintConfig("spark.flint.job.query")
.doc("Flint query for batch and streaming job")
.createOptional()
val QUERY_ID =
FlintConfig("spark.flint.job.queryId")
.doc("Flint query id for batch and streaming job")
.createOptional()
val JOB_TYPE =
FlintConfig(s"spark.flint.job.type")
.doc("Flint job type. Including interactive and streaming")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val resultIndex = "query_results2"
val appId = "00feq82b752mbt0p"
val dataSourceName = "my_glue1"
val queryId = "testQueryId"
var osClient: OSClient = _
val threadLocalFuture = new ThreadLocal[Future[Unit]]()

Expand Down Expand Up @@ -103,9 +104,10 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
jobRunId,
spark,
query,
queryId,
dataSourceName,
resultIndex,
true,
"streaming",
streamingRunningCount)
job.terminateJVM = false
job.start()
Expand Down Expand Up @@ -144,7 +146,6 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {

assert(result.status == "SUCCESS", s"expected status is SUCCESS, but got ${result.status}")
assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}")
assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}")

commonAssert(result, jobRunId, query, queryStartTime)
true
Expand Down Expand Up @@ -362,7 +363,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
result.queryRunTime < System.currentTimeMillis() - queryStartTime,
s"expected query run time ${result.queryRunTime} should be less than ${System
.currentTimeMillis() - queryStartTime}, but it is not")
assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}")
assert(result.queryId == queryId, s"expected query id is ${queryId}, but got ${result.queryId}")
}

def pollForResultAndAssert(expected: REPLResult => Boolean, jobId: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ case class CommandContext(
jobId: String,
spark: SparkSession,
dataSource: String,
jobType: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ import java.util.concurrent.atomic.AtomicInteger
import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
import play.api.libs.json._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types._

/**
* Spark SQL Application entrypoint
Expand All @@ -41,6 +39,8 @@ object FlintJob extends Logging with FlintJobExecutor {
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
Expand All @@ -66,9 +66,10 @@ object FlintJob extends Logging with FlintJobExecutor {
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
jobType.equalsIgnoreCase("streaming"),
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._
import org.apache.spark.util.Utils

object SparkConfConstants {
val SQL_EXTENSIONS_KEY = "spark.sql.extensions"
Expand Down Expand Up @@ -524,4 +525,25 @@ trait FlintJobExecutor {
CustomLogging.logError(t)
throw t
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
defaultConstructor
} else {
try {
val classObject = Utils.classForName(className)
val ctor = if (args.isEmpty) {
classObject.getDeclaredConstructor()
} else {
classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*)
}
ctor.setAccessible(true)
ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.FlintREPLConfConstants._
import org.apache.spark.sql.SessionUpdateMode._
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.util.ThreadUtils

object FlintREPLConfConstants {
val HEARTBEAT_INTERVAL_MILLIS = 60000L
Expand Down Expand Up @@ -87,6 +87,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val query = getQuery(queryOption, jobType, conf)
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (jobType.equalsIgnoreCase("streaming")) {
if (resultIndexOption.isEmpty) {
Expand All @@ -100,9 +101,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
true,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
Expand Down Expand Up @@ -174,6 +176,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
spark,
dataSource,
jobType,
sessionId,
sessionManager,
queryResultWriter,
Expand Down Expand Up @@ -352,7 +355,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
canPickUpNextStatement = updatedCanPickUpNextStatement
lastCanPickCheckTime = updatedLastCanPickCheckTime
} finally {
statementsExecutionManager.terminateStatementsExecution()
statementsExecutionManager.terminateStatementExecution()
}

Thread.sleep(commandContext.queryLoopExecutionFrequency)
Expand Down Expand Up @@ -975,26 +978,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
defaultConstructor
} else {
try {
val classObject = Utils.classForName(className)
val ctor = if (args.isEmpty) {
classObject.getDeclaredConstructor()
} else {
classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*)
}
ctor.setAccessible(true)
ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

private def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
Expand Down
Loading

0 comments on commit 223c619

Please sign in to comment.