diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala index 49dc8e355..bc76547f6 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -17,4 +17,12 @@ trait QueryResultWriter { * data storage based on the provided FlintStatement metadata. */ def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit + + /** + * Reformat the given DataFrame to the desired format. + */ + def reformatDataFrame( + dataFrame: DataFrame, + flintStatement: FlintStatement, + queryStartTime: Long): DataFrame } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 42b1ae2f6..56bd9cb00 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -18,7 +18,6 @@ case class CommandContext( jobType: String, sessionId: String, sessionManager: SessionManager, - queryResultWriter: QueryResultWriter, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 24d68fd47..c076f9974 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.commons.text.StringEscapeUtils.unescapeJava +import org.opensearch.common.Strings import org.opensearch.flint.core.IRestHighLevelClient import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage} import org.opensearch.flint.core.metrics.MetricConstants @@ -533,7 +534,7 @@ trait FlintJobExecutor { } def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { - if (className.isEmpty) { + if (Strings.isNullOrEmpty(className)) { defaultConstructor } else { try { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index a0516a37a..a57f8127d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -169,8 +169,6 @@ object FlintREPL extends Logging with FlintJobExecutor { return } - val queryResultWriter = - instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( applicationId, jobId, @@ -179,7 +177,6 @@ object FlintREPL extends Logging with FlintJobExecutor { jobType, sessionId, sessionManager, - queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis, @@ -316,7 +313,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - + val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { logInfo(s"""Executing session with sessionId: ${sessionId}""") @@ -342,7 +339,11 @@ object FlintREPL extends Logging with FlintJobExecutor { executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(statementsExecutionManager, commandContext, commandState) + processCommands( + statementsExecutionManager, + queryResultWriter, + commandContext, + commandState) val ( updatedLastActivityTime, @@ -491,6 +492,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processCommands( statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, context: CommandContext, state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ @@ -525,6 +527,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( statementExecutionManager, + queryResultWriter, flintStatement, state, context) @@ -532,7 +535,7 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = returnedVerificationResult finalizeCommand( statementExecutionManager, - context, + queryResultWriter, dataToWrite, flintStatement, statementTimerContext) @@ -558,11 +561,10 @@ object FlintREPL extends Logging with FlintJobExecutor { */ private def finalizeCommand( statementExecutionManager: StatementExecutionManager, - commandContext: CommandContext, + queryResultWriter: QueryResultWriter, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, statementTimerContext: Timer.Context): Unit = { - import commandContext._ try { dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { @@ -626,6 +628,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, flintStatement: FlintStatement, statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -640,6 +643,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -677,6 +681,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processStatementOnVerification( statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, flintStatement: FlintStatement, commandState: CommandState, commandContext: CommandContext) = { @@ -698,6 +703,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -764,6 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -782,6 +789,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, flintStatement: FlintStatement, statementsExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -801,7 +809,9 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } else { val futureQueryExecution = Future { - statementsExecutionManager.executeStatement(flintStatement) + val startTime = System.currentTimeMillis() + val df = statementsExecutionManager.executeStatement(flintStatement) + queryResultWriter.reformatDataFrame(df, flintStatement, startTime) }(executionContext) // time out after 10 minutes ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) @@ -998,11 +1008,11 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def instantiateQueryResultWriter( - sparkConf: SparkConf, - context: Map[String, Any]): QueryResultWriter = { + spark: SparkSession, + commandContext: CommandContext): QueryResultWriter = { instantiate( - new QueryResultWriterImpl(context), - sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) + new QueryResultWriterImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) } private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index deee6eb1d..58d868a2e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -61,7 +61,6 @@ case class JobOperator( jobType, "", // FlintJob doesn't have sessionId null, // FlintJob doesn't have SessionManager - null, // FlintJob doesn't have QueryResultWriter Duration.Inf, // FlintJob doesn't have queryExecutionTimeout -1, // FlintJob doesn't have inactivityLimitMillis -1, // FlintJob doesn't have queryWaitTimeMillis diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala index 23d7f42a1..c64712621 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -10,9 +10,14 @@ import org.opensearch.flint.common.model.FlintStatement import org.apache.spark.internal.Logging import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.util.CleanerFactory -class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging { +class QueryResultWriterImpl(commandContext: CommandContext) + extends QueryResultWriter + with FlintJobExecutor + with Logging { + private val context = commandContext.sessionManager.getSessionContext private val resultIndex = context("resultIndex").asInstanceOf[String] // Initialize OSClient with Flint options because custom session manager implementation should not have it in the context private val osClient = new OSClient(FlintSparkConf().flintOptions()) @@ -20,4 +25,26 @@ class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) } + + /** + * Reformat the given DataFrame to the desired format. + */ + override def reformatDataFrame( + dataFrame: DataFrame, + statement: FlintStatement, + queryStartTime: Long): DataFrame = { + import commandContext._ + getFormattedData( + applicationId, + jobId, + dataFrame, + spark, + dataSource, + statement.queryId, + statement.query, + sessionId, + queryStartTime, + currentTimeProvider, + CleanerFactory.cleaner(false)) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index 4e9435f7b..432d6df11 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -54,16 +54,13 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) } override def executeStatement(statement: FlintStatement): DataFrame = { - import commandContext._ - executeQuery( - applicationId, - jobId, - spark, - statement.query, - dataSource, + import commandContext.spark + // we have to set job group in the same thread that started the query according to spark doc + spark.sparkContext.setJobGroup( statement.queryId, - sessionId, - false) + "Job group for " + statement.queryId, + interruptOnCancel = true) + spark.sql(statement.query) } private def createOpenSearchQueryReader() = { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 355bd9ede..5eeccce73 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -675,7 +675,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -748,7 +747,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -761,6 +759,7 @@ class FlintREPLTest mockSparkSession, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -809,7 +808,6 @@ class FlintREPLTest when(mockSparkSession.sparkContext).thenReturn(sparkContext) // Assume handleQueryException logs the error and returns an error message string - val mockErrorString = "Error due to syntax" when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]])) .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) @@ -824,7 +822,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -837,6 +834,7 @@ class FlintREPLTest mockSparkSession, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -1076,7 +1074,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, @@ -1146,7 +1143,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, @@ -1212,7 +1208,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1283,7 +1278,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1367,7 +1361,6 @@ class FlintREPLTest override val osClient: OSClient = mockOSClient override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1377,7 +1370,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1453,7 +1445,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60,