diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index cdeebe663..51a9373d8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -314,7 +314,13 @@ object FlintREPL extends Logging with FlintJobExecutor { val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) - var futurePrepareQueryExecution: Future[Either[String, Unit]] = null + + val statementsExecutionManager = + instantiateStatementExecutionManager(commandContext) + + var futurePrepareQueryExecution: Future[Either[String, Unit]] = Future { + statementsExecutionManager.prepareStatementExecution() + } try { logInfo(s"""Executing session with sessionId: ${sessionId}""") @@ -324,12 +330,6 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - val statementsExecutionManager = - instantiateStatementExecutionManager(commandContext) - - futurePrepareQueryExecution = Future { - statementsExecutionManager.prepareStatementExecution() - } try { val commandState = CommandState( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index 432d6df11..09a1b3c1e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -6,7 +6,7 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement -import org.opensearch.flint.core.storage.OpenSearchUpdater +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging @@ -29,8 +29,8 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] // Using one reader client within same session will cause concurrency issue. - // To resolve this move the reader creation and getNextStatement method to mirco-batch level - private val flintReader = createOpenSearchQueryReader() + // To resolve this move the reader creation to getNextStatement method at mirco-batch level + private var currentReader: Option[FlintReader] = None override def prepareStatementExecution(): Either[String, Unit] = { checkAndCreateIndex(osClient, resultIndex) @@ -39,12 +39,17 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) } override def terminateStatementExecution(): Unit = { - flintReader.close() + currentReader.foreach(_.close()) + currentReader = None } override def getNextStatement(): Option[FlintStatement] = { - if (flintReader.hasNext) { - val rawStatement = flintReader.next() + if (currentReader.isEmpty) { + currentReader = Some(createOpenSearchQueryReader()) + } + + if (currentReader.get.hasNext) { + val rawStatement = currentReader.get.next() val flintStatement = FlintStatement.deserialize(rawStatement) logInfo(s"Next statement to execute: $flintStatement") Some(flintStatement) @@ -100,7 +105,6 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) | ] | } |}""".stripMargin - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader + osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 5eeccce73..07ed94bdc 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -1387,7 +1387,8 @@ class FlintREPLTest val expectedCalls = Math.ceil(inactivityLimit.toDouble / DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY).toInt - verify(mockOSClient, Mockito.atMost(expectedCalls)).getIndexMetadata(*) + verify(mockOSClient, times(1)).getIndexMetadata(*) + verify(mockOSClient, Mockito.atMost(expectedCalls)).createQueryReader(*, *, *, *) } val testCases = Table(