Skip to content

Commit

Permalink
Refactor query result writer
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Sep 25, 2024
1 parent 0b248ec commit dda7427
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,12 @@ trait QueryResultWriter {
* data storage based on the provided FlintStatement metadata.
*/
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit

/**
* Reformat the given DataFrame to the desired format.
*/
def reformatDataFrame(
dataFrame: DataFrame,
flintStatement: FlintStatement,
queryStartTime: Long): DataFrame
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ case class CommandContext(
jobType: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.common.Strings
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
import org.opensearch.flint.core.metrics.MetricConstants
Expand Down Expand Up @@ -533,7 +534,7 @@ trait FlintJobExecutor {
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
if (Strings.isNullOrEmpty(className)) {
defaultConstructor
} else {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
return
}

val queryResultWriter =
instantiateQueryResultWriter(conf, sessionManager.getSessionContext)
val commandContext = CommandContext(
applicationId,
jobId,
Expand All @@ -179,7 +177,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobType,
sessionId,
sessionManager,
queryResultWriter,
queryExecutionTimeoutSecs,
inactivityLimitMillis,
queryWaitTimeoutMillis,
Expand Down Expand Up @@ -316,7 +313,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
// 1 thread for async query execution
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

val queryResultWriter = instantiateQueryResultWriter(spark, commandContext)
var futurePrepareQueryExecution: Future[Either[String, Unit]] = null
try {
logInfo(s"""Executing session with sessionId: ${sessionId}""")
Expand All @@ -342,7 +339,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
executionContext,
lastCanPickCheckTime)
val result: (Long, VerificationResult, Boolean, Long) =
processCommands(statementsExecutionManager, commandContext, commandState)
processCommands(
statementsExecutionManager,
queryResultWriter,
commandContext,
commandState)

val (
updatedLastActivityTime,
Expand Down Expand Up @@ -491,6 +492,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processCommands(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
context: CommandContext,
state: CommandState): (Long, VerificationResult, Boolean, Long) = {
import context._
Expand Down Expand Up @@ -525,14 +527,15 @@ object FlintREPL extends Logging with FlintJobExecutor {
val (dataToWrite, returnedVerificationResult) =
processStatementOnVerification(
statementExecutionManager,
queryResultWriter,
flintStatement,
state,
context)

verificationResult = returnedVerificationResult
finalizeCommand(
statementExecutionManager,
context,
queryResultWriter,
dataToWrite,
flintStatement,
statementTimerContext)
Expand All @@ -558,11 +561,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
*/
private def finalizeCommand(
statementExecutionManager: StatementExecutionManager,
commandContext: CommandContext,
queryResultWriter: QueryResultWriter,
dataToWrite: Option[DataFrame],
flintStatement: FlintStatement,
statementTimerContext: Timer.Context): Unit = {
import commandContext._
try {
dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement))
if (flintStatement.isRunning || flintStatement.isWaiting) {
Expand Down Expand Up @@ -626,6 +628,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -640,6 +643,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -677,6 +681,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processStatementOnVerification(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
flintStatement: FlintStatement,
commandState: CommandState,
commandContext: CommandContext) = {
Expand All @@ -698,6 +703,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -764,6 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand All @@ -782,6 +789,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementsExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -801,7 +809,9 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime)
} else {
val futureQueryExecution = Future {
statementsExecutionManager.executeStatement(flintStatement)
val startTime = System.currentTimeMillis()
val df = statementsExecutionManager.executeStatement(flintStatement)
queryResultWriter.reformatDataFrame(df, flintStatement, startTime)
}(executionContext)
// time out after 10 minutes
ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut)
Expand Down Expand Up @@ -998,11 +1008,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
}

private def instantiateQueryResultWriter(
sparkConf: SparkConf,
context: Map[String, Any]): QueryResultWriter = {
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(context),
sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ case class JobOperator(
jobType,
"", // FlintJob doesn't have sessionId
null, // FlintJob doesn't have SessionManager
null, // FlintJob doesn't have QueryResultWriter
Duration.Inf, // FlintJob doesn't have queryExecutionTimeout
-1, // FlintJob doesn't have inactivityLimitMillis
-1, // FlintJob doesn't have queryWaitTimeMillis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,41 @@ import org.opensearch.flint.common.model.FlintStatement
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.CleanerFactory

class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging {
class QueryResultWriterImpl(commandContext: CommandContext)
extends QueryResultWriter
with FlintJobExecutor
with Logging {

private val context = commandContext.sessionManager.getSessionContext
private val resultIndex = context("resultIndex").asInstanceOf[String]
// Initialize OSClient with Flint options because custom session manager implementation should not have it in the context
private val osClient = new OSClient(FlintSparkConf().flintOptions())

override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = {
writeDataFrameToOpensearch(dataFrame, resultIndex, osClient)
}

/**
* Reformat the given DataFrame to the desired format.
*/
override def reformatDataFrame(
dataFrame: DataFrame,
statement: FlintStatement,
queryStartTime: Long): DataFrame = {
import commandContext._
getFormattedData(
applicationId,
jobId,
dataFrame,
spark,
dataSource,
statement.queryId,
statement.query,
sessionId,
queryStartTime,
currentTimeProvider,
CleanerFactory.cleaner(false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,13 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
}

override def executeStatement(statement: FlintStatement): DataFrame = {
import commandContext._
executeQuery(
applicationId,
jobId,
spark,
statement.query,
dataSource,
import commandContext.spark
// we have to set job group in the same thread that started the query according to spark doc
spark.sparkContext.setJobGroup(
statement.queryId,
sessionId,
false)
"Job group for " + statement.queryId,
interruptOnCancel = true)
spark.sql(statement.query)
}

private def createOpenSearchQueryReader() = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
60,
60,
Expand Down Expand Up @@ -748,7 +747,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
60,
60,
Expand All @@ -761,6 +759,7 @@ class FlintREPLTest
mockSparkSession,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -809,7 +808,6 @@ class FlintREPLTest
when(mockSparkSession.sparkContext).thenReturn(sparkContext)

// Assume handleQueryException logs the error and returns an error message string
val mockErrorString = "Error due to syntax"
when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]]))
.thenReturn(expectedDataFrame)
when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame)
Expand All @@ -824,7 +822,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
60,
60,
Expand All @@ -837,6 +834,7 @@ class FlintREPLTest
mockSparkSession,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -1076,7 +1074,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
shortInactivityLimit,
60,
Expand Down Expand Up @@ -1146,7 +1143,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
longInactivityLimit,
60,
Expand Down Expand Up @@ -1212,7 +1208,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
inactivityLimit,
60,
Expand Down Expand Up @@ -1283,7 +1278,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
inactivityLimit,
60,
Expand Down Expand Up @@ -1367,7 +1361,6 @@ class FlintREPLTest
override val osClient: OSClient = mockOSClient
override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater
}
val queryResultWriter = mock[QueryResultWriter]

val commandContext = CommandContext(
applicationId,
Expand All @@ -1377,7 +1370,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
inactivityLimit,
60,
Expand Down Expand Up @@ -1453,7 +1445,6 @@ class FlintREPLTest
INTERACTIVE_JOB_TYPE,
sessionId,
sessionManager,
queryResultWriter,
Duration(10, MINUTES),
inactivityLimit,
60,
Expand Down

0 comments on commit dda7427

Please sign in to comment.