From 40d2c08a49b8db5cd0282119df1c75eb604ae8c7 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 3 Jul 2024 17:14:10 -0700 Subject: [PATCH 1/2] Add config on query loop execution frequency Signed-off-by: Louis Chu --- .../apache/spark/sql/FlintREPLITSuite.scala | 425 +++++++++--------- .../org/apache/spark/sql/CommandContext.scala | 3 +- .../org/apache/spark/sql/FlintREPL.scala | 13 +- .../org/apache/spark/sql/FlintREPLTest.scala | 119 ++--- 4 files changed, 301 insertions(+), 259 deletions(-) diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 921db792a..72a58303e 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -18,7 +18,7 @@ import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.opensearch.search.sort.SortOrder +import org.scalatest.prop.TableDrivenPropertyChecks._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} @@ -142,7 +142,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { index(requestIndex, oneNodeSetting, requestIndexMapping, docs) } - def startREPL(): Future[Unit] = { + def startREPL(queryLoopExecutionFrequency: Long): Future[Unit] = { val prefix = "flint-repl-test" val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -164,6 +164,11 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) System.setProperty(REFRESH_POLICY.key, "true") + // Set the query loop execution frequency + System.setProperty( + "spark.flint.job.queryLoopExecutionFrequency", + queryLoopExecutionFrequency.toString) + FlintREPL.envinromentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) FlintREPL.enableHiveSupport = false @@ -212,220 +217,232 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { statementId } - test("sanity") { - try { - createSession(jobRunId, "") - threadLocalFuture.set(startREPL()) - - val createStatement = - s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\\t' - | ) - |""".stripMargin - submitQuery(s"${makeJsonCompliant(createStatement)}", "99") - - val insertStatement = - s""" - | INSERT INTO $testTable - | VALUES ('Hello', 30) - | """.stripMargin - submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") - - val selectQueryId = "101" - val selectQueryStartTime = System.currentTimeMillis() - val selectQuery = s"SELECT name, age FROM $testTable".stripMargin - val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) - - val describeStatement = s"DESC $testTable".stripMargin - val descQueryId = "102" - val descStartTime = System.currentTimeMillis() - val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) - - val showTableStatement = - s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" - val showQueryId = "103" - val showStartTime = System.currentTimeMillis() - val showTableStatementId = - submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) - - val wrongSelectQueryId = "104" - val wrongSelectQueryStartTime = System.currentTimeMillis() - val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin - val wrongSelectStatementId = - submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) - - val lateSelectQueryId = "105" - val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin - // submitted from last year. We won't pick it up - val lateSelectStatementId = - submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) - - // clean up - val dropStatement = - s"""DROP TABLE $testTable""".stripMargin - submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") - - val selectQueryValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 1, - s"expected result size is 1, but got ${result.results.size}") - val expectedResult = "{'name':'Hello','age':30}" - assert( - result.results(0).equals(expectedResult), - s"expected result is $expectedResult, but got ${result.results(0)}") - assert( - result.schemas.size == 2, - s"expected schema size is 2, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" - assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" - assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) - successValidation(result) - true - } - pollForResultAndAssert(selectQueryValidation, selectQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - selectStatementId), - s"Fail to verify for $selectStatementId.") - - val descValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 2, - s"expected result size is 2, but got ${result.results.size}") - val expectedResult0 = "{'col_name':'name','data_type':'string'}" - assert( - result.results(0).equals(expectedResult0), - s"expected result is $expectedResult0, but got ${result.results(0)}") - val expectedResult1 = "{'col_name':'age','data_type':'int'}" - assert( - result.results(1).equals(expectedResult1), - s"expected result is $expectedResult1, but got ${result.results(1)}") - assert( - result.schemas.size == 3, - s"expected schema size is 3, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" - assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" - assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" - assert( - result.schemas(2).equals(expectedSecondSchema), - s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") - commonValidation(result, descQueryId, describeStatement, descStartTime) - successValidation(result) - true - } - pollForResultAndAssert(descValidation, descQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - descStatementId), - s"Fail to verify for $descStatementId.") - - val showValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 1, - s"expected result size is 1, but got ${result.results.size}") - val expectedResult = - "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" - assert( - result.results(0).equals(expectedResult), - s"expected result is $expectedResult, but got ${result.results(0)}") - assert( - result.schemas.size == 3, - s"expected schema size is 3, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + // Define the test parameters + val testParams = Table( + ("testName", "queryLoopExecutionFrequency"), + ("Sanity with 100ms frequency", 100), + ("Sanity with 1000ms frequency", 1000)) + + forAll(testParams) { (testName: String, queryLoopExecutionFrequency: Int) => + test(testName) { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL(queryLoopExecutionFrequency)) + + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "99") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") + + val selectQueryId = "101" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val describeStatement = s"DESC $testTable".stripMargin + val descQueryId = "102" + val descStartTime = System.currentTimeMillis() + val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) + + val showTableStatement = + s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" + val showQueryId = "103" + val showStartTime = System.currentTimeMillis() + val showTableStatementId = + submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) + + val wrongSelectQueryId = "104" + val wrongSelectQueryStartTime = System.currentTimeMillis() + val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin + val wrongSelectStatementId = + submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) + + val lateSelectQueryId = "105" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // submitted from last year. We won't pick it up + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + val descValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 2, + s"expected result size is 2, but got ${result.results.size}") + val expectedResult0 = "{'col_name':'name','data_type':'string'}" + assert( + result.results(0).equals(expectedResult0), + s"expected result is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = "{'col_name':'age','data_type':'int'}" + assert( + result.results(1).equals(expectedResult1), + s"expected result is $expectedResult1, but got ${result.results(1)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, descQueryId, describeStatement, descStartTime) + successValidation(result) + true + } + pollForResultAndAssert(descValidation, descQueryId) assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + descStatementId), + s"Fail to verify for $descStatementId.") + + val showValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = + "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, showQueryId, showTableStatement, showStartTime) + successValidation(result) + true + } + pollForResultAndAssert(showValidation, showQueryId) assert( - result.schemas(2).equals(expectedSecondSchema), - s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") - commonValidation(result, showQueryId, showTableStatement, showStartTime) - successValidation(result) - true - } - pollForResultAndAssert(showValidation, showQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - showTableStatementId), - s"Fail to verify for $showTableStatementId.") - - val wrongSelectQueryValidation: REPLResult => Boolean = result => { + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + showTableStatementId), + s"Fail to verify for $showTableStatementId.") + + val wrongSelectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + commonValidation( + result, + wrongSelectQueryId, + wrongSelectQuery, + wrongSelectQueryStartTime) + failureValidation(result) + true + } + pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) assert( - result.results.size == 0, - s"expected result size is 0, but got ${result.results.size}") + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "failed" + }, + wrongSelectStatementId), + s"Fail to verify for $wrongSelectStatementId.") + + // expect time out as this statement should not be picked up assert( - result.schemas.size == 0, - s"expected schema size is 0, but got ${result.schemas.size}") - commonValidation(result, wrongSelectQueryId, wrongSelectQuery, wrongSelectQueryStartTime) - failureValidation(result) - true + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. } - pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) - assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "failed" - }, - wrongSelectStatementId), - s"Fail to verify for $wrongSelectStatementId.") - - // expect time out as this statement should not be picked up - assert( - awaitConditionForStatementOrTimeout( - statement => { - statement.state != "waiting" - }, - lateSelectStatementId), - s"Fail to verify for $lateSelectStatementId.") - } catch { - case e: Exception => - logError("Unexpected exception", e) - assert(false, "Unexpected exception") - } finally { - waitREPLStop(threadLocalFuture.get()) - threadLocalFuture.remove() - - // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. } } test("create table with dummy location should fail with excepted error message") { try { createSession(jobRunId, "") - threadLocalFuture.set(startREPL()) + threadLocalFuture.set(startREPL(100L)) val dummyLocation = "s3://path/to/dummy/location" val testQueryId = "110" diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index fe2fa5212..048f69ced 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -21,4 +21,5 @@ case class CommandContext( jobId: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, - queryWaitTimeMillis: Long) + queryWaitTimeMillis: Long, + queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 8cad8844b..e0ab37acd 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -52,6 +52,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + private val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L val INITIAL_DELAY_MILLIS = 3000L val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L @@ -134,7 +135,10 @@ object FlintREPL extends Logging with FlintJobExecutor { SECONDS) val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) @@ -199,7 +203,8 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId, queryExecutionTimeoutSecs, inactivityLimitMillis, - queryWaitTimeoutMillis) + queryWaitTimeoutMillis, + queryLoopExecutionFrequency) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } @@ -342,7 +347,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { - // 1 thread for updating heart beat + // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -392,7 +397,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader.close() } - Thread.sleep(100) + Thread.sleep(commandContext.queryLoopExecutionFrequency) } } finally { if (threadPool != null) { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index d8ddcb665..3e38e6e35 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -17,14 +17,15 @@ import scala.reflect.runtime.universe.TypeTag import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer -import org.mockito.ArgumentMatchersSugar -import org.mockito.Mockito._ +import org.mockito.{ArgumentMatchersSugar, Mockito} +import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.flint.data.FlintStatement import org.opensearch.search.sort.SortOrder +import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -599,7 +600,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), 60, - 60) + 60, + 100) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -880,7 +882,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), shortInactivityLimit, - 60) + 60, + 100) // Mock processCommands to always allow loop continuation val getResponse = mock[GetResponse] @@ -930,7 +933,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), longInactivityLimit, - 60) + 60, + 100) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -986,7 +990,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) try { // Mocking ThreadUtils to track the shutdown call @@ -1036,7 +1041,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) try { // Mocking ThreadUtils to track the shutdown call @@ -1117,7 +1123,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + 100) val startTime = Instant.now().toEpochMilli() @@ -1131,58 +1138,70 @@ class FlintREPLTest verify(osClient, times(1)).getIndexMetadata(*) } - test("queryLoop should execute loop without processing any commands") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) - - // Configure mockReader to always return false, indicating no commands to process - when(mockReader.hasNext).thenReturn(false) + val testCases = Table( + ("inactivityLimit", "queryLoopExecutionFrequency"), + (5000, 100), // 5 seconds, 100 ms + (100, 300) // 100 ms, 300 ms + ) - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" + test( + "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { + forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val inactivityLimit = 5000 // 5 seconds + val flintSessionIndexUpdater = mock[OpenSearchUpdater] - // Create a SparkSession for testing - val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60, + queryLoopExecutionFrequency) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val startTime = Instant.now().toEpochMilli() - val commandContext = CommandContext( - spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, - Duration(10, MINUTES), - inactivityLimit, - 60) + // Running the queryLoop + FlintREPL.queryLoop(commandContext) - val startTime = Instant.now().toEpochMilli() + val endTime = Instant.now().toEpochMilli() - // Running the queryLoop - FlintREPL.queryLoop(commandContext) + val elapsedTime = endTime - startTime - val endTime = Instant.now().toEpochMilli() + // Assert that the loop ran for at least the duration of the inactivity limit + assert(elapsedTime >= inactivityLimit) - // Assert that the loop ran for at least the duration of the inactivity limit - assert(endTime - startTime >= inactivityLimit) + // Verify query execution frequency + val expectedCalls = Math.ceil(elapsedTime.toDouble / queryLoopExecutionFrequency).toInt + verify(mockReader, Mockito.atMost(expectedCalls)).hasNext - // Verify that no command was actually processed - verify(mockReader, never()).next() + // Verify that no command was actually processed + verify(mockReader, never()).next() - // Stop the SparkSession - spark.stop() + // Stop the SparkSession + spark.stop() + } } } From 4345b16a540a83960f1afcb55aabc2f5af529ed0 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Tue, 6 Aug 2024 15:42:57 -0700 Subject: [PATCH 2/2] Fix IT and address comments Signed-off-by: Louis Chu --- .../apache/spark/sql/FlintREPLITSuite.scala | 534 ++++++++++-------- .../org/apache/spark/sql/FlintREPL.scala | 23 +- .../org/apache/spark/sql/FlintREPLTest.scala | 17 +- 3 files changed, 330 insertions(+), 244 deletions(-) diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 72a58303e..d2a43a877 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -18,9 +18,10 @@ import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils @@ -130,19 +131,20 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { def createSession(jobId: String, excludeJobId: String): Unit = { val docs = Seq(s"""{ - | "state": "running", - | "lastUpdateTime": 1698796582978, - | "applicationId": "00fd777k3k3ls20p", - | "error": "", - | "sessionId": ${sessionId}, - | "jobId": \"${jobId}\", - | "type": "session", - | "excludeJobIds": [\"${excludeJobId}\"] - |}""".stripMargin) + | "state": "running", + | "lastUpdateTime": 1698796582978, + | "applicationId": "00fd777k3k3ls20p", + | "error": "", + | "sessionId": ${sessionId}, + | "jobId": \"${jobId}\", + | "type": "session", + | "excludeJobIds": [\"${excludeJobId}\"] + |}""".stripMargin) index(requestIndex, oneNodeSetting, requestIndexMapping, docs) } - def startREPL(queryLoopExecutionFrequency: Long): Future[Unit] = { + def startREPL(queryLoopExecutionFrequency: Long = DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + : Future[Unit] = { val prefix = "flint-repl-test" val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -164,7 +166,6 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) System.setProperty(REFRESH_POLICY.key, "true") - // Set the query loop execution frequency System.setProperty( "spark.flint.job.queryLoopExecutionFrequency", queryLoopExecutionFrequency.toString) @@ -217,232 +218,220 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { statementId } - // Define the test parameters - val testParams = Table( - ("testName", "queryLoopExecutionFrequency"), - ("Sanity with 100ms frequency", 100), - ("Sanity with 1000ms frequency", 1000)) - - forAll(testParams) { (testName: String, queryLoopExecutionFrequency: Int) => - test(testName) { - try { - createSession(jobRunId, "") - threadLocalFuture.set(startREPL(queryLoopExecutionFrequency)) - - val createStatement = - s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\\t' - | ) - |""".stripMargin - submitQuery(s"${makeJsonCompliant(createStatement)}", "99") - - val insertStatement = - s""" - | INSERT INTO $testTable - | VALUES ('Hello', 30) - | """.stripMargin - submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") - - val selectQueryId = "101" - val selectQueryStartTime = System.currentTimeMillis() - val selectQuery = s"SELECT name, age FROM $testTable".stripMargin - val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) - - val describeStatement = s"DESC $testTable".stripMargin - val descQueryId = "102" - val descStartTime = System.currentTimeMillis() - val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) - - val showTableStatement = - s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" - val showQueryId = "103" - val showStartTime = System.currentTimeMillis() - val showTableStatementId = - submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) - - val wrongSelectQueryId = "104" - val wrongSelectQueryStartTime = System.currentTimeMillis() - val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin - val wrongSelectStatementId = - submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) - - val lateSelectQueryId = "105" - val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin - // submitted from last year. We won't pick it up - val lateSelectStatementId = - submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) - - // clean up - val dropStatement = - s"""DROP TABLE $testTable""".stripMargin - submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") - - val selectQueryValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 1, - s"expected result size is 1, but got ${result.results.size}") - val expectedResult = "{'name':'Hello','age':30}" - assert( - result.results(0).equals(expectedResult), - s"expected result is $expectedResult, but got ${result.results(0)}") - assert( - result.schemas.size == 2, - s"expected schema size is 2, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" - assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" - assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) - successValidation(result) - true - } - pollForResultAndAssert(selectQueryValidation, selectQueryId) + test("sanity") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL()) + + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "99") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") + + val selectQueryId = "101" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val describeStatement = s"DESC $testTable".stripMargin + val descQueryId = "102" + val descStartTime = System.currentTimeMillis() + val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) + + val showTableStatement = + s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" + val showQueryId = "103" + val showStartTime = System.currentTimeMillis() + val showTableStatementId = + submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) + + val wrongSelectQueryId = "104" + val wrongSelectQueryStartTime = System.currentTimeMillis() + val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin + val wrongSelectStatementId = + submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) + + val lateSelectQueryId = "105" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // submitted from last year. We won't pick it up + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - selectStatementId), - s"Fail to verify for $selectStatementId.") - - val descValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 2, - s"expected result size is 2, but got ${result.results.size}") - val expectedResult0 = "{'col_name':'name','data_type':'string'}" - assert( - result.results(0).equals(expectedResult0), - s"expected result is $expectedResult0, but got ${result.results(0)}") - val expectedResult1 = "{'col_name':'age','data_type':'int'}" - assert( - result.results(1).equals(expectedResult1), - s"expected result is $expectedResult1, but got ${result.results(1)}") - assert( - result.schemas.size == 3, - s"expected schema size is 3, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" - assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" - assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" - assert( - result.schemas(2).equals(expectedSecondSchema), - s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") - commonValidation(result, descQueryId, describeStatement, descStartTime) - successValidation(result) - true - } - pollForResultAndAssert(descValidation, descQueryId) + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - descStatementId), - s"Fail to verify for $descStatementId.") - - val showValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 1, - s"expected result size is 1, but got ${result.results.size}") - val expectedResult = - "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" - assert( - result.results(0).equals(expectedResult), - s"expected result is $expectedResult, but got ${result.results(0)}") - assert( - result.schemas.size == 3, - s"expected schema size is 3, but got ${result.schemas.size}") - val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" - assert( - result.schemas(0).equals(expectedZerothSchema), - s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") - val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" - assert( - result.schemas(1).equals(expectedFirstSchema), - s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") - val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" - assert( - result.schemas(2).equals(expectedSecondSchema), - s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") - commonValidation(result, showQueryId, showTableStatement, showStartTime) - successValidation(result) - true - } - pollForResultAndAssert(showValidation, showQueryId) + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "success" - }, - showTableStatementId), - s"Fail to verify for $showTableStatementId.") - - val wrongSelectQueryValidation: REPLResult => Boolean = result => { - assert( - result.results.size == 0, - s"expected result size is 0, but got ${result.results.size}") - assert( - result.schemas.size == 0, - s"expected schema size is 0, but got ${result.schemas.size}") - commonValidation( - result, - wrongSelectQueryId, - wrongSelectQuery, - wrongSelectQueryStartTime) - failureValidation(result) - true - } - pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" assert( - !awaitConditionForStatementOrTimeout( - statement => { - statement.state == "failed" - }, - wrongSelectStatementId), - s"Fail to verify for $wrongSelectStatementId.") - - // expect time out as this statement should not be picked up + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" assert( - awaitConditionForStatementOrTimeout( - statement => { - statement.state != "waiting" - }, - lateSelectStatementId), - s"Fail to verify for $lateSelectStatementId.") - } catch { - case e: Exception => - logError("Unexpected exception", e) - assert(false, "Unexpected exception") - } finally { - waitREPLStop(threadLocalFuture.get()) - threadLocalFuture.remove() - - // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + val descValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 2, + s"expected result size is 2, but got ${result.results.size}") + val expectedResult0 = "{'col_name':'name','data_type':'string'}" + assert( + result.results(0).equals(expectedResult0), + s"expected result is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = "{'col_name':'age','data_type':'int'}" + assert( + result.results(1).equals(expectedResult1), + s"expected result is $expectedResult1, but got ${result.results(1)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, descQueryId, describeStatement, descStartTime) + successValidation(result) + true + } + pollForResultAndAssert(descValidation, descQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + descStatementId), + s"Fail to verify for $descStatementId.") + + val showValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = + "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, showQueryId, showTableStatement, showStartTime) + successValidation(result) + true + } + pollForResultAndAssert(showValidation, showQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + showTableStatementId), + s"Fail to verify for $showTableStatementId.") + + val wrongSelectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + commonValidation(result, wrongSelectQueryId, wrongSelectQuery, wrongSelectQueryStartTime) + failureValidation(result) + true } + pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "failed" + }, + wrongSelectStatementId), + s"Fail to verify for $wrongSelectStatementId.") + + // expect time out as this statement should not be picked up + assert( + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. } } test("create table with dummy location should fail with excepted error message") { try { createSession(jobRunId, "") - threadLocalFuture.set(startREPL(100L)) + threadLocalFuture.set(startREPL()) val dummyLocation = "s3://path/to/dummy/location" val testQueryId = "110" @@ -502,6 +491,99 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } } + test("query loop should exit with inactivity timeout due to large query loop freq") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL(5000L)) + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "119") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "120") + + val selectQueryId = "121" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val lateSelectQueryId = "122" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // old query + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + assert( + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. + } + } + /** * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or * removed when inside a JSON string. The same goes for tab characters, which should be diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index e0ab37acd..782dd04c2 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -30,9 +30,20 @@ import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.FlintREPLConfConstants._ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.ThreadUtils +object FlintREPLConfConstants { + val HEARTBEAT_INTERVAL_MILLIS = 60000L + val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) + val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L + val INITIAL_DELAY_MILLIS = 3000L + val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L +} + /** * Spark SQL Application entrypoint * @@ -48,14 +59,6 @@ import org.apache.spark.util.ThreadUtils */ object FlintREPL extends Logging with FlintJobExecutor { - private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) - private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) - private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 - private val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L - val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L - @volatile var earlyExitFlag: Boolean = false def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { @@ -560,8 +563,8 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { + // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionId, jobId, osClient, sessionIndex) lastCanPickCheckTime = currentTime diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 3e38e6e35..ef5db02dc 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -31,6 +31,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.SparkListenerApplicationEnd import org.apache.spark.sql.FlintREPL.PreShutdownListener +import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin @@ -601,7 +602,7 @@ class FlintREPLTest Duration(10, MINUTES), 60, 60, - 100) + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -883,7 +884,7 @@ class FlintREPLTest Duration(10, MINUTES), shortInactivityLimit, 60, - 100) + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mock processCommands to always allow loop continuation val getResponse = mock[GetResponse] @@ -934,7 +935,7 @@ class FlintREPLTest Duration(10, MINUTES), longInactivityLimit, 60, - 100) + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -991,7 +992,7 @@ class FlintREPLTest Duration(10, MINUTES), inactivityLimit, 60, - 100) + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) try { // Mocking ThreadUtils to track the shutdown call @@ -1042,7 +1043,7 @@ class FlintREPLTest Duration(10, MINUTES), inactivityLimit, 60, - 100) + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) try { // Mocking ThreadUtils to track the shutdown call @@ -1124,7 +1125,7 @@ class FlintREPLTest Duration(10, MINUTES), inactivityLimit, 60, - 100) + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val startTime = Instant.now().toEpochMilli() @@ -1140,8 +1141,8 @@ class FlintREPLTest val testCases = Table( ("inactivityLimit", "queryLoopExecutionFrequency"), - (5000, 100), // 5 seconds, 100 ms - (100, 300) // 100 ms, 300 ms + (5000, 100L), // 5 seconds, 100 ms + (100, 300L) // 100 ms, 300 ms ) test(