From 5be9be603fe8f49eb40eca7dc024016e2e3711f2 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Tue, 6 Aug 2024 18:08:01 -0700 Subject: [PATCH] Add config on query loop execution frequency (#411) * Add config on query loop execution frequency Signed-off-by: Louis Chu * Fix IT and address comments Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .../apache/spark/sql/FlintREPLITSuite.scala | 121 ++++++++++++++++-- .../org/apache/spark/sql/CommandContext.scala | 3 +- .../org/apache/spark/sql/FlintREPL.scala | 34 +++-- .../org/apache/spark/sql/FlintREPLTest.scala | 120 +++++++++-------- 4 files changed, 203 insertions(+), 75 deletions(-) diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 921db792a..d2a43a877 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -21,6 +21,7 @@ import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils @@ -130,19 +131,20 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { def createSession(jobId: String, excludeJobId: String): Unit = { val docs = Seq(s"""{ - | "state": "running", - | "lastUpdateTime": 1698796582978, - | "applicationId": "00fd777k3k3ls20p", - | "error": "", - | "sessionId": ${sessionId}, - | "jobId": \"${jobId}\", - | "type": "session", - | "excludeJobIds": [\"${excludeJobId}\"] - |}""".stripMargin) + | "state": "running", + | "lastUpdateTime": 1698796582978, + | "applicationId": "00fd777k3k3ls20p", + | "error": "", + | "sessionId": ${sessionId}, + | "jobId": \"${jobId}\", + | "type": "session", + | "excludeJobIds": [\"${excludeJobId}\"] + |}""".stripMargin) index(requestIndex, oneNodeSetting, requestIndexMapping, docs) } - def startREPL(): Future[Unit] = { + def startREPL(queryLoopExecutionFrequency: Long = DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + : Future[Unit] = { val prefix = "flint-repl-test" val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -164,6 +166,10 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) System.setProperty(REFRESH_POLICY.key, "true") + System.setProperty( + "spark.flint.job.queryLoopExecutionFrequency", + queryLoopExecutionFrequency.toString) + FlintREPL.envinromentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) FlintREPL.enableHiveSupport = false @@ -266,7 +272,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin // submitted from last year. We won't pick it up val lateSelectStatementId = - submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) + submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L) // clean up val dropStatement = @@ -485,6 +491,99 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } } + test("query loop should exit with inactivity timeout due to large query loop freq") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL(5000L)) + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "119") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "120") + + val selectQueryId = "121" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val lateSelectQueryId = "122" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // old query + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + assert( + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. + } + } + /** * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or * removed when inside a JSON string. The same goes for tab characters, which should be diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index fe2fa5212..048f69ced 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -21,4 +21,5 @@ case class CommandContext( jobId: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, - queryWaitTimeMillis: Long) + queryWaitTimeMillis: Long, + queryLoopExecutionFrequency: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 8cad8844b..782dd04c2 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -30,9 +30,20 @@ import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.FlintREPLConfConstants._ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.ThreadUtils +object FlintREPLConfConstants { + val HEARTBEAT_INTERVAL_MILLIS = 60000L + val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) + val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L + val INITIAL_DELAY_MILLIS = 3000L + val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L +} + /** * Spark SQL Application entrypoint * @@ -48,13 +59,6 @@ import org.apache.spark.util.ThreadUtils */ object FlintREPL extends Logging with FlintJobExecutor { - private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) - private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) - private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 - val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L - @volatile var earlyExitFlag: Boolean = false def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { @@ -134,7 +138,10 @@ object FlintREPL extends Logging with FlintJobExecutor { SECONDS) val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - + val queryLoopExecutionFrequency: Long = + conf.getLong( + "spark.flint.job.queryLoopExecutionFrequency", + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) @@ -199,7 +206,8 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId, queryExecutionTimeoutSecs, inactivityLimitMillis, - queryWaitTimeoutMillis) + queryWaitTimeoutMillis, + queryLoopExecutionFrequency) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } @@ -342,7 +350,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { - // 1 thread for updating heart beat + // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -392,7 +400,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader.close() } - Thread.sleep(100) + Thread.sleep(commandContext.queryLoopExecutionFrequency) } } finally { if (threadPool != null) { @@ -555,8 +563,8 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { + // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { canPickNextStatementResult = canPickNextStatement(sessionId, jobId, osClient, sessionIndex) lastCanPickCheckTime = currentTime diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index d8ddcb665..ef5db02dc 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -17,19 +17,21 @@ import scala.reflect.runtime.universe.TypeTag import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer -import org.mockito.ArgumentMatchersSugar -import org.mockito.Mockito._ +import org.mockito.{ArgumentMatchersSugar, Mockito} +import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.flint.data.FlintStatement import org.opensearch.search.sort.SortOrder +import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.SparkListenerApplicationEnd import org.apache.spark.sql.FlintREPL.PreShutdownListener +import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin @@ -599,7 +601,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), 60, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -880,7 +883,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), shortInactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mock processCommands to always allow loop continuation val getResponse = mock[GetResponse] @@ -930,7 +934,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), longInactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -986,7 +991,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) try { // Mocking ThreadUtils to track the shutdown call @@ -1036,7 +1042,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) try { // Mocking ThreadUtils to track the shutdown call @@ -1117,7 +1124,8 @@ class FlintREPLTest jobId, Duration(10, MINUTES), inactivityLimit, - 60) + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) val startTime = Instant.now().toEpochMilli() @@ -1131,58 +1139,70 @@ class FlintREPLTest verify(osClient, times(1)).getIndexMetadata(*) } - test("queryLoop should execute loop without processing any commands") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) - - // Configure mockReader to always return false, indicating no commands to process - when(mockReader.hasNext).thenReturn(false) + val testCases = Table( + ("inactivityLimit", "queryLoopExecutionFrequency"), + (5000, 100L), // 5 seconds, 100 ms + (100, 300L) // 100 ms, 300 ms + ) - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" + test( + "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { + forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => + val mockReader = mock[FlintReader] + val osClient = mock[OSClient] + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + val getResponse = mock[GetResponse] + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(false) + when(mockReader.hasNext).thenReturn(false) + + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val jobId = "testJobId" + + // Create a SparkSession for testing + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val inactivityLimit = 5000 // 5 seconds + val flintSessionIndexUpdater = mock[OpenSearchUpdater] - // Create a SparkSession for testing - val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val commandContext = CommandContext( + spark, + dataSource, + resultIndex, + sessionId, + flintSessionIndexUpdater, + osClient, + sessionIndex, + jobId, + Duration(10, MINUTES), + inactivityLimit, + 60, + queryLoopExecutionFrequency) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val startTime = Instant.now().toEpochMilli() - val commandContext = CommandContext( - spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, - Duration(10, MINUTES), - inactivityLimit, - 60) + // Running the queryLoop + FlintREPL.queryLoop(commandContext) - val startTime = Instant.now().toEpochMilli() + val endTime = Instant.now().toEpochMilli() - // Running the queryLoop - FlintREPL.queryLoop(commandContext) + val elapsedTime = endTime - startTime - val endTime = Instant.now().toEpochMilli() + // Assert that the loop ran for at least the duration of the inactivity limit + assert(elapsedTime >= inactivityLimit) - // Assert that the loop ran for at least the duration of the inactivity limit - assert(endTime - startTime >= inactivityLimit) + // Verify query execution frequency + val expectedCalls = Math.ceil(elapsedTime.toDouble / queryLoopExecutionFrequency).toInt + verify(mockReader, Mockito.atMost(expectedCalls)).hasNext - // Verify that no command was actually processed - verify(mockReader, never()).next() + // Verify that no command was actually processed + verify(mockReader, never()).next() - // Stop the SparkSession - spark.stop() + // Stop the SparkSession + spark.stop() + } } }