From dda7427ad21fe70c71761768f6b28599fce4ee0e Mon Sep 17 00:00:00 2001
From: Louis Chu <clingzhi@amazon.com>
Date: Wed, 25 Sep 2024 16:49:49 -0700
Subject: [PATCH] Refactor query result writer

Signed-off-by: Louis Chu <clingzhi@amazon.com>
---
 .../apache/spark/sql/QueryResultWriter.scala  |  8 +++++
 .../org/apache/spark/sql/CommandContext.scala |  1 -
 .../apache/spark/sql/FlintJobExecutor.scala   |  3 +-
 .../org/apache/spark/sql/FlintREPL.scala      | 36 ++++++++++++-------
 .../org/apache/spark/sql/JobOperator.scala    |  1 -
 .../spark/sql/QueryResultWriterImpl.scala     | 29 ++++++++++++++-
 .../sql/StatementExecutionManagerImpl.scala   | 15 ++++----
 .../org/apache/spark/sql/FlintREPLTest.scala  | 13 ++-----
 8 files changed, 69 insertions(+), 37 deletions(-)

diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala
index 49dc8e355..bc76547f6 100644
--- a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala
+++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala
@@ -17,4 +17,12 @@ trait QueryResultWriter {
    * data storage based on the provided FlintStatement metadata.
    */
   def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
+
+  /**
+   * Reformat the given DataFrame to the desired format.
+   */
+  def reformatDataFrame(
+      dataFrame: DataFrame,
+      flintStatement: FlintStatement,
+      queryStartTime: Long): DataFrame
 }
diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala
index 42b1ae2f6..56bd9cb00 100644
--- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala
+++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala
@@ -18,7 +18,6 @@ case class CommandContext(
     jobType: String,
     sessionId: String,
     sessionManager: SessionManager,
-    queryResultWriter: QueryResultWriter,
     queryExecutionTimeout: Duration,
     inactivityLimitMillis: Long,
     queryWaitTimeMillis: Long,
diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala
index 24d68fd47..c076f9974 100644
--- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala
+++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala
@@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception
 import com.fasterxml.jackson.databind.ObjectMapper
 import com.fasterxml.jackson.module.scala.DefaultScalaModule
 import org.apache.commons.text.StringEscapeUtils.unescapeJava
+import org.opensearch.common.Strings
 import org.opensearch.flint.core.IRestHighLevelClient
 import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
 import org.opensearch.flint.core.metrics.MetricConstants
@@ -533,7 +534,7 @@ trait FlintJobExecutor {
   }
 
   def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
-    if (className.isEmpty) {
+    if (Strings.isNullOrEmpty(className)) {
       defaultConstructor
     } else {
       try {
diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala
index a0516a37a..a57f8127d 100644
--- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala
+++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala
@@ -169,8 +169,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
           return
         }
 
-        val queryResultWriter =
-          instantiateQueryResultWriter(conf, sessionManager.getSessionContext)
         val commandContext = CommandContext(
           applicationId,
           jobId,
@@ -179,7 +177,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
           jobType,
           sessionId,
           sessionManager,
-          queryResultWriter,
           queryExecutionTimeoutSecs,
           inactivityLimitMillis,
           queryWaitTimeoutMillis,
@@ -316,7 +313,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
     // 1 thread for async query execution
     val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
     implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
-
+    val queryResultWriter = instantiateQueryResultWriter(spark, commandContext)
     var futurePrepareQueryExecution: Future[Either[String, Unit]] = null
     try {
       logInfo(s"""Executing session with sessionId: ${sessionId}""")
@@ -342,7 +339,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
             executionContext,
             lastCanPickCheckTime)
           val result: (Long, VerificationResult, Boolean, Long) =
-            processCommands(statementsExecutionManager, commandContext, commandState)
+            processCommands(
+              statementsExecutionManager,
+              queryResultWriter,
+              commandContext,
+              commandState)
 
           val (
             updatedLastActivityTime,
@@ -491,6 +492,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
 
   private def processCommands(
       statementExecutionManager: StatementExecutionManager,
+      queryResultWriter: QueryResultWriter,
       context: CommandContext,
       state: CommandState): (Long, VerificationResult, Boolean, Long) = {
     import context._
@@ -525,6 +527,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
             val (dataToWrite, returnedVerificationResult) =
               processStatementOnVerification(
                 statementExecutionManager,
+                queryResultWriter,
                 flintStatement,
                 state,
                 context)
@@ -532,7 +535,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
             verificationResult = returnedVerificationResult
             finalizeCommand(
               statementExecutionManager,
-              context,
+              queryResultWriter,
               dataToWrite,
               flintStatement,
               statementTimerContext)
@@ -558,11 +561,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
    */
   private def finalizeCommand(
       statementExecutionManager: StatementExecutionManager,
-      commandContext: CommandContext,
+      queryResultWriter: QueryResultWriter,
       dataToWrite: Option[DataFrame],
       flintStatement: FlintStatement,
       statementTimerContext: Timer.Context): Unit = {
-    import commandContext._
     try {
       dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement))
       if (flintStatement.isRunning || flintStatement.isWaiting) {
@@ -626,6 +628,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
       spark: SparkSession,
       flintStatement: FlintStatement,
       statementExecutionManager: StatementExecutionManager,
+      queryResultWriter: QueryResultWriter,
       dataSource: String,
       sessionId: String,
       executionContext: ExecutionContextExecutor,
@@ -640,6 +643,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
           spark,
           flintStatement,
           statementExecutionManager,
+          queryResultWriter,
           dataSource,
           sessionId,
           executionContext,
@@ -677,6 +681,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
 
   private def processStatementOnVerification(
       statementExecutionManager: StatementExecutionManager,
+      queryResultWriter: QueryResultWriter,
       flintStatement: FlintStatement,
       commandState: CommandState,
       commandContext: CommandContext) = {
@@ -698,6 +703,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
                 spark,
                 flintStatement,
                 statementExecutionManager,
+                queryResultWriter,
                 dataSource,
                 sessionId,
                 executionContext,
@@ -764,6 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
           spark,
           flintStatement,
           statementExecutionManager,
+          queryResultWriter,
           dataSource,
           sessionId,
           executionContext,
@@ -782,6 +789,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
       spark: SparkSession,
       flintStatement: FlintStatement,
       statementsExecutionManager: StatementExecutionManager,
+      queryResultWriter: QueryResultWriter,
       dataSource: String,
       sessionId: String,
       executionContext: ExecutionContextExecutor,
@@ -801,7 +809,9 @@ object FlintREPL extends Logging with FlintJobExecutor {
         startTime)
     } else {
       val futureQueryExecution = Future {
-        statementsExecutionManager.executeStatement(flintStatement)
+        val startTime = System.currentTimeMillis()
+        val df = statementsExecutionManager.executeStatement(flintStatement)
+        queryResultWriter.reformatDataFrame(df, flintStatement, startTime)
       }(executionContext)
       // time out after 10 minutes
       ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut)
@@ -998,11 +1008,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
   }
 
   private def instantiateQueryResultWriter(
-      sparkConf: SparkConf,
-      context: Map[String, Any]): QueryResultWriter = {
+      spark: SparkSession,
+      commandContext: CommandContext): QueryResultWriter = {
     instantiate(
-      new QueryResultWriterImpl(context),
-      sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
+      new QueryResultWriterImpl(commandContext),
+      spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
   }
 
   private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala
index deee6eb1d..58d868a2e 100644
--- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala
+++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala
@@ -61,7 +61,6 @@ case class JobOperator(
       jobType,
       "", // FlintJob doesn't have sessionId
       null, // FlintJob doesn't have SessionManager
-      null, // FlintJob doesn't have QueryResultWriter
       Duration.Inf, // FlintJob doesn't have queryExecutionTimeout
       -1, // FlintJob doesn't have inactivityLimitMillis
       -1, // FlintJob doesn't have queryWaitTimeMillis
diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala
index 23d7f42a1..c64712621 100644
--- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala
+++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala
@@ -10,9 +10,14 @@ import org.opensearch.flint.common.model.FlintStatement
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch
 import org.apache.spark.sql.flint.config.FlintSparkConf
+import org.apache.spark.sql.util.CleanerFactory
 
-class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging {
+class QueryResultWriterImpl(commandContext: CommandContext)
+    extends QueryResultWriter
+    with FlintJobExecutor
+    with Logging {
 
+  private val context = commandContext.sessionManager.getSessionContext
   private val resultIndex = context("resultIndex").asInstanceOf[String]
   // Initialize OSClient with Flint options because custom session manager implementation should not have it in the context
   private val osClient = new OSClient(FlintSparkConf().flintOptions())
@@ -20,4 +25,26 @@ class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter
   override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = {
     writeDataFrameToOpensearch(dataFrame, resultIndex, osClient)
   }
+
+  /**
+   * Reformat the given DataFrame to the desired format.
+   */
+  override def reformatDataFrame(
+      dataFrame: DataFrame,
+      statement: FlintStatement,
+      queryStartTime: Long): DataFrame = {
+    import commandContext._
+    getFormattedData(
+      applicationId,
+      jobId,
+      dataFrame,
+      spark,
+      dataSource,
+      statement.queryId,
+      statement.query,
+      sessionId,
+      queryStartTime,
+      currentTimeProvider,
+      CleanerFactory.cleaner(false))
+  }
 }
diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala
index 4e9435f7b..432d6df11 100644
--- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala
+++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala
@@ -54,16 +54,13 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
   }
 
   override def executeStatement(statement: FlintStatement): DataFrame = {
-    import commandContext._
-    executeQuery(
-      applicationId,
-      jobId,
-      spark,
-      statement.query,
-      dataSource,
+    import commandContext.spark
+    // we have to set job group in the same thread that started the query according to spark doc
+    spark.sparkContext.setJobGroup(
       statement.queryId,
-      sessionId,
-      false)
+      "Job group for " + statement.queryId,
+      interruptOnCancel = true)
+    spark.sql(statement.query)
   }
 
   private def createOpenSearchQueryReader() = {
diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala
index 355bd9ede..5eeccce73 100644
--- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala
+++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala
@@ -675,7 +675,6 @@ class FlintREPLTest
         INTERACTIVE_JOB_TYPE,
         sessionId,
         sessionManager,
-        queryResultWriter,
         Duration(10, MINUTES),
         60,
         60,
@@ -748,7 +747,6 @@ class FlintREPLTest
         INTERACTIVE_JOB_TYPE,
         sessionId,
         sessionManager,
-        queryResultWriter,
         Duration(10, MINUTES),
         60,
         60,
@@ -761,6 +759,7 @@ class FlintREPLTest
         mockSparkSession,
         flintStatement,
         statementExecutionManager,
+        queryResultWriter,
         dataSource,
         sessionId,
         executionContext,
@@ -809,7 +808,6 @@ class FlintREPLTest
       when(mockSparkSession.sparkContext).thenReturn(sparkContext)
 
       // Assume handleQueryException logs the error and returns an error message string
-      val mockErrorString = "Error due to syntax"
       when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]]))
         .thenReturn(expectedDataFrame)
       when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame)
@@ -824,7 +822,6 @@ class FlintREPLTest
         INTERACTIVE_JOB_TYPE,
         sessionId,
         sessionManager,
-        queryResultWriter,
         Duration(10, MINUTES),
         60,
         60,
@@ -837,6 +834,7 @@ class FlintREPLTest
         mockSparkSession,
         flintStatement,
         statementExecutionManager,
+        queryResultWriter,
         dataSource,
         sessionId,
         executionContext,
@@ -1076,7 +1074,6 @@ class FlintREPLTest
       INTERACTIVE_JOB_TYPE,
       sessionId,
       sessionManager,
-      queryResultWriter,
       Duration(10, MINUTES),
       shortInactivityLimit,
       60,
@@ -1146,7 +1143,6 @@ class FlintREPLTest
       INTERACTIVE_JOB_TYPE,
       sessionId,
       sessionManager,
-      queryResultWriter,
       Duration(10, MINUTES),
       longInactivityLimit,
       60,
@@ -1212,7 +1208,6 @@ class FlintREPLTest
       INTERACTIVE_JOB_TYPE,
       sessionId,
       sessionManager,
-      queryResultWriter,
       Duration(10, MINUTES),
       inactivityLimit,
       60,
@@ -1283,7 +1278,6 @@ class FlintREPLTest
       INTERACTIVE_JOB_TYPE,
       sessionId,
       sessionManager,
-      queryResultWriter,
       Duration(10, MINUTES),
       inactivityLimit,
       60,
@@ -1367,7 +1361,6 @@ class FlintREPLTest
       override val osClient: OSClient = mockOSClient
       override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater
     }
-    val queryResultWriter = mock[QueryResultWriter]
 
     val commandContext = CommandContext(
       applicationId,
@@ -1377,7 +1370,6 @@ class FlintREPLTest
       INTERACTIVE_JOB_TYPE,
       sessionId,
       sessionManager,
-      queryResultWriter,
       Duration(10, MINUTES),
       inactivityLimit,
       60,
@@ -1453,7 +1445,6 @@ class FlintREPLTest
         INTERACTIVE_JOB_TYPE,
         sessionId,
         sessionManager,
-        queryResultWriter,
         Duration(10, MINUTES),
         inactivityLimit,
         60,