diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala index ad49201f0..2e285e3d9 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala @@ -14,5 +14,4 @@ case class CommandState( recordedVerificationResult: VerificationResult, flintReader: FlintReader, futureMappingCheck: Future[Either[String, Unit]], - executionContext: ExecutionContextExecutor, - recordedLastCanPickCheckTime: Long) + executionContext: ExecutionContextExecutor) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 99085185c..4a1676634 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -6,6 +6,8 @@ package org.apache.spark.sql import java.net.ConnectException +import java.time.Instant +import java.util.Map import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} @@ -13,11 +15,9 @@ import scala.concurrent.duration.{Duration, MINUTES, _} import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal -import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} -import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.apache.spark.SparkConf @@ -47,7 +47,6 @@ object FlintREPL extends Logging with FlintJobExecutor { private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L def update(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) @@ -293,11 +292,10 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastActivityTime = currentTimeProvider.currentEpochMillis() var verificationResult: VerificationResult = NotVerified var canPickUpNextStatement = true - var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { logInfo( - s"""read from ${commandContext.sessionIndex}, sessionId: ${commandContext.sessionId}""") + s"""read from ${commandContext.sessionIndex}, sessionId: $commandContext.sessionId""") val flintReader: FlintReader = createQueryReader( commandContext.osClient, @@ -311,21 +309,18 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult, flintReader, futureMappingCheck, - executionContext, - lastCanPickCheckTime) - val result: (Long, VerificationResult, Boolean, Long) = + executionContext) + val result: (Long, VerificationResult, Boolean) = processCommands(commandContext, commandState) val ( updatedLastActivityTime, updatedVerificationResult, - updatedCanPickUpNextStatement, - updatedLastCanPickCheckTime) = result + updatedCanPickUpNextStatement) = result lastActivityTime = updatedLastActivityTime verificationResult = updatedVerificationResult canPickUpNextStatement = updatedCanPickUpNextStatement - lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { flintReader.close() } @@ -486,7 +481,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processCommands( context: CommandContext, - state: CommandState): (Long, VerificationResult, Boolean, Long) = { + state: CommandState): (Long, VerificationResult, Boolean) = { import context._ import state._ @@ -494,19 +489,10 @@ object FlintREPL extends Logging with FlintJobExecutor { var verificationResult = recordedVerificationResult var canProceed = true var canPickNextStatementResult = true // Add this line to keep track of canPickNextStatement - var lastCanPickCheckTime = recordedLastCanPickCheckTime while (canProceed) { - val currentTime = currentTimeProvider.currentEpochMillis() - - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { - canPickNextStatementResult = - canPickNextStatement(sessionId, jobId, osClient, sessionIndex) - lastCanPickCheckTime = currentTime - } - - if (!canPickNextStatementResult) { + if (!canPickNextStatement(sessionId, jobId, osClient, sessionIndex)) { + canPickNextStatementResult = false canProceed = false } else if (!flintReader.hasNext) { canProceed = false @@ -538,7 +524,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } // return tuple indicating if still active and mapping verification result - (lastActivityTime, verificationResult, canPickNextStatementResult, lastCanPickCheckTime) + (lastActivityTime, verificationResult, canPickNextStatementResult) } /** @@ -902,12 +888,20 @@ object FlintREPL extends Logging with FlintJobExecutor { return // Exit the run method if the thread is interrupted } - flintSessionUpdater.upsert( - sessionId, - Serialization.write( - Map( - "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), - "state" -> "running"))) + val getResponse = osClient.getDoc(sessionIndex, sessionId) + if (getResponse.isExists()) { + val source = getResponse.getSourceAsMap + val flintInstance = FlintInstance.deserializeFromMap(source) + flintInstance.state = "running" + flintSessionUpdater.updateIf( + sessionId, + FlintInstance.serializeWithoutJobId( + flintInstance, + currentTimeProvider.currentEpochMillis()), + getResponse.getSeqNo, + getResponse.getPrimaryTerm) + } + // do nothing if the session doc does not exist } catch { case ie: InterruptedException => // Preserve the interrupt status diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index c3d027102..7b9fcc140 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -47,6 +47,20 @@ class FlintREPLTest val getResponse = mock[GetResponse] val scheduledFutureRaw = mock[ScheduledFuture[_]] + // Mock behaviors + when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> "app1", + "jobId" -> "job1", + "sessionId" -> "session1", + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + when(getResponse.getSeqNo).thenReturn(0L) + when(getResponse.getPrimaryTerm).thenReturn(0L) // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. when( threadPool.scheduleAtFixedRate( @@ -71,7 +85,8 @@ class FlintREPLTest 0) // Verifications - verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) + verify(osClient, atLeastOnce()).getDoc("sessionIndex", "session1") + verify(flintSessionUpdater, atLeastOnce()).updateIf(eqTo("session1"), *, eqTo(0L), eqTo(0L)) } test("createShutdownHook add shutdown hook and update FlintInstance if conditions are met") {