From 40d2c08a49b8db5cd0282119df1c75eb604ae8c7 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 3 Jul 2024 17:14:10 -0700 Subject: [PATCH] Add config on query loop execution frequency Signed-off-by: Louis Chu --- .../apache/spark/sql/FlintREPLITSuite.scala | 425 +++++++++--------- .../org/apache/spark/sql/CommandContext.scala | 3 +- .../org/apache/spark/sql/FlintREPL.scala | 13 +- .../org/apache/spark/sql/FlintREPLTest.scala | 119 ++--- 4 files changed, 301 insertions(+), 259 deletions(-) diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 921db792a..72a58303e 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -18,7 +18,7 @@ import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.opensearch.search.sort.SortOrder +import org.scalatest.prop.TableDrivenPropertyChecks._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} @@ -142,7 +142,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { index(requestIndex, oneNodeSetting, requestIndexMapping, docs) } - def startREPL(): Future[Unit] = { + def startREPL(queryLoopExecutionFrequency: Long): Future[Unit] = { val prefix = "flint-repl-test" val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -164,6 +164,11 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) System.setProperty(REFRESH_POLICY.key, "true") + // Set the query loop execution frequency + System.setProperty( + "spark.flint.job.queryLoopExecutionFrequency", + queryLoopExecutionFrequency.toString) + FlintREPL.envinromentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) FlintREPL.enableHiveSupport = false @@ -212,220 +217,232 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { statementId } - test("sanity") { - try { - createSession(jobRunId, "") - threadLocalFuture.set(startREPL()) - - val createStatement = - s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\\t' - | ) - |""".stripMargin - submitQuery(s"${makeJsonCompliant(createStatement)}", "99") - - val insertStatement = - s""" - | INSERT INTO $testTable - | VALUES ('Hello', 30) - | """.stripMargin - submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") - - val selectQueryId = "101" - val selectQueryStartTime = System.currentTimeMillis() - val selectQuery = s"SELECT name, age FROM $testTable".stripMargin - val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) - - val describeStatement = s"DESC $testTable".stripMargin - val descQueryId = "102" - val descStartTime = System.currentTimeMillis() - val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) - - val showTableStatement = - s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" - val showQueryId = "103" - val showStartTime = System.currentTimeMillis() - val showTableStatementId = - submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) - - val wrongSelectQueryId = "104" - val wrongSelectQueryStartTime = System.currentTimeMillis() - val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin - val wrongSelectStatementId = - submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) - - val lateSelectQueryId = "105" - val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin - // submitted from last year. We won't pick it up - val lateSelectStatementId = - submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) - - // clean up - val dropStatement = - s"""DROP TABLE $testTable""".stripMargin - submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") - - val selectQueryValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 1, - s"expected result size is 1, but got ${result.results.size}") - val expectedResult = "{'name':'Hello','age':30}" - assert( - result.results(0).equals(expectedResult), - s"expected result is $expectedResult, but got ${result.results(0)}") - assert( - result.schemas.size == 2, - s"expected schema size is 2, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" - assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" - assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) - successValidation(result) - true - } - pollForResultAndAssert(selectQueryValidation, selectQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - selectStatementId), - s"Fail to verify for $selectStatementId.") - - val descValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 2, - s"expected result size is 2, but got ${result.results.size}") - val expectedResult0 = "{'col_name':'name','data_type':'string'}" - assert( - result.results(0).equals(expectedResult0), - s"expected result is $expectedResult0, but got ${result.results(0)}") - val expectedResult1 = "{'col_name':'age','data_type':'int'}" - assert( - result.results(1).equals(expectedResult1), - s"expected result is $expectedResult1, but got ${result.results(1)}") - assert( - result.schemas.size == 3, - s"expected schema size is 3, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" - assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" - assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" - assert( - result.schemas(2).equals(expectedSecondSchema), - s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") - commonValidation(result, descQueryId, describeStatement, descStartTime) - successValidation(result) - true - } - pollForResultAndAssert(descValidation, descQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - descStatementId), - s"Fail to verify for $descStatementId.") - - val showValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 1, - s"expected result size is 1, but got ${result.results.size}") - val expectedResult = - "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" - assert( - result.results(0).equals(expectedResult), - s"expected result is $expectedResult, but got ${result.results(0)}") - assert( - result.schemas.size == 3, - s"expected schema size is 3, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + // Define the test parameters + val testParams = Table( + ("testName", "queryLoopExecutionFrequency"), + ("Sanity with 100ms frequency", 100), + ("Sanity with 1000ms frequency", 1000)) + + forAll(testParams) { (testName: String, queryLoopExecutionFrequency: Int) => + test(testName) { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL(queryLoopExecutionFrequency)) + + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "99") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") + + val selectQueryId = "101" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val describeStatement = s"DESC $testTable".stripMargin + val descQueryId = "102" + val descStartTime = System.currentTimeMillis() + val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) + + val showTableStatement = + s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" + val showQueryId = "103" + val showStartTime = System.currentTimeMillis() + val showTableStatementId = + submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) + + val wrongSelectQueryId = "104" + val wrongSelectQueryStartTime = System.currentTimeMillis() + val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin + val wrongSelectStatementId = + submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) + + val lateSelectQueryId = "105" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // submitted from last year. We won't pick it up + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + val descValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 2, + s"expected result size is 2, but got ${result.results.size}") + val expectedResult0 = "{'col_name':'name','data_type':'string'}" + assert( + result.results(0).equals(expectedResult0), + s"expected result is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = "{'col_name':'age','data_type':'int'}" + assert( + result.results(1).equals(expectedResult1), + s"expected result is $expectedResult1, but got ${result.results(1)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, descQueryId, describeStatement, descStartTime) + successValidation(result) + true + } + pollForResultAndAssert(descValidation, descQueryId) assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + descStatementId), + s"Fail to verify for $descStatementId.") + + val showValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = + "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, showQueryId, showTableStatement, showStartTime) + successValidation(result) + true + } + pollForResultAndAssert(showValidation, showQueryId) assert( - result.schemas(2).equals(expectedSecondSchema), - s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") - commonValidation(result, showQueryId, showTableStatement, showStartTime) - successValidation(result) - true - } - pollForResultAndAssert(showValidation, showQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - showTableStatementId), - s"Fail to verify for $showTableStatementId.") - - val wrongSelectQueryValidation: REPLResult => Boolean = result => { + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + showTableStatementId), + s"Fail to verify for $showTableStatementId.") + + val wrongSelectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + commonValidation( + result, + wrongSelectQueryId, + wrongSelectQuery, + wrongSelectQueryStartTime) + failureValidation(result) + true + } + pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) assert( - result.results.size == 0, - s"expected result size is 0, but got ${result.results.size}") + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "failed" + }, + wrongSelectStatementId), + s"Fail to verify for $wrongSelectStatementId.") + + // expect time out as this statement should not be picked up assert( - result.schemas.size == 0, - s"expected schema size is 0, but got ${result.schemas.size}") - commonValidation(result, wrongSelectQueryId, wrongSelectQuery, wrongSelectQueryStartTime) - failureValidation(result) - true + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. } - pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "failed" - }, - wrongSelectStatementId), - s"Fail to verify for $wrongSelectStatementId.") - - // expect time out as this statement should not be picked up - assert( - awaitConditionForStatementOrTimeout( - statement => { - statement.state != "waiting" - }, - lateSelectStatementId), - s"Fail to verify for $lateSelectStatementId.") - } catch { - case e: Exception => - logError("Unexpected exception", e) - assert(false, "Unexpected exception") - } finally { - waitREPLStop(threadLocalFuture.get()) - threadLocalFuture.remove() - - // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. } } test("create table with dummy location should fail with excepted error message") { try { createSession(jobRunId, "") - threadLocalFuture.set(startREPL()) + threadLocalFuture.set(startREPL(100L)) val dummyLocation = "s3://path/to/dummy/location" val testQueryId = "110" diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index fe2fa5212..048f69ced 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -21,4 +21,5 @@ case class CommandContext( jobId: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, - queryWaitTimeMillis: Long) + queryWaitTimeMillis: Long, + queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 8cad8844b..e0ab37acd 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -52,6 +52,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + private val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L val INITIAL_DELAY_MILLIS = 3000L val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L @@ -134,7 +135,10 @@ object FlintREPL extends Logging with FlintJobExecutor { SECONDS) val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) @@ -199,7 +203,8 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId, queryExecutionTimeoutSecs, inactivityLimitMillis, - queryWaitTimeoutMillis) + queryWaitTimeoutMillis, + queryLoopExecutionFrequency) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } @@ -342,7 +347,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { - // 1 thread for updating heart beat + // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -392,7 +397,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader.close() } - Thread.sleep(100) + Thread.sleep(commandContext.queryLoopExecutionFrequency) } } finally { if (threadPool != null) { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index d8ddcb665..3e38e6e35 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -17,14 +17,15 @@ import scala.reflect.runtime.universe.TypeTag import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer -import org.mockito.ArgumentMatchersSugar -import org.mockito.Mockito._ +import org.mockito.{ArgumentMatchersSugar, Mockito} +import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.flint.data.FlintStatement import org.opensearch.search.sort.SortOrder +import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -599,7 +600,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), 60, - 60) + 60, + 100) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -880,7 +882,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), shortInactivityLimit, - 60) + 60, + 100) // Mock processCommands to always allow loop continuation val getResponse = mock[GetResponse] @@ -930,7 +933,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), longInactivityLimit, - 60) + 60, + 100) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -986,7 +990,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) try { // Mocking ThreadUtils to track the shutdown call @@ -1036,7 +1041,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) try { // Mocking ThreadUtils to track the shutdown call @@ -1117,7 +1123,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) val startTime = Instant.now().toEpochMilli() @@ -1131,58 +1138,70 @@ class FlintREPLTest verify(osClient, times(1)).getIndexMetadata(*) } - test("queryLoop should execute loop without processing any commands") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) - - // Configure mockReader to always return false, indicating no commands to process - when(mockReader.hasNext).thenReturn(false) + val testCases = Table( + ("inactivityLimit", "queryLoopExecutionFrequency"), + (5000, 100), // 5 seconds, 100 ms + (100, 300) // 100 ms, 300 ms + ) - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" + test( + "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { + forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val inactivityLimit = 5000 // 5 seconds + val flintSessionIndexUpdater = mock[OpenSearchUpdater] - // Create a SparkSession for testing - val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60, + queryLoopExecutionFrequency) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val startTime = Instant.now().toEpochMilli() - val commandContext = CommandContext( - spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, - Duration(10, MINUTES), - inactivityLimit, - 60) + // Running the queryLoop + FlintREPL.queryLoop(commandContext) - val startTime = Instant.now().toEpochMilli() + val endTime = Instant.now().toEpochMilli() - // Running the queryLoop - FlintREPL.queryLoop(commandContext) + val elapsedTime = endTime - startTime - val endTime = Instant.now().toEpochMilli() + // Assert that the loop ran for at least the duration of the inactivity limit + assert(elapsedTime >= inactivityLimit) - // Assert that the loop ran for at least the duration of the inactivity limit - assert(endTime - startTime >= inactivityLimit) + // Verify query execution frequency + val expectedCalls = Math.ceil(elapsedTime.toDouble / queryLoopExecutionFrequency).toInt + verify(mockReader, Mockito.atMost(expectedCalls)).hasNext - // Verify that no command was actually processed - verify(mockReader, never()).next() + // Verify that no command was actually processed + verify(mockReader, never()).next() - // Stop the SparkSession - spark.stop() + // Stop the SparkSession + spark.stop() + } } }