From 8b6fa08b91b9e50bedbce107d821855373028d54 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 13 Jun 2024 21:16:58 -0700 Subject: [PATCH] refactor statement lifecycle --- .../apache/spark/sql/QueryResultWriter.scala | 1 + .../org/apache/spark/sql/SessionManager.scala | 4 +- .../spark/sql/StatementLifecycleManager.scala | 15 + .../apache/spark/sql/StatementManager.scala | 15 - .../apache/spark/sql/FlintJobExecutor.scala | 63 +-- .../org/apache/spark/sql/FlintREPL.scala | 435 +++++++++--------- .../spark/sql/QueryExecutionContext.scala | 3 +- .../spark/sql/QueryResultWriterImpl.scala | 29 ++ .../apache/spark/sql/SessionManagerImpl.scala | 123 ++--- .../sql/StatementLifecycleManagerImpl.scala | 57 +++ ...cala => inMemoryQueryExecutionState.scala} | 9 +- .../org/apache/spark/sql/FlintJobTest.scala | 11 +- .../org/apache/spark/sql/FlintREPLTest.scala | 7 +- 13 files changed, 404 insertions(+), 368 deletions(-) create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala delete mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala rename spark-sql-application/src/main/scala/org/apache/spark/sql/{CommandState.scala => inMemoryQueryExecutionState.scala} (61%) diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala index d69fbc30f..15d971faa 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -8,5 +8,6 @@ package org.apache.spark.sql import org.opensearch.flint.data.FlintStatement trait QueryResultWriter { + def reformat(dataFrame: DataFrame, flintStatement: FlintStatement): DataFrame def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit } diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala index 345f97619..9b53240c8 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -15,11 +15,11 @@ trait SessionManager { def updateSessionDetails( sessionDetails: InteractiveSession, updateMode: SessionUpdateMode): Unit - def hasPendingStatement(sessionId: String): Boolean + def getNextStatement(sessionId: String): Option[FlintStatement] def recordHeartbeat(sessionId: String): Unit } object SessionUpdateMode extends Enumeration { type SessionUpdateMode = Value - val Update, Upsert, UpdateIf = Value + val UPDATE, UPSERT, UPDATE_IF = Value } diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala new file mode 100644 index 000000000..c1408d3d5 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +trait StatementLifecycleManager { + def initializeStatementLifecycle(): Either[String, Unit] + def updateStatement(statement: FlintStatement): Unit + def setLocalProperty(statement: FlintStatement): Unit + def terminateStatementLifecycle(): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala deleted file mode 100644 index c0a24ab33..000000000 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementManager.scala +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.apache.spark.sql - -import org.opensearch.flint.data.FlintStatement - -trait StatementManager { - def prepareCommandLifecycle(): Either[String, Unit] - def initCommandLifecycle(sessionId: String): FlintStatement - def closeCommandLifecycle(): Unit - def updateCommandDetails(commandDetails: FlintStatement): Unit -} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 665ec5a27..673f81614 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -167,16 +167,7 @@ trait FlintJobExecutor { * @return * dataframe with result, schema and emr step id */ - def getFormattedData( - result: DataFrame, - spark: SparkSession, - dataSource: String, - queryId: String, - query: String, - sessionId: String, - startTime: Long, - timeProvider: TimeProvider, - cleaner: Cleaner): DataFrame = { + def getFormattedData(result: DataFrame, spark: SparkSession, dataSource: String, queryId: String, query: String, sessionId: String, startTime: Long, cleaner: Cleaner): DataFrame = { // Create the schema dataframe val schemaRows = result.schema.fields.map { field => Row(field.name, field.dataType.typeName) @@ -210,7 +201,7 @@ trait FlintJobExecutor { .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")) val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")) - val endTime = timeProvider.currentEpochMillis() + val endTime = currentTimeProvider.currentEpochMillis() // https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data // after consumed the query result. Streaming query shuffle data is cleaned after each @@ -348,31 +339,6 @@ trait FlintJobExecutor { compareJson(inputJson, mappingJson) || compareJson(mappingJson, inputJson) } - def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { - try { - val existingSchema = osClient.getIndexMetadata(resultIndex) - if (!isSuperset(existingSchema, resultIndexMapping)) { - Left(s"The mapping of $resultIndex is incorrect.") - } else { - Right(()) - } - } catch { - case e: IllegalStateException - if e.getCause != null && - e.getCause.getMessage.contains("index_not_found_exception") => - createResultIndex(osClient, resultIndex, resultIndexMapping) - case e: InterruptedException => - val error = s"Interrupted by the main thread: ${e.getMessage}" - Thread.currentThread().interrupt() // Preserve the interrupt status - logError(error, e) - Left(error) - case e: Exception => - val error = s"Failed to verify existing mapping: ${e.getMessage}" - logError(error, e) - Left(error) - } - } - def createResultIndex( osClient: OSClient, resultIndex: String, @@ -411,16 +377,7 @@ trait FlintJobExecutor { spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true) val result: DataFrame = spark.sql(query) // Get Data - getFormattedData( - result, - spark, - dataSource, - queryId, - query, - sessionId, - startTime, - currentTimeProvider, - CleanerFactory.cleaner(streaming)) + getFormattedData(result, spark, dataSource, queryId, query, sessionId, startTime, CleanerFactory.cleaner(streaming)) } private def handleQueryException( @@ -485,16 +442,16 @@ trait FlintJobExecutor { } } - def parseArgs(args: Array[String]): (Option[String], String) = { + def parseArgs(args: Array[String]): (Option[String], Option[String]) = { args match { + case Array() => + (None, None) case Array(resultIndex) => - (None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument + (None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument case Array(query, resultIndex) => - ( - Some(query), - resultIndex - ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.") + (Some(query), Some(resultIndex)) // Before OS 2.13, there are two arguments, the second one is resultIndex + case _ => + logAndThrow("Unsupported number of arguments. Expected no more than two arguments.") } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 9b0b66d28..bbd25c5e5 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -22,12 +22,13 @@ import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} +import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SessionUpdateMode.UPDATE_IF import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.{ThreadUtils, Utils} @@ -47,11 +48,11 @@ import org.apache.spark.util.{ThreadUtils, Utils} object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + private val PREPARE_QUERY_EXEC_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 - val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L + private val INITIAL_DELAY_MILLIS = 3000L + private val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L @volatile var earlyExitFlag: Boolean = false @@ -65,10 +66,6 @@ object FlintREPL extends Logging with FlintJobExecutor { def main(args: Array[String]) { val (queryOption, resultIndex) = parseArgs(args) - if (Strings.isNullOrEmpty(resultIndex)) { - logAndThrow("resultIndex is not set") - } - // init SparkContext val conf: SparkConf = createSparkConf() val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") @@ -89,19 +86,24 @@ object FlintREPL extends Logging with FlintJobExecutor { val query = getQuery(queryOption, jobType, conf) if (jobType.equalsIgnoreCase("streaming")) { - logInfo(s"""streaming query ${query}""") - configDYNMaxExecutors(conf, jobType) - val streamingRunningCount = new AtomicInteger(0) - val jobOperator = - JobOperator( - createSparkSession(conf), - query, - dataSource, - resultIndex, - true, - streamingRunningCount) - registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) - jobOperator.start() + logInfo(s"streaming query $query") + resultIndex match { + case Some(index) => + configDYNMaxExecutors(conf, jobType) + val streamingRunningCount = new AtomicInteger(0) + val jobOperator = JobOperator( + createSparkSession(conf), + query, + dataSource, + index, // Ensure the correct Option type is passed + true, + streamingRunningCount + ) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + jobOperator.start() + + case None => logAndThrow("resultIndex is not set") + } } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) @@ -111,10 +113,10 @@ object FlintREPL extends Logging with FlintJobExecutor { } val spark = createSparkSession(conf) + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong( @@ -128,7 +130,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - val sessionManager = instantiateSessionManager() + val sessionManager = instantiateSessionManager(spark, resultIndex) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -141,7 +143,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/320 */ spark.sparkContext.addSparkListener( - new PreShutdownListener(sessionManager, sessionId.get, sessionTimerContext)) + new PreShutdownListener(sessionId.get, sessionManager, sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -173,23 +175,35 @@ object FlintREPL extends Logging with FlintJobExecutor { return } - val commandContext = QueryExecutionContext( + val queryExecutionManager = instantiateQueryExecutionManager( + sessionManager.getSessionManagerMetadata) + val queryResultWriter = instantiateQueryResultWriter( + sessionManager.getSessionManagerMetadata) + val queryExecutionContext = QueryExecutionContext( spark, jobId, sessionId.get, sessionManager, + queryExecutionManager, + queryResultWriter, dataSource, - resultIndex, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { - queryLoop(commandContext) + queryLoop(queryExecutionContext) } recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => - handleSessionError(sessionTimerContext = sessionTimerContext, e = e) + handleSessionError( + applicationId, + jobId, + sessionId.get, + sessionManager, + sessionTimerContext, + jobStartTime, + e) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running @@ -275,16 +289,16 @@ object FlintREPL extends Logging with FlintJobExecutor { case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 - val excludeJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis + val excludedJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis - if (excludeJobIds.contains(jobId)) { + if (excludedJobIds.contains(jobId)) { logInfo(s"current job is excluded, exit the application.") return true } val sessionDetails = sessionManager.getSessionDetails(sessionId) val existingExcludedJobIds = sessionDetails.get.excludedJobIds - if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { + if (excludedJobIds.sorted == existingExcludedJobIds.sorted) { logInfo("duplicate job running, exit the application.") return true } @@ -296,20 +310,20 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, sessionManager, jobStartTime, - excludeJobIds) + excludedJobIds = excludedJobIds) } false } def queryLoop(queryExecutionContext: QueryExecutionContext): Unit = { - // 1 thread for updating heart beat + val statementLifecycleManager = queryExecutionContext.statementLifecycleManager + // 1 thread for query execution preparation val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var futureMappingCheck: Future[Either[String, Unit]] = null + implicit val futureExecutor = ExecutionContext.fromExecutor(threadPool) + var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { - checkAndCreateIndex(queryExecutionContext.osClient, queryExecutionContext.resultIndex) + futurePrepareQueryExecution = Future { + statementLifecycleManager.initializeStatementLifecycle() } var lastActivityTime = currentTimeProvider.currentEpochMillis() @@ -318,19 +332,17 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= queryExecutionContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo( - s"""read from ${queryExecutionContext.sessionIndex}, sessionId: ${queryExecutionContext.sessionId}""") + logInfo(s"""sessionId: ${queryExecutionContext.sessionId}""") try { - val commandState = CommandState( + val inMemoryQueryExecutionState = inMemoryQueryExecutionState( lastActivityTime, + lastCanPickCheckTime, verificationResult, - flintReader, - futureMappingCheck, - executionContext, - lastCanPickCheckTime) + futurePrepareQueryExecution, + futureExecutor) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(queryExecutionContext, commandState) + processCommands(queryExecutionContext, inMemoryQueryExecutionState) val ( updatedLastActivityTime, @@ -343,7 +355,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - flintReader.close() + statementLifecycleManager.terminateStatementLifecycle() } Thread.sleep(100) @@ -355,37 +367,52 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - // TODO: Refactor this with getDetails + private def refreshSessionState( + applicationId: String, + jobId: String, + sessionId: String, + sessionManager: SessionManager, + jobStartTime: Long, + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + sessionDetails.state = state + sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionDetails + } + private def setupFlintJob( applicationId: String, jobId: String, sessionId: String, sessionManager: SessionManager, jobStartTime: Long, - excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + excludedJobIds: Seq[String] = Seq.empty[String]): Unit = { + refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) - - // TODO: serialize need to be refactored to be more flexible - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo(s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}}""") + SessionStates.RUNNING, + excludedJobIds = excludedJobIds) sessionRunningCount.incrementAndGet() } - def handleSessionError( + private def handleSessionError( applicationId: String, jobId: String, sessionId: String, @@ -395,39 +422,15 @@ object FlintREPL extends Logging with FlintJobExecutor { e: Exception): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - - val sessionDetails = sessionManager - .getSessionDetails(sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(sessionDetails, flintSessionIndexUpdater, sessionId) - if (sessionDetails.isFail) { - recordSessionFailed(sessionTimerContext) - } - } - - private def createFailedFlintInstance( - applicationId: String, - jobId: String, - sessionId: String, - jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - - private def updateFlintInstance( - flintInstance: InteractiveSession, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { - val currentTime = currentTimeProvider.currentEpochMillis() - flintSessionIndexUpdater.upsert( + refreshSessionState( + applicationId, + jobId, sessionId, - InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } /** @@ -479,7 +482,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processCommands( context: QueryExecutionContext, - state: CommandState): (Long, VerificationResult, Boolean, Long) = { + state: inMemoryQueryExecutionState): (Long, VerificationResult, Boolean, Long) = { import context._ import state._ @@ -492,8 +495,8 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { + // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionManager, sessionId, jobId) lastCanPickCheckTime = currentTime } @@ -501,35 +504,23 @@ object FlintREPL extends Logging with FlintJobExecutor { if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { - val statementTimerContext = getTimerContext( - MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) + sessionManager.getNextStatement(sessionId) match { + case Some(flintStatement) => + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex, - queryExecutionTimeout, - queryWaitTimeMillis) - - verificationResult = returnedVerificationResult - finalizeCommand( - dataToWrite, - flintStatement, - resultIndex, - flintSessionIndexUpdater, - osClient, - statementTimerContext) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() + val (dataToWrite, returnedVerificationResult) = + processStatementOnVerification(context, state, flintStatement) + + verificationResult = returnedVerificationResult + finalizeCommand(context, dataToWrite, flintStatement, statementTimerContext) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + + case None => + canProceed = false + } } } @@ -550,19 +541,20 @@ object FlintREPL extends Logging with FlintJobExecutor { * flint session index updater */ private def finalizeCommand( + context: QueryExecutionContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, - resultIndex: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, statementTimerContext: Timer.Context): Unit = { + import context._ + try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + dataToWrite.foreach(df => queryResultWriter.write(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling flintStatement.complete() } - updateSessionIndex(flintStatement, flintSessionIndexUpdater) + statementLifecycleManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue @@ -571,7 +563,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) flintStatement.fail() - updateSessionIndex(flintStatement, flintSessionIndexUpdater) + statementLifecycleManager.updateStatement(flintStatement) recordStatementStateChange(flintStatement, statementTimerContext) } } @@ -607,25 +599,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeAndHandle( - spark: SparkSession, + context: QueryExecutionContext, + state: inMemoryQueryExecutionState, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecuitonTimeOut: Duration, - queryWaitTimeMillis: Long): Option[DataFrame] = { + startTime: Long): Option[DataFrame] = { + import context._ + import state._ + try { - Some( - executeQueryAsync( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecuitonTimeOut, - queryWaitTimeMillis)) + Some(executeQueryAsync(context, state, flintStatement, startTime)) } catch { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" @@ -645,16 +627,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processStatementOnVerification( - recordedVerificationResult: VerificationResult, - spark: SparkSession, - flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String, - queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long) = { + context: QueryExecutionContext, + state: inMemoryQueryExecutionState, + flintStatement: FlintStatement) = { + import context._ + import state._ + val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None @@ -662,17 +640,9 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult match { case NotVerified => try { - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + ThreadUtils.awaitResult(futurePrepareQueryExecution, PREPARE_QUERY_EXEC_TIMEOUT) match { case Right(_) => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(context, state, flintStatement, startTime) verificationResult = VerifiedWithoutError case Left(error) => verificationResult = VerifiedWithError(error) @@ -713,15 +683,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, startTime)) case VerifiedWithoutError => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(context, state, flintStatement, startTime) } logInfo(s"command complete: $flintStatement") @@ -729,14 +691,13 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeQueryAsync( - spark: SparkSession, + context: QueryExecutionContext, + state: inMemoryQueryExecutionState, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecutionTimeOut: Duration, - queryWaitTimeMillis: Long): DataFrame = { + startTime: Long): DataFrame = { + import context._ + import state._ + if (currentTimeProvider .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( @@ -748,19 +709,36 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } else { val futureQueryExecution = Future { - executeQuery( + val dataFrame = executeQuery( spark, flintStatement.query, dataSource, flintStatement.queryId, sessionId, false) - }(executionContext) + queryResultWriter. + }(futureExecutor) // time out after 10 minutes - ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) + ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeout) } } + + override def executeQuery( + spark: SparkSession, + query: String, + dataSource: String, + queryId: String, + sessionId: String, + streaming: Boolean): DataFrame = { + // Execute SQL query + val startTime = System.currentTimeMillis() + // we have to set job group in the same thread that started the query according to spark doc + spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true) + val result: DataFrame = spark.sql(query) + result + } + private def processCommandInitiation( flintReader: FlintReader, flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { @@ -820,36 +798,31 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader } - class PreShutdownListener( - sessionManager: SessionManager, - sessionId: String, - sessionTimerContext: Timer.Context) + class PreShutdownListener(sessionId: String, sessionManager: SessionManager, sessionTimerContext: Timer.Context) extends SparkListener with Logging { - // TODO: Refactor update - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { + sessionDetails.state = SessionStates.DEAD + sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + } } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } @@ -1008,25 +981,47 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } - private def instantiateSessionManager(): SessionManager = { - val options = FlintSparkConf().flintOptions() - val className = options.getCustomSessionManager() - - if (className.isEmpty) { - new SessionManagerImpl(options) + private def instantiateProvider[T](defaultProvider: => T, providerClassName: String): T = { + if (providerClassName.isEmpty) { + defaultProvider } else { try { - val providerClass = Utils.classForName(className) + val providerClass = Utils.classForName(providerClassName) val ctor = providerClass.getDeclaredConstructor() ctor.setAccessible(true) - ctor.newInstance().asInstanceOf[SessionManager] + ctor.newInstance().asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Failed to instantiate provider: $className", e) + throw new RuntimeException(s"Failed to instantiate provider: $providerClassName", e) } } } + private def instantiateManager[T]( + defaultInstance: => T, + customClassNameOption: () => String): T = { + val customClassName = customClassNameOption() + instantiateProvider(defaultInstance, customClassName) + } + + private def instantiateSessionManager(spark: SparkSession, resultIndex: Option[String]): SessionManager = { + val options = FlintSparkConf().flintOptions() + instantiateManager(new SessionManagerImpl(spark, options, resultIndex), options.getCustomSessionManager) + } + + private def instantiateQueryExecutionManager( + context: Map[String, Any]): StatementLifecycleManager = { + val options = FlintSparkConf().flintOptions() + instantiateManager( + new StatementLifecycleManagerImpl(context), + options.getCustomStatementManager) + } + + private def instantiateQueryResultWriter(context: Map[String, Any]): QueryResultWriter = { + val options = FlintSparkConf().flintOptions() + instantiateManager(new QueryResultWriterImpl(context), options.getCustomQueryResultWriter) + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala index 5108371ef..1d01e2b38 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryExecutionContext.scala @@ -12,8 +12,9 @@ case class QueryExecutionContext( jobId: String, sessionId: String, sessionManager: SessionManager, + statementLifecycleManager: StatementLifecycleManager, + queryResultWriter: QueryResultWriter, dataSource: String, - resultIndex: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala new file mode 100644 index 000000000..cf23c4111 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging + +class QueryResultWriterImpl(context: Map[String, Any]) + extends QueryResultWriter + with FlintJobExecutor + with Logging { + + val spark = context("spark").asInstanceOf[SparkSession] + val dataSource = context("dataSource").asInstanceOf[String] + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + + override def reformat(dataFrame: DataFrame, flintStatement: FlintStatement): DataFrame = { + getFormattedData() + } + + override def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { + writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala index 29f70ddbf..a1d7a1ea1 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -6,20 +6,19 @@ package org.apache.spark.sql import scala.util.{Failure, Success, Try} - import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.storage.FlintReader -import org.opensearch.flint.data.InteractiveSession +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder - import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintREPL.logAndThrow import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode import org.apache.spark.sql.flint.config.FlintSparkConf -class SessionManagerImpl(flintOptions: FlintOptions) +class SessionManagerImpl(spark: SparkSession, flintOptions: FlintOptions, resultIndex: Option[String]) extends SessionManager with FlintJobExecutor with Logging { @@ -32,6 +31,9 @@ class SessionManagerImpl(flintOptions: FlintOptions) if (sessionIndex.isEmpty) { logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } + if (resultIndex.isEmpty) { + logAndThrow("resultIndex is not set") + } if (sessionId.isEmpty) { logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") } @@ -45,7 +47,10 @@ class SessionManagerImpl(flintOptions: FlintOptions) override def getSessionManagerMetadata: Map[String, Any] = { Map( + "sparkSession" -> spark, + "dataSource" -> dataSource, "sessionIndex" -> sessionIndex, + "resultIndex" -> resultIndex.get, "osClient" -> osClient, "flintSessionIndexUpdater" -> flintSessionIndexUpdater, "flintReader" -> flintReader) @@ -65,6 +70,64 @@ class SessionManagerImpl(flintOptions: FlintOptions) } } + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.UPDATE => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.UPSERT => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("_seq_no") + .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .asInstanceOf[Long] + val primaryTerm = sessionDetails + .getContextValue("_primary_term") + .getOrElse( + throw new IllegalArgumentException("Missing _primary_term for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } + + override def getNextStatement(sessionId: String): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + logDebug(s"raw statement: $rawStatement") + val flintStatement = FlintStatement.deserialize(rawStatement) + logDebug(s"statement: $flintStatement") + Some(flintStatement) + } else { + None + } + } + override def recordHeartbeat(sessionId: String): Unit = { flintSessionIndexUpdater.upsert( sessionId, @@ -72,10 +135,6 @@ class SessionManagerImpl(flintOptions: FlintOptions) Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) } - override def hasPendingStatement(sessionId: String): Boolean = { - flintReader.hasNext - } - private def createQueryReader(sessionId: String, sessionIndex: String, dataSource: String) = { // all state in index are in lower case // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the @@ -116,52 +175,4 @@ class SessionManagerImpl(flintOptions: FlintOptions) val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) flintReader } - - override def updateSessionDetails( - sessionDetails: InteractiveSession, - sessionUpdateMode: SessionUpdateMode): Unit = { - sessionUpdateMode match { - case SessionUpdateMode.Update => - flintSessionIndexUpdater.update( - sessionDetails.sessionId, - InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) - case SessionUpdateMode.Upsert => - val includeJobId = - !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( - sessionDetails.jobId) - val serializedSession = if (includeJobId) { - InteractiveSession.serialize( - sessionDetails, - currentTimeProvider.currentEpochMillis(), - true) - } else { - InteractiveSession.serializeWithoutJobId( - sessionDetails, - currentTimeProvider.currentEpochMillis()) - } - flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) - case SessionUpdateMode.UpdateIf => - val executionContext = sessionDetails.executionContext.getOrElse( - throw new IllegalArgumentException("Missing executionContext for conditional update")) - val seqNo = executionContext - .get("_seq_no") - .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) - .asInstanceOf[Long] - val primaryTerm = executionContext - .get("_primary_term") - .getOrElse( - throw new IllegalArgumentException("Missing _primary_term for conditional update")) - .asInstanceOf[Long] - flintSessionIndexUpdater.updateIf( - sessionDetails.sessionId, - InteractiveSession.serializeWithoutJobId( - sessionDetails, - currentTimeProvider.currentEpochMillis()), - seqNo, - primaryTerm) - } - - logInfo( - s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") - } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala new file mode 100644 index 000000000..80401a6af --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging + +class StatementLifecycleManagerImpl(context: Map[String, Any]) + extends StatementLifecycleManager + with FlintJobExecutor + with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + val flintSessionIndexUpdater = + context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] + val flintReader = context("flintReader").asInstanceOf[FlintReader] + + override def initializeStatementLifecycle(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + override def updateStatement(statement: FlintStatement): Unit = { + flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) + } + override def terminateStatementLifecycle(): Unit = { + flintReader.close() + } + override def setLocalProperty(statement: FlintStatement): Unit = { + + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/inMemoryQueryExecutionState.scala similarity index 61% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala rename to spark-sql-application/src/main/scala/org/apache/spark/sql/inMemoryQueryExecutionState.scala index ad49201f0..c09ec823b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/inMemoryQueryExecutionState.scala @@ -9,10 +9,9 @@ import scala.concurrent.{ExecutionContextExecutor, Future} import org.opensearch.flint.core.storage.FlintReader -case class CommandState( +case class inMemoryQueryExecutionState( recordedLastActivityTime: Long, + recordedLastCanPickCheckTime: Long, recordedVerificationResult: VerificationResult, - flintReader: FlintReader, - futureMappingCheck: Future[Either[String, Unit]], - executionContext: ExecutionContextExecutor, - recordedLastCanPickCheckTime: Long) + futurePrepareQueryExecution: Future[Either[String, Unit]], + futureExecutor: ExecutionContextExecutor) diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 19f596e31..5e2e7d6f2 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -71,16 +71,7 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { // Compare the result val result = - FlintJob.getFormattedData( - input, - spark, - dataSourceName, - "10", - "select 1", - "20", - currentTime - queryRunTime, - new MockTimeProvider(currentTime), - CleanerFactory.cleaner(false)) + FlintJob.getFormattedData(input, spark, dataSourceName, "10", "select 1", "20", currentTime - queryRunTime, CleanerFactory.cleaner(false)) assertEqualDataframe(expected, result) } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 45ec7b2cc..15c47a352 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -181,12 +181,7 @@ class FlintREPLTest "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) // Instantiate the listener - val listener = new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex, - sessionId, - timerContext) + val listener = new PreShutdownListener(osClient, flintSessionIndexUpdater) // Simulate application end listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis()))