From f6b7dc17d71ebc51ae7730c82a2310f01429059a Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 12 Jun 2024 13:56:45 -0700 Subject: [PATCH] Refactor REPL mode refactor statement lifecycle [WIP] Fix tests fix some tests --- .../apache/spark/sql/QueryResultWriter.scala | 16 + .../org/apache/spark/sql/SessionManager.scala | 48 ++ .../spark/sql/StatementExecutionContext.scala | 16 +- .../spark/sql/StatementLifecycleManager.scala | 29 + .../flint/data/FlintStatement.scala | 3 +- .../flint/data/InteractiveSession.scala | 28 +- .../sql/flint/config/FlintSparkConf.scala | 12 + .../scala/org/apache/spark/sql/FlintJob.scala | 2 +- .../apache/spark/sql/FlintJobExecutor.scala | 76 +- .../org/apache/spark/sql/FlintREPL.scala | 779 +++++++----------- ...cala => InMemoryQueryExecutionState.scala} | 9 +- .../org/apache/spark/sql/JobOperator.scala | 8 +- .../spark/sql/QueryResultWriterImpl.scala | 40 + .../apache/spark/sql/SessionManagerImpl.scala | 175 ++++ .../sql/StatementLifecycleManagerImpl.scala | 54 ++ .../org/apache/spark/sql/FlintJobTest.scala | 1 - .../org/apache/spark/sql/FlintREPLTest.scala | 556 +++++++------ 17 files changed, 1033 insertions(+), 819 deletions(-) create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala rename spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala => flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionContext.scala (55%) create mode 100644 flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala rename spark-sql-application/src/main/scala/org/apache/spark/sql/{CommandState.scala => InMemoryQueryExecutionState.scala} (61%) create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala new file mode 100644 index 000000000..f676a3519 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +trait QueryResultWriter { + def reformatQueryResult( + dataFrame: DataFrame, + flintStatement: FlintStatement, + queryExecutionContext: StatementExecutionContext): DataFrame + def persistQueryResult(dataFrame: DataFrame, flintStatement: FlintStatement): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala new file mode 100644 index 000000000..91a68ead3 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} + +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode + +/** + * Trait defining the interface for managing interactive sessions. + */ +trait SessionManager { + + /** + * Retrieves metadata about the session manager. + */ + def getSessionManagerMetadata: Map[String, Any] + + /** + * Fetches the details of a specific session. + */ + def getSessionDetails(sessionId: String): Option[InteractiveSession] + + /** + * Updates the details of a specific session. + */ + def updateSessionDetails( + sessionDetails: InteractiveSession, + updateMode: SessionUpdateMode): Unit + + /** + * Retrieves the next statement to be executed in a specific session. + */ + def getNextStatement(sessionId: String): Option[FlintStatement] + + /** + * Records a heartbeat for a specific session to indicate it is still active. + */ + def recordHeartbeat(sessionId: String): Unit +} + +object SessionUpdateMode extends Enumeration { + type SessionUpdateMode = Value + val UPDATE, UPSERT, UPDATE_IF = Value +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionContext.scala similarity index 55% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala rename to flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionContext.scala index fe2fa5212..91358d6ec 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionContext.scala @@ -5,20 +5,16 @@ package org.apache.spark.sql -import scala.concurrent.{ExecutionContextExecutor, Future} import scala.concurrent.duration.Duration -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} - -case class CommandContext( +case class StatementExecutionContext( spark: SparkSession, - dataSource: String, - resultIndex: String, - sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, jobId: String, + sessionId: String, + sessionManager: SessionManager, + statementLifecycleManager: StatementLifecycleManager, + queryResultWriter: QueryResultWriter, + dataSource: String, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long) diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala new file mode 100644 index 000000000..5a890f5ed --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementLifecycleManager.scala @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +/** + * Trait defining the interface for managing the lifecycle of statements. + */ +trait StatementLifecycleManager { + + /** + * Prepares the statement lifecycle. + */ + def prepareStatementLifecycle(): Either[String, Unit] + + /** + * Updates a specific statement. + */ + def updateStatement(statement: FlintStatement): Unit + + /** + * Terminates the statement lifecycle. + */ + def terminateStatementLifecycle(): Unit +} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala index dbe73e9a5..80c22df82 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala @@ -42,6 +42,7 @@ class FlintStatement( val statementId: String, val queryId: String, val submitTime: Long, + var queryStartTime: Option[Long] = Some(-1L), var error: Option[String] = None, statementContext: Map[String, Any] = Map.empty[String, Any]) extends ContextualDataStore { @@ -76,7 +77,7 @@ object FlintStatement { case _ => None } - new FlintStatement(state, query, statementId, queryId, submitTime, maybeError) + new FlintStatement(state, query, statementId, queryId, submitTime, error = maybeError) } def serialize(flintStatement: FlintStatement): String = { diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala index c5eaee4f1..3727d14ee 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.data -import java.util.{Map => JavaMap} +import java.util.{List => JavaList, Map => JavaMap} import scala.collection.JavaConverters._ @@ -16,9 +16,9 @@ import org.json4s.native.Serialization object SessionStates { val RUNNING = "running" - val COMPLETE = "complete" - val FAILED = "failed" - val WAITING = "waiting" + val DEAD = "dead" + val FAIL = "fail" + val NOT_STARTED = "not_started" } /** @@ -57,9 +57,9 @@ class InteractiveSession( context = sessionContext // Initialize the context from the constructor def isRunning: Boolean = state == SessionStates.RUNNING - def isComplete: Boolean = state == SessionStates.COMPLETE - def isFailed: Boolean = state == SessionStates.FAILED - def isWaiting: Boolean = state == SessionStates.WAITING + def isDead: Boolean = state == SessionStates.DEAD + def isFail: Boolean = state == SessionStates.FAIL + def isNotStarted: Boolean = state == SessionStates.NOT_STARTED override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") @@ -129,10 +129,7 @@ object InteractiveSession { } // We safely handle the possibility of excludeJobIds being absent or not a list. - val excludeJobIds: Seq[String] = scalaSource.get("excludeJobIds") match { - case Some(lst: java.util.List[_]) => lst.asScala.toList.map(_.asInstanceOf[String]) - case _ => Seq.empty[String] - } + val excludeJobIds: Seq[String] = parseExcludedJobIds(scalaSource.get("excludeJobIds")) // Handle error similarly, ensuring we get an Option[String]. val maybeError: Option[String] = scalaSource.get("error") match { @@ -201,4 +198,13 @@ object InteractiveSession { def serializeWithoutJobId(job: InteractiveSession, currentTime: Long): String = { serialize(job, currentTime, includeJobId = false) } + private def parseExcludedJobIds(source: Option[Any]): Seq[String] = { + source match { + case Some(s: String) => Seq(s) + case Some(list: JavaList[_]) => list.asScala.toList.collect { case str: String => str } + case None => Seq.empty[String] + case _ => + Seq.empty + } + } } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index c6638c0b2..a70f3630b 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -203,6 +203,15 @@ object FlintSparkConf { FlintConfig("spark.metadata.accessAWSCredentialsProvider") .doc("AWS credentials provider for metadata access permission") .createOptional() + val CUSTOM_SESSION_MANAGER = + FlintConfig("spark.flint.job.customSessionManager") + .createOptional() + val CUSTOM_STATEMENT_MANAGER = + FlintConfig("spark.flint.job.customStatementManager") + .createOptional() + val CUSTOM_QUERY_RESULT_WRITER = + FlintConfig("spark.flint.job.customQueryResultWriter") + .createOptional() } /** @@ -277,6 +286,9 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable SESSION_ID, REQUEST_INDEX, METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, + CUSTOM_SESSION_MANAGER, + CUSTOM_STATEMENT_MANAGER, + CUSTOM_QUERY_RESULT_WRITER, EXCLUDE_JOB_IDS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index bba999110..d6bf611fe 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -58,7 +58,7 @@ object FlintJob extends Logging with FlintJobExecutor { createSparkSession(conf), query, dataSource, - resultIndex, + resultIndex.get, jobType.equalsIgnoreCase("streaming"), streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 665ec5a27..60f2a5f8b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -89,6 +89,24 @@ trait FlintJobExecutor { } }""".stripMargin + // Define the data schema + val schema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), + // number is not nullable + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = true))) + def createSparkConf(): SparkConf = { new SparkConf() .setAppName(getClass.getSimpleName) @@ -175,7 +193,6 @@ trait FlintJobExecutor { query: String, sessionId: String, startTime: Long, - timeProvider: TimeProvider, cleaner: Cleaner): DataFrame = { // Create the schema dataframe val schemaRows = result.schema.fields.map { field => @@ -188,29 +205,11 @@ trait FlintJobExecutor { StructField("column_name", StringType, nullable = false), StructField("data_type", StringType, nullable = false)))) - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) - val resultToSave = result.toJSON.collect.toList .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")) val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")) - val endTime = timeProvider.currentEpochMillis() + val endTime = currentTimeProvider.currentEpochMillis() // https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data // after consumed the query result. Streaming query shuffle data is cleaned after each @@ -245,28 +244,9 @@ trait FlintJobExecutor { queryId: String, query: String, sessionId: String, - startTime: Long, - timeProvider: TimeProvider): DataFrame = { - - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) - - val endTime = timeProvider.currentEpochMillis() + startTime: Long): DataFrame = { + + val endTime = currentTimeProvider.currentEpochMillis() // Create the data rows val rows = Seq( @@ -419,7 +399,6 @@ trait FlintJobExecutor { query, sessionId, startTime, - currentTimeProvider, CleanerFactory.cleaner(streaming)) } @@ -485,16 +464,19 @@ trait FlintJobExecutor { } } - def parseArgs(args: Array[String]): (Option[String], String) = { + def parseArgs(args: Array[String]): (Option[String], Option[String]) = { args match { + case Array() => + (None, None) case Array(resultIndex) => - (None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument + (None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument case Array(query, resultIndex) => ( Some(query), - resultIndex + Some(resultIndex) ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.") + case _ => + logAndThrow("Unsupported number of arguments. Expected no more than two arguments.") } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 8cad8844b..6b2f8bd25 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -15,23 +15,18 @@ import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import com.codahale.metrics.Timer -import org.json4s.native.Serialization -import org.opensearch.action.get.GetResponse -import org.opensearch.common.Strings import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.opensearch.flint.data.InteractiveSession.formats -import org.opensearch.search.sort.SortOrder +import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SessionUpdateMode.UPDATE_IF import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Spark SQL Application entrypoint @@ -49,28 +44,20 @@ import org.apache.spark.util.ThreadUtils object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) + private val PREPARE_QUERY_EXEC_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 - val INITIAL_DELAY_MILLIS = 3000L - val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L + private val INITIAL_DELAY_MILLIS = 3000L + private val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L @volatile var earlyExitFlag: Boolean = false - def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { - updater.update(flintStatement.statementId, FlintStatement.serialize(flintStatement)) - } - private val sessionRunningCount = new AtomicInteger(0) private val statementRunningCount = new AtomicInteger(0) def main(args: Array[String]) { val (queryOption, resultIndex) = parseArgs(args) - if (Strings.isNullOrEmpty(resultIndex)) { - logAndThrow("resultIndex is not set") - } - // init SparkContext val conf: SparkConf = createSparkConf() val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") @@ -91,37 +78,36 @@ object FlintREPL extends Logging with FlintJobExecutor { val query = getQuery(queryOption, jobType, conf) if (jobType.equalsIgnoreCase("streaming")) { - logInfo(s"""streaming query ${query}""") - configDYNMaxExecutors(conf, jobType) - val streamingRunningCount = new AtomicInteger(0) - val jobOperator = - JobOperator( - createSparkSession(conf), - query, - dataSource, - resultIndex, - true, - streamingRunningCount) - registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) - jobOperator.start() + logInfo(s"streaming query $query") + resultIndex match { + case Some(index) => + configDYNMaxExecutors(conf, jobType) + val streamingRunningCount = new AtomicInteger(0) + val jobOperator = JobOperator( + createSparkSession(conf), + query, + dataSource, + index, // Ensure the correct Option type is passed + true, + streamingRunningCount) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + jobOperator.start() + + case None => logAndThrow("resultIndex is not set") + } } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) - if (sessionIndex.isEmpty) { - logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") - } if (sessionId.isEmpty) { logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") } val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong( @@ -135,7 +121,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val queryWaitTimeoutMillis: Long = conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) - val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) + val sessionManager = instantiateSessionManager(spark, resultIndex) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -148,12 +134,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/320 */ spark.sparkContext.addSparkListener( - new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionId.get, - sessionTimerContext)) + new PreShutdownListener(sessionId.get, sessionManager, sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -161,61 +142,59 @@ object FlintREPL extends Logging with FlintJobExecutor { registerGauge(MetricConstants.REPL_RUNNING_METRIC, sessionRunningCount) registerGauge(MetricConstants.STATEMENT_RUNNING_METRIC, statementRunningCount) + val jobStartTime = currentTimeProvider.currentEpochMillis() // update heart beat every 30 seconds // OpenSearch triggers recovery after 1 minute outdated heart beat var heartBeatFuture: ScheduledFuture[_] = null try { heartBeatFuture = createHeartBeatUpdater( - HEARTBEAT_INTERVAL_MILLIS, - flintSessionIndexUpdater, sessionId.get, - threadPool, - osClient, - sessionIndex.get, - INITIAL_DELAY_MILLIS) + sessionManager, + HEARTBEAT_INTERVAL_MILLIS, + INITIAL_DELAY_MILLIS, + threadPool) if (setupFlintJobWithExclusionCheck( conf, - sessionIndex, - sessionId, - osClient, - jobId, applicationId, - flintSessionIndexUpdater, + jobId, + sessionId.get, + sessionManager, jobStartTime)) { earlyExitFlag = true return } - val commandContext = CommandContext( + val queryExecutionManager = + instantiateQueryExecutionManager(spark, sessionManager.getSessionManagerMetadata) + val queryResultWriter = + instantiateQueryResultWriter(spark, sessionManager.getSessionManagerMetadata) + val queryLoopContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId.get, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, jobId, + sessionId.get, + sessionManager, + queryExecutionManager, + queryResultWriter, + dataSource, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { - queryLoop(commandContext) + queryLoop(queryLoopContext) } recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => handleSessionError( - e, applicationId, jobId, sessionId.get, + sessionManager, + sessionTimerContext, jobStartTime, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionTimerContext) + e) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running @@ -287,69 +266,55 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def setupFlintJobWithExclusionCheck( conf: SparkConf, - sessionIndex: Option[String], - sessionId: Option[String], - osClient: OSClient, - jobId: String, applicationId: String, - flintSessionIndexUpdater: OpenSearchUpdater, + jobId: String, + sessionId: String, + sessionManager: SessionManager, jobStartTime: Long): Boolean = { val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) confExcludeJobsOpt match { case None => // If confExcludeJobs is None, pass null or an empty sequence as per your setupFlintJob method's signature - setupFlintJob( - applicationId, - jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, - jobStartTime) + setupFlintJob(applicationId, jobId, sessionId, sessionManager, jobStartTime) case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 - val excludeJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis + val excludedJobIds = confExcludeJobs.split(",").toList // Convert Array to Lis - if (excludeJobIds.contains(jobId)) { + if (excludedJobIds.contains(jobId)) { logInfo(s"current job is excluded, exit the application.") return true } - val getResponse = osClient.getDoc(sessionIndex.get, sessionId.get) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source != null) { - val existingExcludedJobIds = parseExcludedJobIds(source) - if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { - logInfo("duplicate job running, exit the application.") - return true - } - } + val sessionDetails = sessionManager.getSessionDetails(sessionId) + val existingExcludedJobIds = sessionDetails.get.excludedJobIds + if (excludedJobIds.sorted == existingExcludedJobIds.sorted) { + logInfo("duplicate job running, exit the application.") + return true } // If none of the edge cases are met, proceed with setup setupFlintJob( applicationId, jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, + sessionId, + sessionManager, jobStartTime, - excludeJobIds) + excludedJobIds = excludedJobIds) } false } - def queryLoop(commandContext: CommandContext): Unit = { - // 1 thread for updating heart beat + def queryLoop(context: StatementExecutionContext): Unit = { + import context._ + // 1 thread for query execution preparation val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var futureMappingCheck: Future[Either[String, Unit]] = null + implicit val futureExecutor = ExecutionContext.fromExecutor(threadPool) + var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { - checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) + futurePrepareQueryExecution = Future { + statementLifecycleManager.prepareStatementLifecycle() } var lastActivityTime = currentTimeProvider.currentEpochMillis() @@ -357,26 +322,18 @@ object FlintREPL extends Logging with FlintJobExecutor { var canPickUpNextStatement = true var lastCanPickCheckTime = 0L while (currentTimeProvider - .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo( - s"""read from ${commandContext.sessionIndex}, sessionId: ${commandContext.sessionId}""") - val flintReader: FlintReader = - createQueryReader( - commandContext.osClient, - commandContext.sessionId, - commandContext.sessionIndex, - commandContext.dataSource) + .currentEpochMillis() - lastActivityTime <= context.inactivityLimitMillis && canPickUpNextStatement) { + logInfo(s"""sessionId: ${context.sessionId}""") try { - val commandState = CommandState( + val inMemoryQueryExecutionState = InMemoryQueryExecutionState( lastActivityTime, + lastCanPickCheckTime, verificationResult, - flintReader, - futureMappingCheck, - executionContext, - lastCanPickCheckTime) + futurePrepareQueryExecution, + futureExecutor) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(commandContext, commandState) + processStatements(context, inMemoryQueryExecutionState) val ( updatedLastActivityTime, @@ -389,7 +346,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - flintReader.close() + statementLifecycleManager.terminateStatementLifecycle() } Thread.sleep(100) @@ -401,96 +358,70 @@ object FlintREPL extends Logging with FlintJobExecutor { } } + private def refreshSessionState( + applicationId: String, + jobId: String, + sessionId: String, + sessionManager: SessionManager, + jobStartTime: Long, + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + sessionDetails.state = state + sessionManager.updateSessionDetails(sessionDetails, updateMode = SessionUpdateMode.UPSERT) + sessionDetails + } + private def setupFlintJob( applicationId: String, jobId: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionIndex: String, + sessionManager: SessionManager, jobStartTime: Long, - excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + excludedJobIds: Seq[String] = Seq.empty[String]): Unit = { + refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) - - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo( - s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") + SessionStates.RUNNING, + excludedJobIds = excludedJobIds) sessionRunningCount.incrementAndGet() } - def handleSessionError( - e: Exception, + private def handleSessionError( applicationId: String, jobId: String, sessionId: String, + sessionManager: SessionManager, + sessionTimerContext: Timer.Context, jobStartTime: Long, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, - sessionTimerContext: Timer.Context): Unit = { + e: Exception): Unit = { val error = s"Session error: ${e.getMessage}" CustomLogging.logError(error, e) - - val flintInstance = getExistingFlintInstance(osClient, sessionIndex, sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) - if (flintInstance.state.equals("fail")) { - recordSessionFailed(sessionTimerContext) - } - } - - private def getExistingFlintInstance( - osClient: OSClient, - sessionIndex: String, - sessionId: String): Option[InteractiveSession] = Try( - osClient.getDoc(sessionIndex, sessionId)) match { - case Success(getResponse) if getResponse.isExists() => - Option(getResponse.getSourceAsMap) - .map(InteractiveSession.deserializeFromMap) - case Failure(exception) => - CustomLogging.logError( - s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", - exception) - None - case _ => None - } - - private def createFailedFlintInstance( - applicationId: String, - jobId: String, - sessionId: String, - jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - - private def updateFlintInstance( - flintInstance: InteractiveSession, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { - val currentTime = currentTimeProvider.currentEpochMillis() - flintSessionIndexUpdater.upsert( + refreshSessionState( + applicationId, + jobId, sessionId, - InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } /** @@ -513,13 +444,12 @@ object FlintREPL extends Logging with FlintJobExecutor { * @return * failed data frame */ - def handleCommandFailureAndGetFailedData( + def handleStatementFailureAndGetFailedData( spark: SparkSession, dataSource: String, error: String, flintStatement: FlintStatement, - sessionId: String, - startTime: Long): DataFrame = { + sessionId: String): DataFrame = { flintStatement.fail() flintStatement.error = Some(error) super.getFailedData( @@ -529,8 +459,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.queryId, flintStatement.query, sessionId, - startTime, - currentTimeProvider) + flintStatement.queryStartTime.get) } def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { @@ -540,9 +469,9 @@ object FlintREPL extends Logging with FlintJobExecutor { error } - private def processCommands( - context: CommandContext, - state: CommandState): (Long, VerificationResult, Boolean, Long) = { + private def processStatements( + context: StatementExecutionContext, + state: InMemoryQueryExecutionState): (Long, VerificationResult, Boolean, Long) = { import context._ import state._ @@ -555,45 +484,32 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed - if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) { - canPickNextStatementResult = - canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed + if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { + canPickNextStatementResult = canPickNextStatement(sessionManager, sessionId, jobId) lastCanPickCheckTime = currentTime } if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { - val statementTimerContext = getTimerContext( - MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) + sessionManager.getNextStatement(sessionId) match { + case Some(flintStatement) => + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex, - queryExecutionTimeout, - queryWaitTimeMillis) - - verificationResult = returnedVerificationResult - finalizeCommand( - dataToWrite, - flintStatement, - resultIndex, - flintSessionIndexUpdater, - osClient, - statementTimerContext) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() + val (dataToWrite, returnedVerificationResult) = + processStatementOnVerification(flintStatement, state, context) + + verificationResult = returnedVerificationResult + finalizeStatement(context, dataToWrite, flintStatement, statementTimerContext) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + + case None => + canProceed = false + } } } @@ -613,21 +529,19 @@ object FlintREPL extends Logging with FlintJobExecutor { * @param flintSessionIndexUpdater * flint session index updater */ - private def finalizeCommand( + private def finalizeStatement( + context: StatementExecutionContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, - resultIndex: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, statementTimerContext: Timer.Context): Unit = { + import context._ + try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + dataToWrite.foreach(df => queryResultWriter.persistQueryResult(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling flintStatement.complete() } - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -635,18 +549,18 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) flintStatement.fail() - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) + } finally { + statementLifecycleManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) } } - private def handleCommandTimeout( + private def handleStatementTimeout( spark: SparkSession, dataSource: String, error: String, flintStatement: FlintStatement, - sessionId: String, - startTime: Long): Option[DataFrame] = { + sessionId: String) = { /* * https://tinyurl.com/2ezs5xj9 * @@ -661,131 +575,89 @@ object FlintREPL extends Logging with FlintJobExecutor { */ spark.sparkContext.cancelJobGroup(flintStatement.queryId) Some( - handleCommandFailureAndGetFailedData( - spark, - dataSource, - error, - flintStatement, - sessionId, - startTime)) + handleStatementFailureAndGetFailedData(spark, dataSource, error, flintStatement, sessionId)) } def executeAndHandle( - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecuitonTimeOut: Duration, - queryWaitTimeMillis: Long): Option[DataFrame] = { + state: InMemoryQueryExecutionState, + context: StatementExecutionContext): Option[DataFrame] = { + import context._ + try { - Some( - executeQueryAsync( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecuitonTimeOut, - queryWaitTimeMillis)) + Some(executeQueryAsync(flintStatement, state, context)) } catch { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + handleStatementTimeout(spark, dataSource, error, flintStatement, sessionId) case e: Exception => val error = processQueryException(e, flintStatement) Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, error, flintStatement, - sessionId, - startTime)) + sessionId)) } } private def processStatementOnVerification( - recordedVerificationResult: VerificationResult, - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String, - queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long) = { - val startTime: Long = currentTimeProvider.currentEpochMillis() + state: InMemoryQueryExecutionState, + context: StatementExecutionContext): (Option[DataFrame], VerificationResult) = { + import context._ + import state._ + + flintStatement.queryStartTime = Some(currentTimeProvider.currentEpochMillis()) var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None verificationResult match { case NotVerified => try { - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + ThreadUtils.awaitResult(futurePrepareQueryExecution, PREPARE_QUERY_EXEC_TIMEOUT) match { case Right(_) => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(flintStatement, state, context) verificationResult = VerifiedWithoutError case Left(error) => verificationResult = VerifiedWithError(error) dataToWrite = Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, error, flintStatement, - sessionId, - startTime)) + sessionId)) } } catch { case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" + val error = s"Query execution preparation timed out" CustomLogging.logError(error, e) dataToWrite = - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + handleStatementTimeout(spark, dataSource, error, flintStatement, sessionId) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" CustomLogging.logError(error, e) dataToWrite = Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, error, flintStatement, - sessionId, - startTime)) + sessionId)) } case VerifiedWithError(err) => dataToWrite = Some( - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, err, flintStatement, - sessionId, - startTime)) + sessionId)) case VerifiedWithoutError => - dataToWrite = executeAndHandle( - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - queryExecutionTimeout, - queryWaitTimeMillis) + dataToWrite = executeAndHandle(flintStatement, state, context) } logInfo(s"command complete: $flintStatement") @@ -793,102 +665,49 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeQueryAsync( - spark: SparkSession, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - startTime: Long, - queryExecutionTimeOut: Duration, - queryWaitTimeMillis: Long): DataFrame = { + state: InMemoryQueryExecutionState, + context: StatementExecutionContext): DataFrame = { + import context._ + import state._ + if (currentTimeProvider .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { - handleCommandFailureAndGetFailedData( + handleStatementFailureAndGetFailedData( spark, dataSource, "wait timeout", flintStatement, - sessionId, - startTime) + sessionId) } else { val futureQueryExecution = Future { - executeQuery( - spark, - flintStatement.query, - dataSource, - flintStatement.queryId, - sessionId, - false) - }(executionContext) + executeQuery(flintStatement, context) + }(futureExecutor) // time out after 10 minutes - ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) + ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeout) } } - private def processCommandInitiation( - flintReader: FlintReader, - flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { - val command = flintReader.next() - logDebug(s"raw command: $command") - val flintStatement = FlintStatement.deserialize(command) - logDebug(s"command: $flintStatement") - flintStatement.running() - logDebug(s"command running: $flintStatement") - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - statementRunningCount.incrementAndGet() - flintStatement - } - private def createQueryReader( - osClient: OSClient, - sessionId: String, - sessionIndex: String, - dataSource: String) = { - // all state in index are in lower case - // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the - // same doc - val dsl = - s"""{ - | "bool": { - | "must": [ - | { - | "term": { - | "type": "statement" - | } - | }, - | { - | "term": { - | "state": "waiting" - | } - | }, - | { - | "term": { - | "sessionId": "$sessionId" - | } - | }, - | { - | "term": { - | "dataSourceName": "$dataSource" - | } - | }, - | { - | "range": { - | "submitTime": { "gte": "now-1h" } - | } - | } - | ] - | } - |}""".stripMargin - - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader + def executeQuery( + flintStatement: FlintStatement, + context: StatementExecutionContext): DataFrame = { + import context._ + // reset start time + flintStatement.queryStartTime = Some(System.currentTimeMillis()) + // we have to set job group in the same thread that started the query according to spark doc + spark.sparkContext.setJobGroup( + flintStatement.queryId, + "Job group for " + flintStatement.queryId, + interruptOnCancel = true) + // Execute SQL query + val result: DataFrame = spark.sql(flintStatement.query) + queryResultWriter.reformatQueryResult(result, flintStatement, context) } class PreShutdownListener( - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, sessionId: String, + sessionManager: SessionManager, sessionTimerContext: Timer.Context) extends SparkListener with Logging { @@ -896,54 +715,29 @@ object FlintREPL extends Logging with FlintJobExecutor { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (!getResponse.isExists()) { - return - } - val source = getResponse.getSourceAsMap - if (source == null) { - return - } - - val state = Option(source.get("state")).map(_.asInstanceOf[String]) - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isDead && !sessionDetails.isFail) { + sessionDetails.state = SessionStates.DEAD + sessionManager.updateSessionDetails(sessionDetails, UPDATE_IF) + recordSessionSuccess(sessionTimerContext) + } + } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } - private def updateFlintInstanceBeforeShutdown( - source: java.util.Map[String, AnyRef], - getResponse: GetResponse, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String, - sessionTimerContext: Timer.Context): Unit = { - val flintInstance = InteractiveSession.deserializeFromMap(source) - flintInstance.state = "dead" - flintSessionIndexUpdater.updateIf( - sessionId, - InteractiveSession.serializeWithoutJobId( - flintInstance, - currentTimeProvider.currentEpochMillis()), - getResponse.getSeqNo, - getResponse.getPrimaryTerm) - recordSessionSuccess(sessionTimerContext) - } - /** * Create a new thread to update the last update time of the flint instance. * @param currentInterval @@ -960,13 +754,11 @@ object FlintREPL extends Logging with FlintJobExecutor { * the intial delay to start heartbeat */ def createHeartBeatUpdater( - currentInterval: Long, - flintSessionUpdater: OpenSearchUpdater, sessionId: String, - threadPool: ScheduledExecutorService, - osClient: OSClient, - sessionIndex: String, - initialDelayMillis: Long): ScheduledFuture[_] = { + sessionManager: SessionManager, + currentInterval: Long, + initialDelayMillis: Long, + threadPool: ScheduledExecutorService): ScheduledFuture[_] = { threadPool.scheduleAtFixedRate( new Runnable { @@ -978,12 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor { return // Exit the run method if the thread is interrupted } - flintSessionUpdater.upsert( - sessionId, - Serialization.write( - Map( - "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), - "state" -> "running"))) + sessionManager.recordHeartbeat(sessionId) } catch { case ie: InterruptedException => // Preserve the interrupt status @@ -1020,62 +807,36 @@ object FlintREPL extends Logging with FlintJobExecutor { * whether we can start fetching next statement or not */ def canPickNextStatement( + sessionManager: SessionManager, sessionId: String, - jobId: String, - osClient: OSClient, - sessionIndex: String): Boolean = { + jobId: String): Boolean = { try { - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source == null) { - logError(s"""Session id ${sessionId} is empty""") - // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) - return true - } - - val runJobId = Option(source.get("jobId")).map(_.asInstanceOf[String]).orNull - val excludeJobIds: Seq[String] = parseExcludedJobIds(source) - - if (runJobId != null && jobId != runJobId) { - logInfo(s"""the current job ID ${jobId} is not the running job ID ${runJobId}""") - return false - } - if (excludeJobIds != null && excludeJobIds.contains(jobId)) { - logInfo(s"""${jobId} is in the list of excluded jobs""") - return false - } - true - } else { - // still proceed since we are not sure what happened (e.g., session doc may not be available yet) - logError(s"""Fail to find id ${sessionId} from session index""") - true + sessionManager.getSessionDetails(sessionId) match { + case Some(sessionDetails) => + val runJobId = sessionDetails.jobId + val excludeJobIds = sessionDetails.excludedJobIds + + if (!runJobId.isEmpty && jobId != runJobId) { + logInfo(s"the current job ID $jobId is not the running job ID ${runJobId}") + return false + } + if (excludeJobIds.contains(jobId)) { + logInfo(s"$jobId is in the list of excluded jobs") + return false + } + true + case None => + logError(s"Failed to fetch sessionDetails by sessionId: $sessionId.") + true } } catch { - // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) + // still proceed with exception case e: Exception => CustomLogging.logError(s"""Fail to find id ${sessionId} from session index.""", e) true } } - private def parseExcludedJobIds(source: java.util.Map[String, AnyRef]): Seq[String] = { - - val rawExcludeJobIds = source.get("excludeJobIds") - Option(rawExcludeJobIds) - .map { - case s: String => Seq(s) - case list: java.util.List[_] @unchecked => - import scala.collection.JavaConverters._ - list.asScala.toList - .collect { case str: String => str } // Collect only strings from the list - case other => - logInfo(s"Unexpected type: ${other.getClass.getName}") - Seq.empty - } - .getOrElse(Seq.empty[String]) // In case of null, return an empty Seq - } - def exponentialBackoffRetry[T](maxRetries: Int, initialDelay: FiniteDuration)( block: => T): T = { var retries = 0 @@ -1113,6 +874,46 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } + private def instantiateProvider[T](defaultProvider: => T, providerClassName: String): T = { + if (providerClassName.isEmpty) { + defaultProvider + } else { + try { + val providerClass = Utils.classForName(providerClassName) + val ctor = providerClass.getDeclaredConstructor() + ctor.setAccessible(true) + ctor.newInstance().asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to instantiate provider: $providerClassName", e) + } + } + } + + private def instantiateSessionManager( + spark: SparkSession, + resultIndex: Option[String]): SessionManager = { + instantiateProvider( + new SessionManagerImpl(spark, resultIndex), + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key)) + } + + private def instantiateQueryExecutionManager( + spark: SparkSession, + context: Map[String, Any]): StatementLifecycleManager = { + instantiateProvider( + new StatementLifecycleManagerImpl(context), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key)) + } + + private def instantiateQueryResultWriter( + spark: SparkSession, + context: Map[String, Any]): QueryResultWriter = { + instantiateProvider( + new QueryResultWriterImpl(context), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key)) + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/InMemoryQueryExecutionState.scala similarity index 61% rename from spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala rename to spark-sql-application/src/main/scala/org/apache/spark/sql/InMemoryQueryExecutionState.scala index ad49201f0..4f87dd643 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/InMemoryQueryExecutionState.scala @@ -9,10 +9,9 @@ import scala.concurrent.{ExecutionContextExecutor, Future} import org.opensearch.flint.core.storage.FlintReader -case class CommandState( +case class InMemoryQueryExecutionState( recordedLastActivityTime: Long, + recordedLastCanPickCheckTime: Long, recordedVerificationResult: VerificationResult, - flintReader: FlintReader, - futureMappingCheck: Future[Either[String, Unit]], - executionContext: ExecutionContextExecutor, - recordedLastCanPickCheckTime: Long) + futurePrepareQueryExecution: Future[Either[String, Unit]], + futureExecutor: ExecutionContextExecutor) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index f315dc836..d386a5c31 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -57,19 +57,17 @@ case class JobOperator( dataToWrite = Some(mappingCheckResult match { case Right(_) => data case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + getFailedData(spark, dataSource, error, "", query, "", startTime) }) exceptionThrown = false } catch { case e: TimeoutException => val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + dataToWrite = Some(getFailedData(spark, dataSource, error, "", query, "", startTime)) case e: Exception => val error = processQueryException(e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + dataToWrite = Some(getFailedData(spark, dataSource, error, "", query, "", startTime)) } finally { cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala new file mode 100644 index 000000000..830fbae68 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.util.CleanerFactory + +class QueryResultWriterImpl(context: Map[String, Any]) + extends QueryResultWriter + with FlintJobExecutor + with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + + override def reformatQueryResult( + dataFrame: DataFrame, + flintStatement: FlintStatement, + queryExecutionContext: StatementExecutionContext): DataFrame = { + import queryExecutionContext._ + getFormattedData( + dataFrame, + spark, + dataSource, + flintStatement.queryId, + flintStatement.query, + sessionId, + flintStatement.queryStartTime.get, + CleanerFactory.cleaner(false)) + } + + override def persistQueryResult(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { + writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala new file mode 100644 index 000000000..6775c9067 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -0,0 +1,175 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.{Failure, Success, Try} + +import org.json4s.native.Serialization +import org.opensearch.flint.core.logging.CustomLogging +import org.opensearch.flint.core.storage.FlintReader +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} +import org.opensearch.flint.data.InteractiveSession.formats +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode +import org.apache.spark.sql.flint.config.FlintSparkConf + +class SessionManagerImpl(spark: SparkSession, resultIndex: Option[String]) + extends SessionManager + with FlintJobExecutor + with Logging { + + // we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found. + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key) + val sessionId: String = spark.conf.get(FlintSparkConf.SESSION_ID.key) + val dataSource: String = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + + if (sessionIndex.isEmpty) { + logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") + } + if (resultIndex.isEmpty) { + logAndThrow("resultIndex is not set") + } + if (sessionId.isEmpty) { + logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") + } + if (dataSource.isEmpty) { + logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set") + } + + val osClient = new OSClient(FlintSparkConf().flintOptions()) + val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + val flintReader: FlintReader = createOpenSearchQueryReader() + + override def getSessionManagerMetadata: Map[String, Any] = { + Map( + "resultIndex" -> resultIndex.get, + "osClient" -> osClient, + "flintSessionIndexUpdater" -> flintSessionIndexUpdater, + "flintReader" -> flintReader) + } + + override def getSessionDetails(sessionId: String): Option[InteractiveSession] = { + Try(osClient.getDoc(sessionIndex, sessionId)) match { + case Success(getResponse) if getResponse.isExists => + Option(getResponse.getSourceAsMap) + .map(InteractiveSession.deserializeFromMap) + case Failure(exception) => + CustomLogging.logError( + s"Failed to retrieve existing InteractiveSession: ${exception.getMessage}", + exception) + None + case _ => None + } + } + + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.UPDATE => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.UPSERT => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("_seq_no") + .getOrElse(throw new IllegalArgumentException("Missing _seq_no for conditional update")) + .asInstanceOf[Long] + val primaryTerm = sessionDetails + .getContextValue("_primary_term") + .getOrElse( + throw new IllegalArgumentException("Missing _primary_term for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } + + override def getNextStatement(sessionId: String): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + logDebug(s"raw statement: $rawStatement") + val flintStatement = FlintStatement.deserialize(rawStatement) + logDebug(s"statement: $flintStatement") + Some(flintStatement) + } else { + None + } + } + + override def recordHeartbeat(sessionId: String): Unit = { + flintSessionIndexUpdater.upsert( + sessionId, + Serialization.write( + Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) + } + + private def createOpenSearchQueryReader() = { + // all state in index are in lower case + // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the + // same doc + val dsl = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "type": "statement" + | } + | }, + | { + | "term": { + | "state": "waiting" + | } + | }, + | { + | "term": { + | "sessionId": "$sessionId" + | } + | }, + | { + | "term": { + | "dataSourceName": "$dataSource" + | } + | }, + | { + | "range": { + | "submitTime": { "gte": "now-1h" } + | } + | } + | ] + | } + |}""".stripMargin + + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) + flintReader + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala new file mode 100644 index 000000000..8e6e8a644 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementLifecycleManagerImpl.scala @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.FlintStatement + +import org.apache.spark.internal.Logging + +class StatementLifecycleManagerImpl(context: Map[String, Any]) + extends StatementLifecycleManager + with FlintJobExecutor + with Logging { + + val resultIndex = context("resultIndex").asInstanceOf[String] + val osClient = context("osClient").asInstanceOf[OSClient] + val flintSessionIndexUpdater = + context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] + val flintReader = context("flintReader").asInstanceOf[FlintReader] + + override def prepareStatementLifecycle(): Either[String, Unit] = { + try { + val existingSchema = osClient.getIndexMetadata(resultIndex) + if (!isSuperset(existingSchema, resultIndexMapping)) { + Left(s"The mapping of $resultIndex is incorrect.") + } else { + Right(()) + } + } catch { + case e: IllegalStateException + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) + case e: InterruptedException => + val error = s"Interrupted by the main thread: ${e.getMessage}" + Thread.currentThread().interrupt() // Preserve the interrupt status + logError(error, e) + Left(error) + case e: Exception => + val error = s"Failed to verify existing mapping: ${e.getMessage}" + logError(error, e) + Left(error) + } + } + override def updateStatement(statement: FlintStatement): Unit = { + flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) + } + override def terminateStatementLifecycle(): Unit = { + flintReader.close() + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 19f596e31..a35cb6590 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -79,7 +79,6 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { "select 1", "20", currentTime - queryRunTime, - new MockTimeProvider(currentTime), CleanerFactory.cleaner(false)) assertEqualDataframe(expected, result) } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 546cd8e97..0918af849 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -23,7 +23,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} -import org.opensearch.flint.data.FlintStatement +import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar @@ -50,7 +50,7 @@ class FlintREPLTest val args = Array("resultIndexName") val (queryOption, resultIndex) = FlintREPL.parseArgs(args) queryOption shouldBe None - resultIndex shouldBe "resultIndexName" + resultIndex shouldBe Some("resultIndexName") } test( @@ -58,16 +58,15 @@ class FlintREPLTest val args = Array("SELECT * FROM table", "resultIndexName") val (queryOption, resultIndex) = FlintREPL.parseArgs(args) queryOption shouldBe Some("SELECT * FROM table") - resultIndex shouldBe "resultIndexName" + resultIndex shouldBe Some("resultIndexName") } test( "parseArgs with no arguments should throw IllegalArgumentException with specific message") { val args = Array.empty[String] - val exception = intercept[IllegalArgumentException] { - FlintREPL.parseArgs(args) - } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + val (queryOption, resultIndex) = FlintREPL.parseArgs(args) + queryOption shouldBe None + resultIndex shouldBe None } test( @@ -76,7 +75,7 @@ class FlintREPLTest val exception = intercept[IllegalArgumentException] { FlintREPL.parseArgs(args) } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + exception.getMessage shouldBe "Unsupported number of arguments. Expected no more than two arguments." } test("getQuery should return query from queryOption if present") { @@ -131,19 +130,19 @@ class FlintREPLTest test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks - val flintSessionUpdater = mock[OpenSearchUpdater] - val osClient = mock[OSClient] val threadPool = mock[ScheduledExecutorService] - val getResponse = mock[GetResponse] val scheduledFutureRaw = mock[ScheduledFuture[_]] - + val sessionManager = mock[SessionManager] + val sessionId = "session1" + val currentInterval = 1000L + val initialDelayMillis = 0L // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. when( threadPool.scheduleAtFixedRate( any[Runnable], - eqTo(0), - *, - eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + eqTo(initialDelayMillis), + eqTo(currentInterval), + eqTo(TimeUnit.MILLISECONDS))) .thenAnswer((invocation: InvocationOnMock) => { val runnable = invocation.getArgument[Runnable](0) runnable.run() @@ -152,54 +151,42 @@ class FlintREPLTest // Invoke the method FlintREPL.createHeartBeatUpdater( - 1000L, - flintSessionUpdater, - "session1", - threadPool, - osClient, - "sessionIndex", - 0) - + sessionId, + sessionManager, + currentInterval, + initialDelayMillis, + threadPool) // Verifications - verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) + verify(sessionManager).recordHeartbeat(sessionId) } test("PreShutdownListener updates FlintInstance if conditions are met") { // Mock dependencies - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val sessionIndex = "testIndex" val sessionId = "testSessionId" val timerContext = mock[Timer.Context] + val sessionManager = mock[SessionManager] - // Setup the getDoc to return a document indicating the session is running - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn( - Map[String, Object]( - "applicationId" -> "app1", - "jobId" -> "job1", - "sessionId" -> "session1", - "state" -> "running", - "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError", - "state" -> "running", - "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + val interactiveSession = new InteractiveSession( + "app123", + "job123", + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000 + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Instantiate the listener - val listener = new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex, - sessionId, - timerContext) + val listener = new PreShutdownListener(sessionId, sessionManager, timerContext) // Simulate application end listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) - // Verify the update is called with the correct arguments - verify(flintSessionIndexUpdater).updateIf(*, *, *, *) + verify(sessionManager).updateSessionDetails( + interactiveSession, + SessionUpdateMode.UPDATE_IF + ) + interactiveSession.state shouldBe SessionStates.DEAD } test("Test getFailedData method") { @@ -249,16 +236,16 @@ class FlintREPLTest try { FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime) - + // Simulate setting query start time + flintStatement.queryStartTime = Some(currentTime - queryRunTime) // Compare the result val result = - FlintREPL.handleCommandFailureAndGetFailedData( + FlintREPL.handleStatementFailureAndGetFailedData( spark, dataSourceName, error, flintStatement, - "20", - currentTime - queryRunTime) + "20") assertEqualDataframe(expected, result) assert("failed" == flintStatement.state) assert(error == flintStatement.error.get) @@ -269,21 +256,22 @@ class FlintREPLTest } } - test("test canPickNextStatement: Doc Exists and Valid JobId") { + test("test canPickNextStatement: Valid jobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000 + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) assert(result) } @@ -292,19 +280,20 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val differentJobId = "jobXYZ" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" - - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val sessionManager = mock[SessionManager] - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", differentJobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000 + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, differentJobId) // Assertions assert(!result) // The function should return false @@ -315,6 +304,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -328,8 +318,14 @@ class FlintREPLTest sourceMap.put("excludeJobIds", excludeJobIdsList) // But jobId is in the exclude list when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = mock[InteractiveSession] + when(interactiveSession.jobId).thenReturn("jobABC") + when(interactiveSession.excludedJobIds).thenReturn(Seq(jobId)) + + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Assertions assert(!result) // The function should return false because jobId is excluded @@ -340,6 +336,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Mock the getDoc response val getResponse = mock[GetResponse] @@ -348,7 +345,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Assertions assert(result) // The function should return true despite the null source @@ -359,7 +356,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" - + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) @@ -373,7 +370,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) assert(result) // The function should return true } @@ -383,6 +380,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock GetResponse val getResponse = mock[GetResponse] @@ -390,7 +388,7 @@ class FlintREPLTest when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist // Execute the function under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Assert the function returns true assert(result) @@ -401,13 +399,14 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock OSClient to throw an exception when(osClient.getDoc(sessionIndex, sessionId)) .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) // Execute the method under test and expect true, since the method is designed to return true even in case of an exception - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Verify the result is true despite the exception assert(result) @@ -420,6 +419,7 @@ class FlintREPLTest val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -433,7 +433,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // The function should return true since jobId is not excluded assert(result) @@ -449,11 +449,9 @@ class FlintREPLTest exception.setServiceName("AWSGlue") val mockFlintStatement = mock[FlintStatement] - val expectedError = ( - """{"Message":"Fail to read data from Glue. Cause: Access denied in AWS Glue service. Please check permissions. (Service: AWSGlue; """ + - """Status Code: 400; Error Code: AccessDeniedException; Request ID: null; Proxy: null)",""" + - """"ErrorSource":"AWSGlue","StatusCode":"400"}""" - ) + val expectedError = """{"Message":"Fail to read data from Glue. Cause: Access denied in AWS Glue service. Please check permissions. (Service: AWSGlue; """ + + """Status Code: 400; Error Code: AccessDeniedException; Request ID: null; Proxy: null)",""" + + """"ErrorSource":"AWSGlue","StatusCode":"400"}""" val result = FlintREPL.processQueryException(exception, mockFlintStatement) @@ -469,7 +467,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" - val handleSessionError = mock[Function1[String, Unit]] + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -486,8 +484,20 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) + // Mock the InteractiveSession + val interactiveSession = new InteractiveSession( + applicationId = "app123", + sessionId = sessionId, + state = "active", + lastUpdateTime = System.currentTimeMillis(), + jobId = "jobABC", + excludedJobIds = Seq(jobId) + ) + + // Mock the sessionManager to return the mocked interactiveSession + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // The function should return false since jobId is excluded assert(!result) @@ -498,6 +508,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -515,7 +526,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // The function should return true since the jobId is not in the excludeJobIds list assert(result) @@ -544,19 +555,20 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { val flintSessionIndexUpdater = mock[OpenSearchUpdater] - - val commandContext = CommandContext( + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - 60, - 60) + 60L, + 60L) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -587,6 +599,9 @@ class FlintREPLTest val sessionId = "someSessionId" val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] + val flintStatement = mock[FlintStatement] + val state = mock[InMemoryQueryExecutionState] + val context = mock[StatementExecutionContext] when(mockFlintStatement.query).thenReturn("SELECT 1") when(mockFlintStatement.submitTime).thenReturn(Instant.now().toEpochMilli()) @@ -608,16 +623,7 @@ class FlintREPLTest val sparkContext = mock[SparkContext] when(mockSparkSession.sparkContext).thenReturn(sparkContext) - val result = FlintREPL.executeAndHandle( - mockSparkSession, - mockFlintStatement, - dataSource, - sessionId, - executionContext, - startTime, - // make sure it times out before mockSparkSession.sql can return, which takes 60 seconds - Duration(1, SECONDS), - 600000) + val result = FlintREPL.executeAndHandle(flintStatement, state, context) verify(mockSparkSession, times(1)).sql(any[String]) verify(sparkContext, times(1)).cancelJobGroup(any[String]) @@ -646,6 +652,8 @@ class FlintREPLTest val sessionId = "someSessionId" val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] + val state = mock[InMemoryQueryExecutionState] + val context = mock[StatementExecutionContext] // sql method can only throw RuntimeException when(mockSparkSession.sql(any[String])).thenThrow( @@ -659,15 +667,7 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val result = FlintREPL.executeAndHandle( - mockSparkSession, - flintStatement, - dataSource, - sessionId, - executionContext, - startTime, - Duration.Inf, // Use Duration.Inf or a large enough duration to avoid a timeout, - 600000) + val result = FlintREPL.executeAndHandle(flintStatement, state, context) // Verify that ParseException was caught and handled result should not be None // or result.isDefined shouldBe true @@ -680,6 +680,13 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) when(getResponse.getSourceAsMap).thenReturn( @@ -694,47 +701,73 @@ class FlintREPLTest when(getResponse.getSeqNo).thenReturn(0L) when(getResponse.getPrimaryTerm).thenReturn(0L) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + lastUpdateTime = java.lang.Long.valueOf(12345L) + ) + + when(sessionManager.getSessionDetails(sessionId)) + .thenReturn(Some(interactiveSession)) // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + conf, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(!result) // Expecting false as the job should proceed normally } test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the rest of the GetResponse as needed - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") + val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job1") + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq(jobId), + lastUpdateTime = java.lang.Long.valueOf(12345L) + ) + + when(sessionManager.getSessionDetails(sessionId)) + .thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(result) // Expecting true as the job should exit early } test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the GetResponse to simulate a scenario of a duplicate job @@ -751,71 +784,89 @@ class FlintREPLTest .asList("job-2", "job-1") // Include this inside the Map ).asJava) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") + val interactiveSession = new InteractiveSession( + "app1", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq("job-1", "job-2"), + lastUpdateTime = java.lang.Long.valueOf(12345L) + ) + // Mock sessionManager to return sessionDetails + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(result) // Expecting true for early exit due to duplicate job } test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") + val interactiveSession = new InteractiveSession( + "app1", + jobId, + sessionId, + SessionStates.RUNNING, + excludedJobIds = Seq("job-5", "job-6"), + lastUpdateTime = java.lang.Long.valueOf(12345L) + ) + // Mock sessionManager to return sessionDetails + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) val result = FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(!result) // Expecting false as the job proceeds normally } test( "setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") { - val osClient = mock[OSClient] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + + // Mock sessionManager to return None for session details + when(sessionManager.getSessionDetails(sessionId)).thenReturn(None) assertThrows[NoSuchElementException] { FlintREPL.setupFlintJobWithExclusionCheck( mockConf, - None, // No sessionIndex provided - None, // No sessionId provided - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) } } test("queryLoop continue until inactivity limit is reached") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" @@ -824,25 +875,24 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] - val commandContext = CommandContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), shortInactivityLimit, - 60) + 60L) // Mock processCommands to always allow loop continuation - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) val startTime = System.currentTimeMillis() @@ -864,30 +914,29 @@ class FlintREPLTest .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(true) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" val longInactivityLimit = 10000 // 10 seconds + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - - val commandContext = CommandContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - longInactivityLimit, - 60) + 60L, + 60L) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -899,6 +948,9 @@ class FlintREPLTest mockGetResponse }) + // Mock getNextStatement to return None, simulating the end of statements + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) + val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -913,15 +965,7 @@ class FlintREPLTest } test("queryLoop should properly shut down the thread pool after execution") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" @@ -929,21 +973,22 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - - val commandContext = CommandContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), inactivityLimit, - 60) + 60L) try { // Mocking ThreadUtils to track the shutdown call @@ -962,16 +1007,7 @@ class FlintREPLTest } test("queryLoop handle exceptions within the loop gracefully") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - // Simulate an exception thrown when hasNext is called - when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) - - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" @@ -979,21 +1015,23 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = CommandContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - inactivityLimit, - 60) + 60L, + 60L) try { // Mocking ThreadUtils to track the shutdown call @@ -1062,19 +1100,37 @@ class FlintREPLTest when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) val flintSessionIndexUpdater = mock[OpenSearchUpdater] - - val commandContext = CommandContext( - mockSparkSession, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").config("spark.flint.job.type", "some_job_type").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + +// val interactiveSession = new InteractiveSession( +// "app123", +// jobId, +// sessionId, +// SessionStates.RUNNING, +// System.currentTimeMillis(), +// System.currentTimeMillis() - 10000 +// ) +// when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + + // Mock getNextStatement to return a valid statement initially and then None + val mockStatement = mock[FlintStatement] + when(mockStatement.queryStartTime).thenReturn(Some(System.currentTimeMillis())) + when(sessionManager.getNextStatement(sessionId)).thenReturn(Some(mockStatement)).thenReturn(None) + + val commandContext = StatementExecutionContext( + spark, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), inactivityLimit, - 60) + 60L) val startTime = Instant.now().toEpochMilli() @@ -1100,9 +1156,7 @@ class FlintREPLTest // Configure mockReader to always return false, indicating no commands to process when(mockReader.hasNext).thenReturn(false) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" @@ -1112,19 +1166,23 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] - val commandContext = CommandContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), inactivityLimit, - 60) + 60L) + + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) val startTime = Instant.now().toEpochMilli()