Skip to content

Commit

Permalink
[REFACTOR] Move DF reformat from StatementExecutionManagerImpl to Que…
Browse files Browse the repository at this point in the history
…ryResultWriterImpl (#701)

* Refactor query result writer

Signed-off-by: Louis Chu <[email protected]>

* Add more scala doc and update sbt

Signed-off-by: Louis Chu <[email protected]>

---------

Signed-off-by: Louis Chu <[email protected]>
(cherry picked from commit d76e0cd)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Sep 30, 2024
1 parent 7423de7 commit 60ca296
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 40 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ lazy val flintCore = (project in file("flint-core"))
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
"com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593"
exclude("com.fasterxml.jackson.core", "jackson-databind"),
"software.amazon.awssdk" % "auth-crt" % "2.25.23",
"software.amazon.awssdk" % "auth-crt" % "2.28.10" % "provided",
"org.scalactic" %% "scalactic" % "3.2.15" % "test",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,39 @@ import org.opensearch.flint.common.model.FlintStatement
trait QueryResultWriter {

/**
* Writes the given DataFrame, which represents the result of a query execution, to an external
* data storage based on the provided FlintStatement metadata.
* Writes the given DataFrame to an external data storage based on the FlintStatement metadata.
* This method is responsible for persisting the query results.
*
* Note: This method typically involves I/O operations and may trigger Spark actions to
* materialize the DataFrame if it hasn't been processed yet.
*
* @param dataFrame
* The DataFrame containing the query results to be written.
* @param flintStatement
* The FlintStatement containing metadata that guides the writing process.
*/
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit

/**
* Defines transformations on the given DataFrame and triggers an action to process it. This
* method applies necessary transformations based on the FlintStatement metadata and executes an
* action to compute the result.
*
* Note: Calling this method will trigger the actual data processing in Spark. If the Spark SQL
* thread is waiting for the result of a query, termination on the same thread will be blocked
* until the action completes.
*
* @param dataFrame
* The DataFrame to be processed.
* @param flintStatement
* The FlintStatement containing statement metadata.
* @param queryStartTime
* The start time of the query execution.
* @return
* The processed DataFrame after applying transformations and executing an action.
*/
def processDataFrame(
dataFrame: DataFrame,
flintStatement: FlintStatement,
queryStartTime: Long): DataFrame
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ case class CommandContext(
jobType: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.common.Strings
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
import org.opensearch.flint.core.metrics.MetricConstants
Expand Down Expand Up @@ -533,7 +534,7 @@ trait FlintJobExecutor {
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
if (Strings.isNullOrEmpty(className)) {
defaultConstructor
} else {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
return
}

val queryResultWriter =
instantiateQueryResultWriter(conf, sessionManager.getSessionContext)
val commandContext = CommandContext(
applicationId,
jobId,
Expand All @@ -179,7 +177,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobType,
sessionId,
sessionManager,
queryResultWriter,
queryExecutionTimeoutSecs,
inactivityLimitMillis,
queryWaitTimeoutMillis,
Expand Down Expand Up @@ -316,7 +313,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
// 1 thread for async query execution
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

val queryResultWriter = instantiateQueryResultWriter(spark, commandContext)
var futurePrepareQueryExecution: Future[Either[String, Unit]] = null
try {
logInfo(s"""Executing session with sessionId: ${sessionId}""")
Expand All @@ -342,7 +339,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
executionContext,
lastCanPickCheckTime)
val result: (Long, VerificationResult, Boolean, Long) =
processCommands(statementsExecutionManager, commandContext, commandState)
processCommands(
statementsExecutionManager,
queryResultWriter,
commandContext,
commandState)

val (
updatedLastActivityTime,
Expand Down Expand Up @@ -491,6 +492,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processCommands(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
context: CommandContext,
state: CommandState): (Long, VerificationResult, Boolean, Long) = {
import context._
Expand Down Expand Up @@ -525,14 +527,15 @@ object FlintREPL extends Logging with FlintJobExecutor {
val (dataToWrite, returnedVerificationResult) =
processStatementOnVerification(
statementExecutionManager,
queryResultWriter,
flintStatement,
state,
context)

verificationResult = returnedVerificationResult
finalizeCommand(
statementExecutionManager,
context,
queryResultWriter,
dataToWrite,
flintStatement,
statementTimerContext)
Expand All @@ -558,11 +561,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
*/
private def finalizeCommand(
statementExecutionManager: StatementExecutionManager,
commandContext: CommandContext,
queryResultWriter: QueryResultWriter,
dataToWrite: Option[DataFrame],
flintStatement: FlintStatement,
statementTimerContext: Timer.Context): Unit = {
import commandContext._
try {
dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement))
if (flintStatement.isRunning || flintStatement.isWaiting) {
Expand Down Expand Up @@ -626,6 +628,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -640,6 +643,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -677,6 +681,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processStatementOnVerification(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
flintStatement: FlintStatement,
commandState: CommandState,
commandContext: CommandContext) = {
Expand All @@ -698,6 +703,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -764,6 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand All @@ -782,6 +789,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementsExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -801,7 +809,14 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime)
} else {
val futureQueryExecution = Future {
statementsExecutionManager.executeStatement(flintStatement)
val startTime = System.currentTimeMillis()
// Execute the statement and get the resulting DataFrame
// This step may involve Spark transformations, but not necessarily actions
val df = statementsExecutionManager.executeStatement(flintStatement)
// Process the DataFrame, applying any necessary transformations
// and triggering Spark actions to materialize the results
// This is where the actual data processing occurs
queryResultWriter.processDataFrame(df, flintStatement, startTime)
}(executionContext)
// time out after 10 minutes
ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut)
Expand Down Expand Up @@ -998,11 +1013,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
}

private def instantiateQueryResultWriter(
sparkConf: SparkConf,
context: Map[String, Any]): QueryResultWriter = {
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(context),
sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ case class JobOperator(
jobType,
"", // FlintJob doesn't have sessionId
null, // FlintJob doesn't have SessionManager
null, // FlintJob doesn't have QueryResultWriter
Duration.Inf, // FlintJob doesn't have queryExecutionTimeout
-1, // FlintJob doesn't have inactivityLimitMillis
-1, // FlintJob doesn't have queryWaitTimeMillis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,42 @@ import org.opensearch.flint.common.model.FlintStatement
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.CleanerFactory

class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging {
class QueryResultWriterImpl(commandContext: CommandContext)
extends QueryResultWriter
with FlintJobExecutor
with Logging {

private val context = commandContext.sessionManager.getSessionContext
private val resultIndex = context("resultIndex").asInstanceOf[String]
// Initialize OSClient with Flint options because custom session manager implementation should not have it in the context
private val osClient = new OSClient(FlintSparkConf().flintOptions())

override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = {
writeDataFrameToOpensearch(dataFrame, resultIndex, osClient)
}

override def processDataFrame(
dataFrame: DataFrame,
statement: FlintStatement,
queryStartTime: Long): DataFrame = {
import commandContext._

/**
* Reformat the given DataFrame to the desired format for OpenSearch storage.
*/
getFormattedData(
applicationId,
jobId,
dataFrame,
spark,
dataSource,
statement.queryId,
statement.query,
sessionId,
queryStartTime,
currentTimeProvider,
CleanerFactory.cleaner(false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,13 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
}

override def executeStatement(statement: FlintStatement): DataFrame = {
import commandContext._
executeQuery(
applicationId,
jobId,
spark,
statement.query,
dataSource,
import commandContext.spark
// we have to set job group in the same thread that started the query according to spark doc
spark.sparkContext.setJobGroup(
statement.queryId,
sessionId,
false)
"Job group for " + statement.queryId,
interruptOnCancel = true)
spark.sql(statement.query)
}

private def createOpenSearchQueryReader() = {
Expand Down
Loading

0 comments on commit 60ca296

Please sign in to comment.