diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala index 2e285e3d9..ad49201f0 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala @@ -14,4 +14,5 @@ case class CommandState( recordedVerificationResult: VerificationResult, flintReader: FlintReader, futureMappingCheck: Future[Either[String, Unit]], - executionContext: ExecutionContextExecutor) + executionContext: ExecutionContextExecutor, + recordedLastCanPickCheckTime: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 4a1676634..99085185c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -6,8 +6,6 @@ package org.apache.spark.sql import java.net.ConnectException -import java.time.Instant -import java.util.Map import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} @@ -15,9 +13,11 @@ import scala.concurrent.duration.{Duration, MINUTES, _} import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal +import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} +import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.apache.spark.SparkConf @@ -47,6 +47,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 val INITIAL_DELAY_MILLIS = 3000L + val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L def update(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) @@ -292,10 +293,11 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastActivityTime = currentTimeProvider.currentEpochMillis() var verificationResult: VerificationResult = NotVerified var canPickUpNextStatement = true + var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { logInfo( - s"""read from ${commandContext.sessionIndex}, sessionId: $commandContext.sessionId""") + s"""read from ${commandContext.sessionIndex}, sessionId: ${commandContext.sessionId}""") val flintReader: FlintReader = createQueryReader( commandContext.osClient, @@ -309,18 +311,21 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult, flintReader, futureMappingCheck, - executionContext) - val result: (Long, VerificationResult, Boolean) = + executionContext, + lastCanPickCheckTime) + val result: (Long, VerificationResult, Boolean, Long) = processCommands(commandContext, commandState) val ( updatedLastActivityTime, updatedVerificationResult, - updatedCanPickUpNextStatement) = result + updatedCanPickUpNextStatement, + updatedLastCanPickCheckTime) = result lastActivityTime = updatedLastActivityTime verificationResult = updatedVerificationResult canPickUpNextStatement = updatedCanPickUpNextStatement + lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { flintReader.close() } @@ -481,7 +486,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processCommands( context: CommandContext, - state: CommandState): (Long, VerificationResult, Boolean) = { + state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ import state._ @@ -489,10 +494,19 @@ object FlintREPL extends Logging with FlintJobExecutor { var verificationResult = recordedVerificationResult var canProceed = true var canPickNextStatementResult = true // Add this line to keep track of canPickNextStatement + var lastCanPickCheckTime = recordedLastCanPickCheckTime while (canProceed) { - if (!canPickNextStatement(sessionId, jobId, osClient, sessionIndex)) { - canPickNextStatementResult = false + val currentTime = currentTimeProvider.currentEpochMillis() + + // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { + canPickNextStatementResult = + canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + lastCanPickCheckTime = currentTime + } + + if (!canPickNextStatementResult) { canProceed = false } else if (!flintReader.hasNext) { canProceed = false @@ -524,7 +538,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } // return tuple indicating if still active and mapping verification result - (lastActivityTime, verificationResult, canPickNextStatementResult) + (lastActivityTime, verificationResult, canPickNextStatementResult, lastCanPickCheckTime) } /** @@ -888,20 +902,12 @@ object FlintREPL extends Logging with FlintJobExecutor { return // Exit the run method if the thread is interrupted } - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - val flintInstance = FlintInstance.deserializeFromMap(source) - flintInstance.state = "running" - flintSessionUpdater.updateIf( - sessionId, - FlintInstance.serializeWithoutJobId( - flintInstance, - currentTimeProvider.currentEpochMillis()), - getResponse.getSeqNo, - getResponse.getPrimaryTerm) - } - // do nothing if the session doc does not exist + flintSessionUpdater.upsert( + sessionId, + Serialization.write( + Map( + "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), + "state" -> "running"))) } catch { case ie: InterruptedException => // Preserve the interrupt status diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 7b9fcc140..c3d027102 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -47,20 +47,6 @@ class FlintREPLTest val getResponse = mock[GetResponse] val scheduledFutureRaw = mock[ScheduledFuture[_]] - // Mock behaviors - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn( - Map[String, Object]( - "applicationId" -> "app1", - "jobId" -> "job1", - "sessionId" -> "session1", - "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError", - "state" -> "running", - "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) - when(getResponse.getSeqNo).thenReturn(0L) - when(getResponse.getPrimaryTerm).thenReturn(0L) // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. when( threadPool.scheduleAtFixedRate( @@ -85,8 +71,7 @@ class FlintREPLTest 0) // Verifications - verify(osClient, atLeastOnce()).getDoc("sessionIndex", "session1") - verify(flintSessionUpdater, atLeastOnce()).updateIf(eqTo("session1"), *, eqTo(0L), eqTo(0L)) + verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) } test("createShutdownHook add shutdown hook and update FlintInstance if conditions are met") {