diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index fe2fa5212..048f69ced 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -21,4 +21,5 @@ case class CommandContext( jobId: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, - queryWaitTimeMillis: Long) + queryWaitTimeMillis: Long, + queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 8cad8844b..9177d9f1e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -52,6 +52,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + private val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100 val INITIAL_DELAY_MILLIS = 3000L val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L @@ -134,7 +135,10 @@ object FlintREPL extends Logging with FlintJobExecutor { SECONDS) val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) @@ -199,7 +203,8 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId, queryExecutionTimeoutSecs, inactivityLimitMillis, - queryWaitTimeoutMillis) + queryWaitTimeoutMillis, + queryLoopExecutionFrequency) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } @@ -342,7 +347,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { - // 1 thread for updating heart beat + // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -392,7 +397,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader.close() } - Thread.sleep(100) + Thread.sleep(commandContext.queryLoopExecutionFrequency) } } finally { if (threadPool != null) { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index d8ddcb665..a0b81efd8 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -599,7 +599,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), 60, - 60) + 60, + 100) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -880,7 +881,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), shortInactivityLimit, - 60) + 60, + 100) // Mock processCommands to always allow loop continuation val getResponse = mock[GetResponse] @@ -900,6 +902,57 @@ class FlintREPLTest spark.stop() } + // TODO: Figure out how to test gracefully + test("queryLoop frequency should be configurable") { + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + val shortInactivityLimit = 100 // 100 milliseconds + val queryLoopExecutionFrequency = 900 // 100 milliseconds + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val flintSessionIndexUpdater = mock[OpenSearchUpdater] + + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + shortInactivityLimit, + 60, + queryLoopExecutionFrequency) + + // Mock processCommands to always allow loop continuation + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + + val startTime = System.currentTimeMillis() + FlintREPL.queryLoop(commandContext) + val endTime = System.currentTimeMillis() + + assert(endTime - startTime >= shortInactivityLimit + 100) + assert(endTime - startTime >= shortInactivityLimit + queryLoopExecutionFrequency) + + // Stop the SparkSession + spark.stop() + FlintREPL.threadPoolFactory = new DefaultThreadPoolFactory() + } + test("queryLoop should stop when canPickUpNextStatement is false") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] @@ -930,7 +983,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), longInactivityLimit, - 60) + 60, + 100) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -986,7 +1040,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) try { // Mocking ThreadUtils to track the shutdown call @@ -1036,7 +1091,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) try { // Mocking ThreadUtils to track the shutdown call @@ -1117,7 +1173,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) val startTime = Instant.now().toEpochMilli() @@ -1167,7 +1224,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) val startTime = Instant.now().toEpochMilli()