Skip to content

Commit

Permalink
Revert "[REFACTOR] Move DF reformat from StatementExecutionManagerImp…
Browse files Browse the repository at this point in the history
…l to QueryResultWriterImpl (opensearch-project#701) (opensearch-project#715)"

This reverts commit 74397d8.
  • Loading branch information
noCharger committed Oct 2, 2024
1 parent 346d9a5 commit eb7116e
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 101 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ lazy val flintCore = (project in file("flint-core"))
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
"com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593"
exclude("com.fasterxml.jackson.core", "jackson-databind"),
"software.amazon.awssdk" % "auth-crt" % "2.28.10" % "provided",
"software.amazon.awssdk" % "auth-crt" % "2.25.23",
"org.scalactic" %% "scalactic" % "3.2.15" % "test",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,8 @@ import org.opensearch.flint.common.model.FlintStatement
trait QueryResultWriter {

/**
* Writes the given DataFrame to an external data storage based on the FlintStatement metadata.
* This method is responsible for persisting the query results.
*
* Note: This method typically involves I/O operations and may trigger Spark actions to
* materialize the DataFrame if it hasn't been processed yet.
*
* @param dataFrame
* The DataFrame containing the query results to be written.
* @param flintStatement
* The FlintStatement containing metadata that guides the writing process.
* Writes the given DataFrame, which represents the result of a query execution, to an external
* data storage based on the provided FlintStatement metadata.
*/
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit

/**
* Defines transformations on the given DataFrame and triggers an action to process it. This
* method applies necessary transformations based on the FlintStatement metadata and executes an
* action to compute the result.
*
* Note: Calling this method will trigger the actual data processing in Spark. If the Spark SQL
* thread is waiting for the result of a query, termination on the same thread will be blocked
* until the action completes.
*
* @param dataFrame
* The DataFrame to be processed.
* @param flintStatement
* The FlintStatement containing statement metadata.
* @param queryStartTime
* The start time of the query execution.
* @return
* The processed DataFrame after applying transformations and executing an action.
*/
def processDataFrame(
dataFrame: DataFrame,
flintStatement: FlintStatement,
queryStartTime: Long): DataFrame
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ case class CommandContext(
jobType: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import com.amazonaws.services.s3.model.AmazonS3Exception
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.common.Strings
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
import org.opensearch.flint.core.metrics.MetricConstants
Expand Down Expand Up @@ -534,7 +533,7 @@ trait FlintJobExecutor {
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (Strings.isNullOrEmpty(className)) {
if (className.isEmpty) {
defaultConstructor
} else {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
return
}

val queryResultWriter =
instantiateQueryResultWriter(conf, sessionManager.getSessionContext)
val commandContext = CommandContext(
applicationId,
jobId,
Expand All @@ -177,6 +179,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobType,
sessionId,
sessionManager,
queryResultWriter,
queryExecutionTimeoutSecs,
inactivityLimitMillis,
queryWaitTimeoutMillis,
Expand Down Expand Up @@ -313,7 +316,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
// 1 thread for async query execution
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
val queryResultWriter = instantiateQueryResultWriter(spark, commandContext)

var futurePrepareQueryExecution: Future[Either[String, Unit]] = null
try {
logInfo(s"""Executing session with sessionId: ${sessionId}""")
Expand All @@ -339,11 +342,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
executionContext,
lastCanPickCheckTime)
val result: (Long, VerificationResult, Boolean, Long) =
processCommands(
statementsExecutionManager,
queryResultWriter,
commandContext,
commandState)
processCommands(statementsExecutionManager, commandContext, commandState)

val (
updatedLastActivityTime,
Expand Down Expand Up @@ -492,7 +491,6 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processCommands(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
context: CommandContext,
state: CommandState): (Long, VerificationResult, Boolean, Long) = {
import context._
Expand Down Expand Up @@ -527,15 +525,14 @@ object FlintREPL extends Logging with FlintJobExecutor {
val (dataToWrite, returnedVerificationResult) =
processStatementOnVerification(
statementExecutionManager,
queryResultWriter,
flintStatement,
state,
context)

verificationResult = returnedVerificationResult
finalizeCommand(
statementExecutionManager,
queryResultWriter,
context,
dataToWrite,
flintStatement,
statementTimerContext)
Expand All @@ -561,10 +558,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
*/
private def finalizeCommand(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
commandContext: CommandContext,
dataToWrite: Option[DataFrame],
flintStatement: FlintStatement,
statementTimerContext: Timer.Context): Unit = {
import commandContext._
try {
dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement))
if (flintStatement.isRunning || flintStatement.isWaiting) {
Expand Down Expand Up @@ -628,7 +626,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -643,7 +640,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -681,7 +677,6 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processStatementOnVerification(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
flintStatement: FlintStatement,
commandState: CommandState,
commandContext: CommandContext) = {
Expand All @@ -703,7 +698,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -770,7 +764,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand All @@ -789,7 +782,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementsExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -809,14 +801,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime)
} else {
val futureQueryExecution = Future {
val startTime = System.currentTimeMillis()
// Execute the statement and get the resulting DataFrame
// This step may involve Spark transformations, but not necessarily actions
val df = statementsExecutionManager.executeStatement(flintStatement)
// Process the DataFrame, applying any necessary transformations
// and triggering Spark actions to materialize the results
// This is where the actual data processing occurs
queryResultWriter.processDataFrame(df, flintStatement, startTime)
statementsExecutionManager.executeStatement(flintStatement)
}(executionContext)
// time out after 10 minutes
ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut)
Expand Down Expand Up @@ -1013,11 +998,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
}

private def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
sparkConf: SparkConf,
context: Map[String, Any]): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
new QueryResultWriterImpl(context),
sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ case class JobOperator(
jobType,
"", // FlintJob doesn't have sessionId
null, // FlintJob doesn't have SessionManager
null, // FlintJob doesn't have QueryResultWriter
Duration.Inf, // FlintJob doesn't have queryExecutionTimeout
-1, // FlintJob doesn't have inactivityLimitMillis
-1, // FlintJob doesn't have queryWaitTimeMillis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,14 @@ import org.opensearch.flint.common.model.FlintStatement
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.CleanerFactory

class QueryResultWriterImpl(commandContext: CommandContext)
extends QueryResultWriter
with FlintJobExecutor
with Logging {
class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging {

private val context = commandContext.sessionManager.getSessionContext
private val resultIndex = context("resultIndex").asInstanceOf[String]
// Initialize OSClient with Flint options because custom session manager implementation should not have it in the context
private val osClient = new OSClient(FlintSparkConf().flintOptions())

override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = {
writeDataFrameToOpensearch(dataFrame, resultIndex, osClient)
}

override def processDataFrame(
dataFrame: DataFrame,
statement: FlintStatement,
queryStartTime: Long): DataFrame = {
import commandContext._

/**
* Reformat the given DataFrame to the desired format for OpenSearch storage.
*/
getFormattedData(
applicationId,
jobId,
dataFrame,
spark,
dataSource,
statement.queryId,
statement.query,
sessionId,
queryStartTime,
currentTimeProvider,
CleanerFactory.cleaner(false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
}

override def executeStatement(statement: FlintStatement): DataFrame = {
import commandContext.spark
// we have to set job group in the same thread that started the query according to spark doc
spark.sparkContext.setJobGroup(
import commandContext._
executeQuery(
applicationId,
jobId,
spark,
statement.query,
dataSource,
statement.queryId,
"Job group for " + statement.queryId,
interruptOnCancel = true)
spark.sql(statement.query)
sessionId,
false)
}

private def createOpenSearchQueryReader() = {
Expand Down
Loading

0 comments on commit eb7116e

Please sign in to comment.