diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala new file mode 100644 index 000000000..49dc8e355 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +/** + * Trait for writing the result of a query execution to an external data storage. + */ +trait QueryResultWriter { + + /** + * Writes the given DataFrame, which represents the result of a query execution, to an external + * data storage based on the provided FlintStatement metadata. + */ + def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala new file mode 100644 index 000000000..00d48b20c --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/SessionManager.scala @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} + +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode + +/** + * Trait defining the interface for managing interactive sessions. + */ +trait SessionManager { + + /** + * Retrieves metadata about the session manager. + */ + def getSessionContext: Map[String, Any] + + /** + * Fetches the details of a specific session. + */ + def getSessionDetails(sessionId: String): Option[InteractiveSession] + + /** + * Updates the details of a specific session. + */ + def updateSessionDetails( + sessionDetails: InteractiveSession, + updateMode: SessionUpdateMode): Unit + + /** + * Records a heartbeat for a specific session to indicate it is still active. + */ + def recordHeartbeat(sessionId: String): Unit +} + +object SessionUpdateMode extends Enumeration { + type SessionUpdateMode = Value + val UPDATE, UPSERT, UPDATE_IF = Value +} diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala new file mode 100644 index 000000000..acf28c572 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +/** + * Trait defining the interface for managing FlintStatement execution. For example, in FlintREPL, + * multiple FlintStatements are running in a micro-batch within same session. + * + * This interface can also apply to other spark entry point like FlintJob. + */ +trait StatementExecutionManager { + + /** + * Prepares execution of each individual statement + */ + def prepareStatementExecution(): Either[String, Unit] + + /** + * Executes a specific statement and returns the spark dataframe + */ + def executeStatement(statement: FlintStatement): DataFrame + + /** + * Retrieves the next statement to be executed. + */ + def getNextStatement(): Option[FlintStatement] + + /** + * Updates a specific statement. + */ + def updateStatement(statement: FlintStatement): Unit + + /** + * Terminates the statement lifecycle. + */ + def terminateStatementsExecution(): Unit +} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala index bc8b38d9a..00876d46e 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala @@ -65,7 +65,7 @@ class FlintStatement( // Does not include context, which could contain sensitive information. override def toString: String = - s"FlintStatement(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" + s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" } object FlintStatement { diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index 9acdeab5f..915d5e229 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -14,6 +14,8 @@ import org.json4s.JsonAST.{JArray, JString} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization +import org.apache.spark.internal.Logging + object SessionStates { val RUNNING = "running" val DEAD = "dead" @@ -52,7 +54,8 @@ class InteractiveSession( val excludedJobIds: Seq[String] = Seq.empty[String], val error: Option[String] = None, sessionContext: Map[String, Any] = Map.empty[String, Any]) - extends ContextualDataStore { + extends ContextualDataStore + with Logging { context = sessionContext // Initialize the context from the constructor def running(): Unit = state = SessionStates.RUNNING @@ -96,6 +99,7 @@ object InteractiveSession { // Replace extractOpt with jsonOption and map val excludeJobIds: Seq[String] = meta \ "excludeJobIds" match { case JArray(lst) => lst.map(_.extract[String]) + case JString(s) => Seq(s) case _ => Seq.empty[String] } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java index d5fb45f99..1440db1f3 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java @@ -22,7 +22,6 @@ * Abstract OpenSearch Reader. */ public abstract class OpenSearchReader implements FlintReader { - @VisibleForTesting /** Search request source builder. */ public final SearchRequest searchRequest; diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 0bfaf38e6..c96b71fd9 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -235,6 +235,15 @@ object FlintSparkConf { FlintConfig("spark.metadata.accessAWSCredentialsProvider") .doc("AWS credentials provider for metadata access permission") .createOptional() + val CUSTOM_SESSION_MANAGER = + FlintConfig("spark.flint.job.customSessionManager") + .createOptional() + val CUSTOM_STATEMENT_MANAGER = + FlintConfig("spark.flint.job.customStatementManager") + .createOptional() + val CUSTOM_QUERY_RESULT_WRITER = + FlintConfig("spark.flint.job.customQueryResultWriter") + .createOptional() } /** diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala index 7318e5c7c..9371557a2 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -98,9 +98,15 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { * JobOperator instance to accommodate specific runtime requirements. */ val job = - JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount) - job.envinromentProvider = new MockEnvironment( - Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + JobOperator( + appId, + jobRunId, + spark, + query, + dataSourceName, + resultIndex, + true, + streamingRunningCount) job.terminateJVM = false job.start() } diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 1d86a6589..24f3c9a89 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -169,7 +169,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { "spark.flint.job.queryLoopExecutionFrequency", queryLoopExecutionFrequency.toString) - FlintREPL.envinromentProvider = new MockEnvironment( + FlintREPL.environmentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) FlintREPL.enableHiveSupport = false FlintREPL.terminateJVM = false diff --git a/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala index e1e967ded..35c700aca 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/OpenSearchSuite.scala @@ -145,7 +145,6 @@ trait OpenSearchSuite extends BeforeAndAfterAll { val response = openSearchClient.bulk(request, RequestOptions.DEFAULT) - assume( !response.hasFailures, s"bulk index docs to $index failed: ${response.buildFailureMessage()}") diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 048f69ced..0d5c062ae 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -11,14 +11,13 @@ import scala.concurrent.duration.Duration import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} case class CommandContext( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, - resultIndex: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, - jobId: String, + sessionManager: SessionManager, + queryResultWriter: QueryResultWriter, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala index ad49201f0..45b7e81cc 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandState.scala @@ -12,7 +12,6 @@ import org.opensearch.flint.core.storage.FlintReader case class CommandState( recordedLastActivityTime: Long, recordedVerificationResult: VerificationResult, - flintReader: FlintReader, - futureMappingCheck: Future[Either[String, Unit]], + futurePrepareQueryExecution: Future[Either[String, Unit]], executionContext: ExecutionContextExecutor, recordedLastCanPickCheckTime: Long) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index bba999110..c556e2786 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ */ object FlintJob extends Logging with FlintJobExecutor { def main(args: Array[String]): Unit = { - val (queryOption, resultIndex) = parseArgs(args) + val (queryOption, resultIndexOption) = parseArgs(args) val conf = createSparkConf() val jobType = conf.get("spark.flint.job.type", "batch") @@ -41,6 +41,9 @@ object FlintJob extends Logging with FlintJobExecutor { if (query.isEmpty) { logAndThrow(s"Query undefined for the ${jobType} job.") } + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -52,13 +55,19 @@ object FlintJob extends Logging with FlintJobExecutor { conf.set("spark.sql.defaultCatalog", dataSource) configDYNMaxExecutors(conf, jobType) + val applicationId = + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( + applicationId, + jobId, createSparkSession(conf), query, dataSource, - resultIndex, + resultIndexOption.get, jobType.equalsIgnoreCase("streaming"), streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 37801a9e8..95d3ba0f1 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -41,7 +41,7 @@ trait FlintJobExecutor { var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() - var envinromentProvider: EnvironmentProvider = new RealEnvironment() + var environmentProvider: EnvironmentProvider = new RealEnvironment() var enableHiveSupport: Boolean = true // termiante JVM in the presence non-deamon thread before exiting var terminateJVM = true @@ -190,6 +190,7 @@ trait FlintJobExecutor { } } + // scalastyle:off /** * Create a new formatted dataframe with json result, json schema and EMR_STEP_ID. * @@ -201,6 +202,8 @@ trait FlintJobExecutor { * dataframe with result, schema and emr step id */ def getFormattedData( + applicationId: String, + jobId: String, result: DataFrame, spark: SparkSession, dataSource: String, @@ -231,14 +234,13 @@ trait FlintJobExecutor { // after consumed the query result. Streaming query shuffle data is cleaned after each // microBatch execution. cleaner.cleanUp(spark) - // Create the data rows val rows = Seq( ( resultToSave, resultSchemaToSave, - envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + jobId, + applicationId, dataSource, "SUCCESS", "", @@ -254,6 +256,8 @@ trait FlintJobExecutor { } def constructErrorDF( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, status: String, @@ -270,8 +274,8 @@ trait FlintJobExecutor { ( null, null, - envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + jobId, + applicationId, dataSource, status.toUpperCase(Locale.ROOT), error, @@ -396,6 +400,8 @@ trait FlintJobExecutor { } def executeQuery( + applicationId: String, + jobId: String, spark: SparkSession, query: String, dataSource: String, @@ -409,6 +415,8 @@ trait FlintJobExecutor { val result: DataFrame = spark.sql(query) // Get Data getFormattedData( + applicationId, + jobId, result, spark, dataSource, @@ -493,16 +501,21 @@ trait FlintJobExecutor { } } - def parseArgs(args: Array[String]): (Option[String], String) = { + /** + * Before OS 2.13, there are two arguments from entry point: query and result index Starting + * from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also + * optional for non-OpenSearch result persist + */ + def parseArgs(args: Array[String]): (Option[String], Option[String]) = { args match { + case Array() => + (None, None) case Array(resultIndex) => - (None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument + (None, Some(resultIndex)) case Array(query, resultIndex) => - ( - Some(query), - resultIndex - ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.") + (Some(query), Some(resultIndex)) + case _ => + logAndThrow("Unsupported number of arguments. Expected no more than two arguments.") } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index e6b8b11ce..340c0656e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -11,28 +11,22 @@ import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} import scala.concurrent.duration._ -import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import com.codahale.metrics.Timer -import org.json4s.native.Serialization -import org.opensearch.action.get.GetResponse -import org.opensearch.common.Strings -import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} -import org.opensearch.flint.common.model.InteractiveSession.formats +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.FlintREPLConfConstants._ +import org.apache.spark.sql.SessionUpdateMode._ import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} object FlintREPLConfConstants { val HEARTBEAT_INTERVAL_MILLIS = 60000L @@ -61,23 +55,19 @@ object FlintREPL extends Logging with FlintJobExecutor { @volatile var earlyExitFlag: Boolean = false - def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { - updater.update(flintStatement.statementId, FlintStatement.serialize(flintStatement)) - } - private val sessionRunningCount = new AtomicInteger(0) private val statementRunningCount = new AtomicInteger(0) def main(args: Array[String]) { - val (queryOption, resultIndex) = parseArgs(args) - - if (Strings.isNullOrEmpty(resultIndex)) { - logAndThrow("resultIndex is not set") - } + val (queryOption, resultIndexOption) = parseArgs(args) // init SparkContext val conf: SparkConf = createSparkConf() - val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") + val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "") + + if (dataSource.trim.isEmpty) { + logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set or is empty") + } // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -88,6 +78,10 @@ object FlintREPL extends Logging with FlintJobExecutor { */ conf.set("spark.sql.defaultCatalog", dataSource) + val applicationId = + environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val jobType = conf.get(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) CustomLogging.logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") conf.set(FlintSparkConf.JOB_TYPE.key, jobType) @@ -95,36 +89,29 @@ object FlintREPL extends Logging with FlintJobExecutor { val query = getQuery(queryOption, jobType, conf) if (jobType.equalsIgnoreCase("streaming")) { - logInfo(s"""streaming query ${query}""") + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } configDYNMaxExecutors(conf, jobType) val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( + applicationId, + jobId, createSparkSession(conf), query, dataSource, - resultIndex, + resultIndexOption.get, true, streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } else { - // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) - val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) - - if (sessionIndex.isEmpty) { - logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") - } - if (sessionId.isEmpty) { - logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") - } - + // we don't allow default value for sessionId. Throw exception if key not found. + val sessionId = getSessionId(conf) + logInfo(s"sessionId: ${sessionId}") val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) - val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = - envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val sessionManager = instantiateSessionManager(spark, resultIndexOption) // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = @@ -142,7 +129,6 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong( "spark.flint.job.queryLoopExecutionFrequency", DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) /** @@ -155,12 +141,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/320 */ spark.sparkContext.addSparkListener( - new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionId.get, - sessionTimerContext)) + new PreShutdownListener(sessionId, sessionManager, sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -173,37 +154,29 @@ object FlintREPL extends Logging with FlintJobExecutor { // OpenSearch triggers recovery after 1 minute outdated heart beat var heartBeatFuture: ScheduledFuture[_] = null try { - heartBeatFuture = createHeartBeatUpdater( - HEARTBEAT_INTERVAL_MILLIS, - flintSessionIndexUpdater, - sessionId.get, - threadPool, - osClient, - sessionIndex.get, - INITIAL_DELAY_MILLIS) + heartBeatFuture = createHeartBeatUpdater(sessionId, sessionManager, threadPool) if (setupFlintJobWithExclusionCheck( conf, - sessionIndex, sessionId, - osClient, jobId, applicationId, - flintSessionIndexUpdater, + sessionManager, jobStartTime)) { earlyExitFlag = true return } + val queryResultWriter = + instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, - resultIndex, - sessionId.get, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - jobId, + sessionId, + sessionManager, + queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis, @@ -218,11 +191,9 @@ object FlintREPL extends Logging with FlintJobExecutor { e, applicationId, jobId, - sessionId.get, + sessionId, + sessionManager, jobStartTime, - flintSessionIndexUpdater, - osClient, - sessionIndex.get, sessionTimerContext) } finally { if (threadPool != null) { @@ -295,25 +266,16 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def setupFlintJobWithExclusionCheck( conf: SparkConf, - sessionIndex: Option[String], - sessionId: Option[String], - osClient: OSClient, + sessionId: String, jobId: String, applicationId: String, - flintSessionIndexUpdater: OpenSearchUpdater, + sessionManager: SessionManager, jobStartTime: Long): Boolean = { val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) - confExcludeJobsOpt match { case None => // If confExcludeJobs is None, pass null or an empty sequence as per your setupFlintJob method's signature - setupFlintJob( - applicationId, - jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, - jobStartTime) + setupFlintJob(applicationId, jobId, sessionId, sessionManager, jobStartTime) case Some(confExcludeJobs) => // example: --conf spark.flint.deployment.excludeJobs=job-1,job-2 @@ -324,25 +286,22 @@ object FlintREPL extends Logging with FlintJobExecutor { return true } - val getResponse = osClient.getDoc(sessionIndex.get, sessionId.get) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source != null) { - val existingExcludedJobIds = parseExcludedJobIds(source) + sessionManager.getSessionDetails(sessionId) match { + case Some(sessionDetails) => + val existingExcludedJobIds = sessionDetails.excludedJobIds if (excludeJobIds.sorted == existingExcludedJobIds.sorted) { logInfo("duplicate job running, exit the application.") return true } - } + case _ => // Do nothing } // If none of the edge cases are met, proceed with setup setupFlintJob( applicationId, jobId, - sessionId.get, - flintSessionIndexUpdater, - sessionIndex.get, + sessionId, + sessionManager, jobStartTime, excludeJobIds) } @@ -350,15 +309,14 @@ object FlintREPL extends Logging with FlintJobExecutor { } def queryLoop(commandContext: CommandContext): Unit = { + import commandContext._ // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - var futureMappingCheck: Future[Either[String, Unit]] = null + var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { - checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) - } + logInfo(s"""Executing session with sessionId: ${sessionId}""") var lastActivityTime = currentTimeProvider.currentEpochMillis() var verificationResult: VerificationResult = NotVerified @@ -366,25 +324,22 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logInfo( - s"""read from ${commandContext.sessionIndex}, sessionId: ${commandContext.sessionId}""") - val flintReader: FlintReader = - createQueryReader( - commandContext.osClient, - commandContext.sessionId, - commandContext.sessionIndex, - commandContext.dataSource) + val statementsExecutionManager = + instantiateStatementExecutionManager(commandContext) + + futurePrepareQueryExecution = Future { + statementsExecutionManager.prepareStatementExecution() + } try { val commandState = CommandState( lastActivityTime, verificationResult, - flintReader, - futureMappingCheck, + futurePrepareQueryExecution, executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(commandContext, commandState) + processCommands(statementsExecutionManager, commandContext, commandState) val ( updatedLastActivityTime, @@ -397,7 +352,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - flintReader.close() + statementsExecutionManager.terminateStatementsExecution() } Thread.sleep(commandContext.queryLoopExecutionFrequency) @@ -413,92 +368,71 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId: String, jobId: String, sessionId: String, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionIndex: String, + sessionManager: SessionManager, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) - val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new InteractiveSession( + refreshSessionState( applicationId, jobId, sessionId, - "running", - currentTime, + sessionManager, jobStartTime, - excludeJobIds) + SessionStates.RUNNING, + excludedJobIds = excludeJobIds) - val serializedFlintInstance = if (includeJobId) { - InteractiveSession.serialize(flintJob, currentTime, true) - } else { - InteractiveSession.serializeWithoutJobId(flintJob, currentTime) - } - flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logInfo( - s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") sessionRunningCount.incrementAndGet() } - def handleSessionError( - e: Exception, + private def refreshSessionState( applicationId: String, jobId: String, sessionId: String, + sessionManager: SessionManager, jobStartTime: Long, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, - sessionTimerContext: Timer.Context): Unit = { - val error = s"Session error: ${e.getMessage}" - CustomLogging.logError(error, e) - - val flintInstance = getExistingFlintInstance(osClient, sessionIndex, sessionId) - .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) - - updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) - if (flintInstance.isFail) { - recordSessionFailed(sessionTimerContext) - } - } - - private def getExistingFlintInstance( - osClient: OSClient, - sessionIndex: String, - sessionId: String): Option[InteractiveSession] = Try( - osClient.getDoc(sessionIndex, sessionId)) match { - case Success(getResponse) if getResponse.isExists() => - Option(getResponse.getSourceAsMap) - .map(InteractiveSession.deserializeFromMap) - case Failure(exception) => - CustomLogging.logError( - s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", - exception) - None - case _ => None + state: String, + error: Option[String] = None, + excludedJobIds: Seq[String] = Seq.empty[String]): InteractiveSession = { + logInfo(s"refreshSessionState: ${jobId}") + val sessionDetails = sessionManager + .getSessionDetails(sessionId) + .getOrElse( + new InteractiveSession( + applicationId, + jobId, + sessionId, + state, + currentTimeProvider.currentEpochMillis(), + jobStartTime, + error = error, + excludedJobIds = excludedJobIds)) + logInfo(s"Current session: ${sessionDetails}") + logInfo(s"State is: ${sessionDetails.state}") + sessionDetails.state = state + logInfo(s"State is: ${sessionDetails.state}") + sessionManager.updateSessionDetails(sessionDetails, updateMode = UPSERT) + sessionDetails } - private def createFailedFlintInstance( + def handleSessionError( + e: Exception, applicationId: String, jobId: String, sessionId: String, + sessionManager: SessionManager, jobStartTime: Long, - errorMessage: String): InteractiveSession = new InteractiveSession( - applicationId, - jobId, - sessionId, - "fail", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - error = Some(errorMessage)) - - private def updateFlintInstance( - flintInstance: InteractiveSession, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { - val currentTime = currentTimeProvider.currentEpochMillis() - flintSessionIndexUpdater.upsert( + sessionTimerContext: Timer.Context): Unit = { + val error = s"Session error: ${e.getMessage}" + CustomLogging.logError(error, e) + + refreshSessionState( + applicationId, + jobId, sessionId, - InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(e.getMessage)) + recordSessionFailed(sessionTimerContext) } /** @@ -522,6 +456,8 @@ object FlintREPL extends Logging with FlintJobExecutor { * failed data frame */ def handleCommandFailureAndGetFailedData( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, error: String, @@ -531,6 +467,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.fail() flintStatement.error = Some(error) super.constructErrorDF( + applicationId, + jobId, spark, dataSource, flintStatement.state, @@ -549,6 +487,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processCommands( + statementExecutionManager: StatementExecutionManager, context: CommandContext, state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ @@ -562,46 +501,43 @@ object FlintREPL extends Logging with FlintJobExecutor { while (canProceed) { val currentTime = currentTimeProvider.currentEpochMillis() - // Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) { - canPickNextStatementResult = - canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + canPickNextStatementResult = canPickNextStatement(sessionId, sessionManager, jobId) lastCanPickCheckTime = currentTime } if (!canPickNextStatementResult) { earlyExitFlag = true canProceed = false - } else if (!flintReader.hasNext) { - canProceed = false } else { - val statementTimerContext = getTimerContext( - MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) - - val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( - recordedVerificationResult, - spark, - flintStatement, - dataSource, - sessionId, - executionContext, - futureMappingCheck, - resultIndex, - queryExecutionTimeout, - queryWaitTimeMillis) - - verificationResult = returnedVerificationResult - finalizeCommand( - dataToWrite, - flintStatement, - resultIndex, - flintSessionIndexUpdater, - osClient, - statementTimerContext) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() + statementExecutionManager.getNextStatement() match { + case Some(flintStatement) => + flintStatement.running() + statementExecutionManager.updateStatement(flintStatement) + statementRunningCount.incrementAndGet() + + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) + val (dataToWrite, returnedVerificationResult) = + processStatementOnVerification( + statementExecutionManager, + flintStatement, + state, + context) + + verificationResult = returnedVerificationResult + finalizeCommand( + statementExecutionManager, + context, + dataToWrite, + flintStatement, + statementTimerContext) + // last query finish time is last activity time + lastActivityTime = currentTimeProvider.currentEpochMillis() + case _ => + canProceed = false + } } } @@ -610,32 +546,26 @@ object FlintREPL extends Logging with FlintJobExecutor { } /** - * finalize command after processing + * finalize statement after processing * * @param dataToWrite * data to write * @param flintStatement - * flint command - * @param resultIndex - * result index - * @param flintSessionIndexUpdater - * flint session index updater + * flint statement */ private def finalizeCommand( + statementExecutionManager: StatementExecutionManager, + commandContext: CommandContext, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, - resultIndex: String, - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, statementTimerContext: Timer.Context): Unit = { + import commandContext._ try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling flintStatement.complete() } - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -643,18 +573,21 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) flintStatement.fail() - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - recordStatementStateChange(flintStatement, statementTimerContext) + } finally { + statementExecutionManager.updateStatement(flintStatement) + recordStatementStateChange(flintStatement, statementTimerContext) } } private def handleCommandTimeout( + applicationId: String, + jobId: String, spark: SparkSession, dataSource: String, error: String, flintStatement: FlintStatement, sessionId: String, - startTime: Long): DataFrame = { + startTime: Long) = { /* * https://tinyurl.com/2ezs5xj9 * @@ -671,6 +604,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.timeout() flintStatement.error = Some(error) super.constructErrorDF( + applicationId, + jobId, spark, dataSource, flintStatement.state, @@ -681,9 +616,13 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } + // scalastyle:off def executeAndHandle( + applicationId: String, + jobId: String, spark: SparkSession, flintStatement: FlintStatement, + statementExecutionManager: StatementExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -693,8 +632,11 @@ object FlintREPL extends Logging with FlintJobExecutor { try { Some( executeQueryAsync( + applicationId, + jobId, spark, flintStatement, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -705,11 +647,22 @@ object FlintREPL extends Logging with FlintJobExecutor { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - Some(handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)) + Some( + handleCommandTimeout( + applicationId, + jobId, + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) case e: Exception => val error = processQueryException(e, flintStatement) Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -720,16 +673,13 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processStatementOnVerification( - recordedVerificationResult: VerificationResult, - spark: SparkSession, + statementExecutionManager: StatementExecutionManager, flintStatement: FlintStatement, - dataSource: String, - sessionId: String, - executionContext: ExecutionContextExecutor, - futureMappingCheck: Future[Either[String, Unit]], - resultIndex: String, - queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long) = { + commandState: CommandState, + commandContext: CommandContext) = { + import commandState._ + import commandContext._ + val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None @@ -737,11 +687,14 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult match { case NotVerified => try { - ThreadUtils.awaitResult(futureMappingCheck, MAPPING_CHECK_TIMEOUT) match { + ThreadUtils.awaitResult(futurePrepareQueryExecution, MAPPING_CHECK_TIMEOUT) match { case Right(_) => dataToWrite = executeAndHandle( + applicationId, + jobId, spark, flintStatement, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -753,6 +706,8 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = VerifiedWithError(error) dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -762,10 +717,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } } catch { case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" + val error = s"Query execution preparation timed out" CustomLogging.logError(error, e) dataToWrite = Some( handleCommandTimeout( + applicationId, + jobId, spark, dataSource, error, @@ -777,6 +734,8 @@ object FlintREPL extends Logging with FlintJobExecutor { CustomLogging.logError(error, e) dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, error, @@ -787,6 +746,8 @@ object FlintREPL extends Logging with FlintJobExecutor { case VerifiedWithError(err) => dataToWrite = Some( handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, err, @@ -795,8 +756,11 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime)) case VerifiedWithoutError => dataToWrite = executeAndHandle( + applicationId, + jobId, spark, flintStatement, + statementExecutionManager, dataSource, sessionId, executionContext, @@ -810,8 +774,11 @@ object FlintREPL extends Logging with FlintJobExecutor { } def executeQueryAsync( + applicationId: String, + jobId: String, spark: SparkSession, flintStatement: FlintStatement, + statementsExecutionManager: StatementExecutionManager, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -821,6 +788,8 @@ object FlintREPL extends Logging with FlintJobExecutor { if (currentTimeProvider .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSource, "wait timeout", @@ -829,83 +798,16 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } else { val futureQueryExecution = Future { - executeQuery( - spark, - flintStatement.query, - dataSource, - flintStatement.queryId, - sessionId, - false) + statementsExecutionManager.executeStatement(flintStatement) }(executionContext) - // time out after 10 minutes ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) } } - private def processCommandInitiation( - flintReader: FlintReader, - flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { - val command = flintReader.next() - logDebug(s"raw command: $command") - val flintStatement = FlintStatement.deserialize(command) - logDebug(s"command: $flintStatement") - flintStatement.running() - logDebug(s"command running: $flintStatement") - updateSessionIndex(flintStatement, flintSessionIndexUpdater) - statementRunningCount.incrementAndGet() - flintStatement - } - - private def createQueryReader( - osClient: OSClient, - sessionId: String, - sessionIndex: String, - dataSource: String) = { - // all state in index are in lower case - // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the - // same doc - val dsl = - s"""{ - | "bool": { - | "must": [ - | { - | "term": { - | "type": "statement" - | } - | }, - | { - | "term": { - | "state": "waiting" - | } - | }, - | { - | "term": { - | "sessionId": "$sessionId" - | } - | }, - | { - | "term": { - | "dataSourceName": "$dataSource" - | } - | }, - | { - | "range": { - | "submitTime": { "gte": "now-1h" } - | } - | } - | ] - | } - |}""".stripMargin - - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader - } class PreShutdownListener( - flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient, - sessionIndex: String, sessionId: String, + sessionManager: SessionManager, sessionTimerContext: Timer.Context) extends SparkListener with Logging { @@ -913,77 +815,43 @@ object FlintREPL extends Logging with FlintJobExecutor { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") logInfo("earlyExitFlag: " + earlyExitFlag) - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (!getResponse.isExists()) { - return - } - - val source = getResponse.getSourceAsMap - if (source == null) { - return - } - - val state = Option(source.get("state")).map(_.asInstanceOf[String]) - // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, - // it indicates that the control plane has already initiated a new session to handle remaining requests for the - // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new - // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, - // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption - // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure - // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate - // processing. - if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { - updateFlintInstanceBeforeShutdown( - source, - getResponse, - flintSessionIndexUpdater, - sessionId, - sessionTimerContext) + try { + sessionManager.getSessionDetails(sessionId).foreach { sessionDetails => + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && !sessionDetails.isComplete && !sessionDetails.isFail) { + sessionDetails.complete() + logInfo(s"jobId before shutting down session: ${sessionDetails.jobId}") + sessionManager.updateSessionDetails(sessionDetails, updateMode = UPDATE_IF) + recordSessionSuccess(sessionTimerContext) + } + } + } catch { + case e: Exception => logError(s"Failed to update session state for $sessionId", e) } } } - private def updateFlintInstanceBeforeShutdown( - source: java.util.Map[String, AnyRef], - getResponse: GetResponse, - flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String, - sessionTimerContext: Timer.Context): Unit = { - val flintInstance = InteractiveSession.deserializeFromMap(source) - flintInstance.complete() - flintSessionIndexUpdater.updateIf( - sessionId, - InteractiveSession.serializeWithoutJobId( - flintInstance, - currentTimeProvider.currentEpochMillis()), - getResponse.getSeqNo, - getResponse.getPrimaryTerm) - recordSessionSuccess(sessionTimerContext) - } - /** - * Create a new thread to update the last update time of the flint instance. - * @param currentInterval - * the interval of updating the last update time. Unit is millisecond. - * @param flintSessionUpdater - * the updater of the flint instance. + * Create a new thread to update the last update time of the flint interactive session. + * * @param sessionId - * the session id of the flint instance. + * the session id of the flint interactive session. + * @param sessionManager + * the manager of the flint interactive session. * @param threadPool * the thread pool. - * @param osClient - * the OpenSearch client. - * @param initialDelayMillis - * the intial delay to start heartbeat */ def createHeartBeatUpdater( - currentInterval: Long, - flintSessionUpdater: OpenSearchUpdater, sessionId: String, - threadPool: ScheduledExecutorService, - osClient: OSClient, - sessionIndex: String, - initialDelayMillis: Long): ScheduledFuture[_] = { + sessionManager: SessionManager, + threadPool: ScheduledExecutorService): ScheduledFuture[_] = { threadPool.scheduleAtFixedRate( new Runnable { @@ -994,13 +862,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logWarning("HeartBeatUpdater has been interrupted. Terminating.") return // Exit the run method if the thread is interrupted } - - flintSessionUpdater.upsert( - sessionId, - Serialization.write( - Map( - "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), - "state" -> "running"))) + sessionManager.recordHeartbeat(sessionId) } catch { case ie: InterruptedException => // Preserve the interrupt status @@ -1020,8 +882,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } } }, - initialDelayMillis, - currentInterval, + INITIAL_DELAY_MILLIS, + HEARTBEAT_INTERVAL_MILLIS, java.util.concurrent.TimeUnit.MILLISECONDS) } @@ -1038,35 +900,26 @@ object FlintREPL extends Logging with FlintJobExecutor { */ def canPickNextStatement( sessionId: String, - jobId: String, - osClient: OSClient, - sessionIndex: String): Boolean = { + sessionManager: SessionManager, + jobId: String): Boolean = { try { - val getResponse = osClient.getDoc(sessionIndex, sessionId) - if (getResponse.isExists()) { - val source = getResponse.getSourceAsMap - if (source == null) { - logError(s"""Session id ${sessionId} is empty""") - // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) - return true - } - - val runJobId = Option(source.get("jobId")).map(_.asInstanceOf[String]).orNull - val excludeJobIds: Seq[String] = parseExcludedJobIds(source) - - if (runJobId != null && jobId != runJobId) { - logInfo(s"""the current job ID ${jobId} is not the running job ID ${runJobId}""") - return false - } - if (excludeJobIds != null && excludeJobIds.contains(jobId)) { - logInfo(s"""${jobId} is in the list of excluded jobs""") - return false - } - true - } else { - // still proceed since we are not sure what happened (e.g., session doc may not be available yet) - logError(s"""Fail to find id ${sessionId} from session index""") - true + sessionManager.getSessionDetails(sessionId) match { + case Some(sessionDetails) => + val runJobId = sessionDetails.jobId + val excludeJobIds = sessionDetails.excludedJobIds + + if (!runJobId.isEmpty && jobId != runJobId) { + logInfo(s"the current job ID $jobId is not the running job ID ${runJobId}") + return false + } + if (excludeJobIds.contains(jobId)) { + logInfo(s"$jobId is in the list of excluded jobs") + return false + } + true + case None => + logError(s"Failed to fetch sessionDetails by sessionId: $sessionId.") + true } } catch { // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) @@ -1076,23 +929,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def parseExcludedJobIds(source: java.util.Map[String, AnyRef]): Seq[String] = { - - val rawExcludeJobIds = source.get("excludeJobIds") - Option(rawExcludeJobIds) - .map { - case s: String => Seq(s) - case list: java.util.List[_] @unchecked => - import scala.collection.JavaConverters._ - list.asScala.toList - .collect { case str: String => str } // Collect only strings from the list - case other => - logInfo(s"Unexpected type: ${other.getClass.getName}") - Seq.empty - } - .getOrElse(Seq.empty[String]) // In case of null, return an empty Seq - } - def exponentialBackoffRetry[T](maxRetries: Int, initialDelay: FiniteDuration)( block: => T): T = { var retries = 0 @@ -1130,6 +966,61 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } + def getSessionId(conf: SparkConf): String = { + conf.getOption(FlintSparkConf.SESSION_ID.key) match { + case Some(sessionId) if sessionId.nonEmpty => + sessionId + case _ => + logAndThrow(s"${FlintSparkConf.SESSION_ID.key} is not set or is empty") + } + } + + private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { + if (className.isEmpty) { + defaultConstructor + } else { + try { + val classObject = Utils.classForName(className) + val ctor = if (args.isEmpty) { + classObject.getDeclaredConstructor() + } else { + classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*) + } + ctor.setAccessible(true) + ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to instantiate provider: $className", e) + } + } + } + + private def instantiateSessionManager( + spark: SparkSession, + resultIndexOption: Option[String]): SessionManager = { + instantiate( + new SessionManagerImpl(spark, resultIndexOption), + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, "")) + } + + private def instantiateStatementExecutionManager( + commandContext: CommandContext): StatementExecutionManager = { + import commandContext._ + instantiate( + new StatementExecutionManagerImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + spark, + sessionId) + } + + private def instantiateQueryResultWriter( + sparkConf: SparkConf, + context: Map[String, Any]): QueryResultWriter = { + instantiate( + new QueryResultWriterImpl(context), + sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) + } + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index c079b3e96..b49f4a9ed 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.util.ShuffleCleaner import org.apache.spark.util.ThreadUtils case class JobOperator( + applicationId: String, + jobId: String, spark: SparkSession, query: String, dataSource: String, @@ -51,13 +53,23 @@ case class JobOperator( val futureMappingCheck = Future { checkAndCreateIndex(osClient, resultIndex) } - val data = executeQuery(spark, query, dataSource, "", "", streaming) + val data = executeQuery(applicationId, jobId, spark, query, dataSource, "", "", streaming) val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) dataToWrite = Some(mappingCheckResult match { case Right(_) => data case Left(error) => - constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "FAILED", + error, + "", + query, + "", + startTime) }) exceptionThrown = false } catch { @@ -65,11 +77,31 @@ case class JobOperator( val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) dataToWrite = Some( - constructErrorDF(spark, dataSource, "TIMEOUT", error, "", query, "", startTime)) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "TIMEOUT", + error, + "", + query, + "", + startTime)) case e: Exception => val error = processQueryException(e) dataToWrite = Some( - constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime)) + constructErrorDF( + applicationId, + jobId, + spark, + dataSource, + "FAILED", + error, + "", + query, + "", + startTime)) } finally { cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala new file mode 100644 index 000000000..238f8fa3d --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch + +class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging { + + private val resultIndex = context("resultIndex").asInstanceOf[String] + private val osClient = context("osClient").asInstanceOf[OSClient] + + override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { + writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala new file mode 100644 index 000000000..2039159e4 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SessionManagerImpl.scala @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.{Failure, Success, Try} + +import org.json4s.native.Serialization +import org.opensearch.flint.common.model.InteractiveSession +import org.opensearch.flint.common.model.InteractiveSession.formats +import org.opensearch.flint.core.logging.CustomLogging + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode +import org.apache.spark.sql.flint.config.FlintSparkConf + +class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String]) + extends SessionManager + with FlintJobExecutor + with Logging { + + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } + + // we don't allow default value for sessionIndex. Throw exception if key not found. + val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "") + + if (sessionIndex.isEmpty) { + logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") + } + + val osClient = new OSClient(FlintSparkConf().flintOptions()) + lazy val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex) + + override def getSessionContext: Map[String, Any] = { + Map( + "sessionIndex" -> sessionIndex, + "resultIndex" -> resultIndexOption.get, + "osClient" -> osClient, + "flintSessionIndexUpdater" -> flintSessionIndexUpdater) + } + + override def getSessionDetails(sessionId: String): Option[InteractiveSession] = { + Try(osClient.getDoc(sessionIndex, sessionId)) match { + case Success(getResponse) if getResponse.isExists => + // Retrieve the source map and create session + val sessionOption = Option(getResponse.getSourceAsMap) + .map(InteractiveSession.deserializeFromMap) + // Retrieve sequence number and primary term from the response + val seqNo = getResponse.getSeqNo + val primaryTerm = getResponse.getPrimaryTerm + + // Add seqNo and primaryTerm to the session context + sessionOption.foreach { session => + session.setContextValue("seqNo", seqNo) + session.setContextValue("primaryTerm", primaryTerm) + } + + sessionOption + case Failure(exception) => + CustomLogging.logError( + s"Failed to retrieve existing InteractiveSession: ${exception.getMessage}", + exception) + None + + case _ => None + } + } + + override def updateSessionDetails( + sessionDetails: InteractiveSession, + sessionUpdateMode: SessionUpdateMode): Unit = { + sessionUpdateMode match { + case SessionUpdateMode.UPDATE => + flintSessionIndexUpdater.update( + sessionDetails.sessionId, + InteractiveSession.serialize(sessionDetails, currentTimeProvider.currentEpochMillis())) + case SessionUpdateMode.UPSERT => + val includeJobId = + !sessionDetails.excludedJobIds.isEmpty && !sessionDetails.excludedJobIds.contains( + sessionDetails.jobId) + val serializedSession = if (includeJobId) { + InteractiveSession.serialize( + sessionDetails, + currentTimeProvider.currentEpochMillis(), + true) + } else { + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()) + } + flintSessionIndexUpdater.upsert(sessionDetails.sessionId, serializedSession) + case SessionUpdateMode.UPDATE_IF => + val seqNo = sessionDetails + .getContextValue("seqNo") + .getOrElse(throw new IllegalArgumentException("Missing seqNo for conditional update")) + .asInstanceOf[Long] + val primaryTerm = sessionDetails + .getContextValue("primaryTerm") + .getOrElse( + throw new IllegalArgumentException("Missing primaryTerm for conditional update")) + .asInstanceOf[Long] + flintSessionIndexUpdater.updateIf( + sessionDetails.sessionId, + InteractiveSession.serializeWithoutJobId( + sessionDetails, + currentTimeProvider.currentEpochMillis()), + seqNo, + primaryTerm) + } + + logInfo( + s"""Updated job: {"jobid": ${sessionDetails.jobId}, "sessionId": ${sessionDetails.sessionId}} from $sessionIndex""") + } + + override def recordHeartbeat(sessionId: String): Unit = { + flintSessionIndexUpdater.upsert( + sessionId, + Serialization.write( + Map("lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala new file mode 100644 index 000000000..0b059f1d3 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.storage.OpenSearchUpdater +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging + +/** + * StatementExecutionManagerImpl is session based implementation of StatementExecutionManager + * interface It uses FlintReader to fetch all pending queries in a mirco-batch + * @param commandContext + */ +class StatementExecutionManagerImpl(commandContext: CommandContext) + extends StatementExecutionManager + with FlintJobExecutor + with Logging { + + private val context = commandContext.sessionManager.getSessionContext + private val sessionIndex = context("sessionIndex").asInstanceOf[String] + private val resultIndex = context("resultIndex").asInstanceOf[String] + private val osClient = context("osClient").asInstanceOf[OSClient] + private val flintSessionIndexUpdater = + context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] + + // Using one reader client within same session will cause concurrency issue. + // To resolve this move the reader creation and getNextStatement method to mirco-batch level + private val flintReader = createOpenSearchQueryReader() + + override def prepareStatementExecution(): Either[String, Unit] = { + checkAndCreateIndex(osClient, resultIndex) + } + override def updateStatement(statement: FlintStatement): Unit = { + flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) + } + override def terminateStatementsExecution(): Unit = { + flintReader.close() + } + + override def getNextStatement(): Option[FlintStatement] = { + if (flintReader.hasNext) { + val rawStatement = flintReader.next() + val flintStatement = FlintStatement.deserialize(rawStatement) + logInfo(s"Next statement to execute: $flintStatement") + Some(flintStatement) + } else { + None + } + } + + override def executeStatement(statement: FlintStatement): DataFrame = { + import commandContext._ + executeQuery( + applicationId, + jobId, + spark, + statement.query, + dataSource, + statement.queryId, + sessionId, + false) + } + + private def createOpenSearchQueryReader() = { + import commandContext._ + // all state in index are in lower case + // we only search for statement submitted in the last hour in case of unexpected bugs causing infinite loop in the + // same doc + val dsl = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "type": "statement" + | } + | }, + | { + | "term": { + | "state": "waiting" + | } + | }, + | { + | "term": { + | "sessionId": "$sessionId" + | } + | }, + | { + | "term": { + | "dataSourceName": "$dataSource" + | } + | }, + | { + | "range": { + | "submitTime": { "gte": "now-1h" } + | } + | } + | ] + | } + |}""".stripMargin + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) + flintReader + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 19f596e31..339c1870d 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -11,6 +11,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider} class FlintJobTest extends SparkFunSuite with JobMatchers { + private val jobId = "testJobId" + private val applicationId = "testApplicationId" val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() @@ -55,8 +57,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { Array( "{'column_name':'Letter','data_type':'string'}", "{'column_name':'Number','data_type':'integer'}"), - "unknown", - "unknown", + jobId, + applicationId, dataSourceName, "SUCCESS", "", @@ -72,6 +74,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { // Compare the result val result = FlintJob.getFormattedData( + applicationId, + jobId, input, spark, dataSourceName, diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 9c193fc9a..433c2351b 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -18,11 +18,11 @@ import scala.reflect.runtime.universe.TypeTag import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer import org.mockito.{ArgumentMatchersSugar, Mockito} -import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} +import org.mockito.Mockito.{atLeastOnce, doNothing, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse -import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.scalatest.prop.TableDrivenPropertyChecks._ @@ -48,38 +48,59 @@ class FlintREPLTest // By using a type alias and casting, I can bypass the type checking error. type AnyScheduledFuture = ScheduledFuture[_] - test( - "parseArgs with one argument should return None for query and the argument as resultIndex") { + private val jobId = "testJobId" + private val applicationId = "testApplicationId" + + test("parseArgs with no arguments should return (None, None)") { + val args = Array.empty[String] + val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) + queryOption shouldBe None + resultIndexOption shouldBe None + } + + test("parseArgs with one argument should return None for query and Some for resultIndex") { val args = Array("resultIndexName") - val (queryOption, resultIndex) = FlintREPL.parseArgs(args) + val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) queryOption shouldBe None - resultIndex shouldBe "resultIndexName" + resultIndexOption shouldBe Some("resultIndexName") } - test( - "parseArgs with two arguments should return the first argument as query and the second as resultIndex") { + test("parseArgs with two arguments should return Some for both query and resultIndex") { val args = Array("SELECT * FROM table", "resultIndexName") - val (queryOption, resultIndex) = FlintREPL.parseArgs(args) + val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args) queryOption shouldBe Some("SELECT * FROM table") - resultIndex shouldBe "resultIndexName" + resultIndexOption shouldBe Some("resultIndexName") } test( - "parseArgs with no arguments should throw IllegalArgumentException with specific message") { - val args = Array.empty[String] + "parseArgs with more than two arguments should throw IllegalArgumentException with specific message") { + val args = Array("arg1", "arg2", "arg3") val exception = intercept[IllegalArgumentException] { FlintREPL.parseArgs(args) } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + exception.getMessage shouldBe "Unsupported number of arguments. Expected no more than two arguments." } - test( - "parseArgs with more than two arguments should throw IllegalArgumentException with specific message") { - val args = Array("arg1", "arg2", "arg3") + test("getSessionId should throw exception when SESSION_ID is not set") { + val conf = new SparkConf() val exception = intercept[IllegalArgumentException] { - FlintREPL.parseArgs(args) + FlintREPL.getSessionId(conf) + } + assert(exception.getMessage === FlintSparkConf.SESSION_ID.key + " is not set or is empty") + } + + test("getSessionId should return the session ID when it's set") { + val sessionId = "test-session-id" + val conf = new SparkConf().set(FlintSparkConf.SESSION_ID.key, sessionId) + assert(FlintREPL.getSessionId(conf) === sessionId) + } + + test("getSessionId should throw exception when SESSION_ID is set to empty string") { + val conf = new SparkConf().set(FlintSparkConf.SESSION_ID.key, "") + val exception = intercept[IllegalArgumentException] { + FlintREPL.getSessionId(conf) } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + assert(exception.getMessage === FlintSparkConf.SESSION_ID.key + " is not set or is empty") } test("getQuery should return query from queryOption if present") { @@ -159,19 +180,13 @@ class FlintREPLTest test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks - val flintSessionUpdater = mock[OpenSearchUpdater] - val osClient = mock[OSClient] val threadPool = mock[ScheduledExecutorService] - val getResponse = mock[GetResponse] val scheduledFutureRaw = mock[ScheduledFuture[_]] - + val sessionManager = mock[SessionManager] + val sessionId = "session1" // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. - when( - threadPool.scheduleAtFixedRate( - any[Runnable], - eqTo(0), - *, - eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + when(threadPool + .scheduleAtFixedRate(any[Runnable], *, *, eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) .thenAnswer((invocation: InvocationOnMock) => { val runnable = invocation.getArgument[Runnable](0) runnable.run() @@ -179,55 +194,35 @@ class FlintREPLTest }) // Invoke the method - FlintREPL.createHeartBeatUpdater( - 1000L, - flintSessionUpdater, - "session1", - threadPool, - osClient, - "sessionIndex", - 0) + FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) // Verifications - verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) + verify(sessionManager, atLeastOnce()).recordHeartbeat(sessionId) } test("PreShutdownListener updates FlintInstance if conditions are met") { // Mock dependencies - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val sessionIndex = "testIndex" val sessionId = "testSessionId" val timerContext = mock[Timer.Context] + val sessionManager = mock[SessionManager] - // Setup the getDoc to return a document indicating the session is running - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn( - Map[String, Object]( - "applicationId" -> "app1", - "jobId" -> "job1", - "sessionId" -> "session1", - "state" -> "running", - "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError", - "state" -> "running", - "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + val interactiveSession = new InteractiveSession( + "app123", + "job123", + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Instantiate the listener - val listener = new PreShutdownListener( - flintSessionIndexUpdater, - osClient, - sessionIndex, - sessionId, - timerContext) + val listener = new PreShutdownListener(sessionId, sessionManager, timerContext) // Simulate application end listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) - // Verify the update is called with the correct arguments - verify(flintSessionIndexUpdater).updateIf(*, *, *, *) + verify(sessionManager).updateSessionDetails(interactiveSession, SessionUpdateMode.UPDATE_IF) + interactiveSession.state shouldBe SessionStates.DEAD } test("Test super.constructErrorDF should construct dataframe properly") { @@ -257,8 +252,8 @@ class FlintREPLTest Row( null, null, - "unknown", - "unknown", + jobId, + applicationId, dataSourceName, "FAILED", error, @@ -281,6 +276,8 @@ class FlintREPLTest // Compare the result val result = FlintREPL.handleCommandFailureAndGetFailedData( + applicationId, + jobId, spark, dataSourceName, error, @@ -300,18 +297,18 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists and Valid JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" - - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val sessionManager = mock[SessionManager] - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) assert(result) } @@ -320,19 +317,20 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val differentJobId = "jobXYZ" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", differentJobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, differentJobId) // Assertions assert(!result) // The function should return false @@ -341,23 +339,21 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" - - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) + val sessionManager = mock[SessionManager] - val excludeJobIdsList = new java.util.ArrayList[String]() - excludeJobIdsList.add(jobId) // Add the jobId to the list to simulate exclusion - - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId) // The jobId matches - sourceMap.put("excludeJobIds", excludeJobIdsList) // But jobId is in the exclude list - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000, + Seq(jobId) // Add the jobId to the list to simulate exclusion + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assertions assert(!result) // The function should return false because jobId is excluded @@ -366,17 +362,12 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists but Source is Null") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] - // Mock the getDoc response - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null + when(sessionManager.getSessionDetails(sessionId)).thenReturn(None) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assertions assert(result) // The function should return true despite the null source @@ -385,11 +376,21 @@ class FlintREPLTest test("test canPickNextStatement: Doc Exists with Unexpected Type in excludeJobIds") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } + val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) val sourceMap = new java.util.HashMap[String, Object]() @@ -401,7 +402,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) assert(result) // The function should return true } @@ -409,16 +410,25 @@ class FlintREPLTest test("test canPickNextStatement: Doc Does Not Exist") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } // Set up the mock GetResponse val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist // Execute the function under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Assert the function returns true assert(result) @@ -427,15 +437,24 @@ class FlintREPLTest test("test canPickNextStatement: OSClient Throws Exception") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } // Set up the mock OSClient to throw an exception - when(osClient.getDoc(sessionIndex, sessionId)) + when(mockOSClient.getDoc(sessionIndex, sessionId)) .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) // Execute the method under test and expect true, since the method is designed to return true even in case of an exception - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // Verify the result is true despite the exception assert(result) @@ -446,11 +465,20 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as a String that does NOT match jobId @@ -461,7 +489,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return true since jobId is not excluded assert(result) @@ -512,17 +540,30 @@ class FlintREPLTest test("Doc Exists and excludeJobIds is an ArrayList Containing JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" - val handleSessionError = mock[Function1[String, Unit]] + val lastUpdateTime = System.currentTimeMillis() + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as an ArrayList containing jobId val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) // Creating an ArrayList and adding the jobId to it val excludeJobIdsList = new java.util.ArrayList[String]() @@ -530,9 +571,8 @@ class FlintREPLTest sourceMap.put("excludeJobIds", excludeJobIdsList.asInstanceOf[Object]) when(getResponse.getSourceAsMap).thenReturn(sourceMap) - // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return false since jobId is excluded assert(!result) @@ -541,16 +581,30 @@ class FlintREPLTest test("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] + val mockOSClient = mock[OSClient] val sessionIndex = "sessionIndex" + val lastUpdateTime = System.currentTimeMillis() + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) // Creating an ArrayList and adding a different jobId to it val excludeJobIdsList = new java.util.ArrayList[String]() @@ -560,7 +614,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) // The function should return true since the jobId is not in the excludeJobIds list assert(result) @@ -571,34 +625,53 @@ class FlintREPLTest val exception = new RuntimeException( new ConnectException( "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) when(mockReader.hasNext).thenThrow(exception) + when(mockOSClient.getIndexMetadata(any[String])).thenReturn(FlintREPL.resultIndexMapping) + val maxRetries = 1 var actualRetries = 0 - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" val applicationId = "testApplicationId" + val sessionIndex = "sessionIndex" + val lastUpdateTime = System.currentTimeMillis() + + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + // Create a sourceMap with excludeJobIds as an ArrayList not containing jobId + val sourceMap = new java.util.HashMap[String, Object]() + sourceMap.put("applicationId", applicationId.asInstanceOf[Object]) + sourceMap.put("state", "running".asInstanceOf[Object]) + sourceMap.put("jobId", jobId.asInstanceOf[Object]) + sourceMap.put("sessionId", sessionId.asInstanceOf[Object]) + sourceMap.put("lastUpdateTime", lastUpdateTime.asInstanceOf[Object]) val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + } + + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, + sessionManager, + queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -624,6 +697,9 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn("someSessionIndex") + // val mockExecutionContextExecutor: ExecutionContextExecutor = mock[ExecutionContextExecutor] val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) @@ -648,18 +724,35 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val sparkContext = mock[SparkContext] when(mockSparkSession.sparkContext).thenReturn(sparkContext) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( + applicationId, + jobId, + mockSparkSession, + dataSource, + sessionId, + sessionManager, + queryResultWriter, + Duration(10, MINUTES), + 60, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + val statementExecutionManager = new StatementExecutionManagerImpl(commandContext) + val result = FlintREPL.executeAndHandle( + applicationId, + jobId, mockSparkSession, flintStatement, + statementExecutionManager, dataSource, sessionId, executionContext, startTime, - // make sure it times out before mockSparkSession.sql can return, which takes 60 seconds Duration(1, SECONDS), 600000) @@ -677,6 +770,9 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn("someSessionIndex") + val flintStatement = new FlintStatement( "Running", @@ -705,14 +801,33 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) + val queryResultWriter = mock[QueryResultWriter] + val commandContext = CommandContext( + applicationId, + jobId, + mockSparkSession, + dataSource, + sessionId, + sessionManager, + queryResultWriter, + Duration(10, MINUTES), + 60, + 60, + DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) + val statementExecutionManager = new StatementExecutionManagerImpl(commandContext) + val result = FlintREPL.executeAndHandle( + applicationId, + jobId, mockSparkSession, flintStatement, + statementExecutionManager, dataSource, sessionId, executionContext, startTime, - Duration.Inf, // Use Duration.Inf or a large enough duration to avoid a timeout, + Duration.Inf, 600000) // Verify that ParseException was caught and handled @@ -720,19 +835,38 @@ class FlintREPLTest flintStatement.error should not be None flintStatement.error.get should include("Syntax error:") } finally threadPool.shutdown() - } test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { - val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + val sessionId = "session1" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) when(getResponse.getSourceAsMap).thenReturn( Map[String, Object]( "applicationId" -> "app1", "jobId" -> "job1", - "sessionId" -> "session1", + "sessionId" -> sessionId, "lastUpdateTime" -> java.lang.Long.valueOf(12345L), "error" -> "someError", "state" -> "running", @@ -740,49 +874,77 @@ class FlintREPLTest when(getResponse.getSeqNo).thenReturn(0L) when(getResponse.getPrimaryTerm).thenReturn(0L) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + conf, + "session1", + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(!result) // Expecting false as the job should proceed normally } test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { - val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + val sessionId = "session1" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the rest of the GetResponse as needed - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", jobId) val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + conf, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(result) // Expecting true as the job should exit early } test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { - val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + val sessionId = "session1" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) + + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + // Mock the GetResponse to simulate a scenario of a duplicate job when(getResponse.getSourceAsMap).thenReturn( Map[String, Object]( @@ -797,100 +959,114 @@ class FlintREPLTest .asList("job-2", "job-1") // Include this inside the Map ).asJava) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + conf, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(result) // Expecting true for early exit due to duplicate job } test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { - val osClient = mock[OSClient] + val sessionIndex = "sessionIndex" + val sessionId = "session1" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + val mockOSClient = mock[OSClient] val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") + val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex")) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") val result = FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - Some("sessionIndex"), - Some("sessionId"), - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, + conf, + sessionId, + jobId, + applicationId, + sessionManager, System.currentTimeMillis()) assert(!result) // Expecting false as the job proceeds normally } - test( - "setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") { - val osClient = mock[OSClient] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") - - assertThrows[NoSuchElementException] { - FlintREPL.setupFlintJobWithExclusionCheck( - mockConf, - None, // No sessionIndex provided - None, // No sessionId provided - osClient, - "jobId", - "appId", - flintSessionIndexUpdater, - System.currentTimeMillis()) - } - } - test("queryLoop continue until inactivity limit is reached") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) val shortInactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, + sessionManager, + queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - // Mock processCommands to always allow loop continuation - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) - val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -905,48 +1081,61 @@ class FlintREPLTest } test("queryLoop should stop when canPickUpNextStatement is false") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(true) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + + // Mocking canPickNextStatement to return false + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { + val mockGetResponse = mock[GetResponse] + when(mockGetResponse.isExists()).thenReturn(true) + when(mockGetResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> "differentJobId", + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + mockGetResponse + }) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(true) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + val longInactivityLimit = 10000 // 10 seconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, + sessionManager, + queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY) - // Mocking canPickNextStatement to return false - when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { - val mockGetResponse = mock[GetResponse] - when(mockGetResponse.isExists()).thenReturn(true) - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", "differentJobId") - when(mockGetResponse.getSourceAsMap).thenReturn(sourceMap) - mockGetResponse - }) - val startTime = System.currentTimeMillis() FlintREPL.queryLoop(commandContext) @@ -961,34 +1150,52 @@ class FlintREPLTest } test("queryLoop should properly shut down the thread pool after execution") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + when(mockReader.hasNext).thenReturn(false) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val getResponse = mock[GetResponse] + when(mockOSClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, + sessionManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1011,35 +1218,54 @@ class FlintREPLTest } test("queryLoop handle exceptions within the loop gracefully") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - // Simulate an exception thrown when hasNext is called - when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" - val jobId = "testJobId" + + val mockReader = mock[FlintReader] + val mockOSClient = mock[OSClient] + val getResponse = mock[GetResponse] + + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) + when(getResponse.isExists()).thenReturn(true) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) + // Simulate an exception thrown when hasNext is called + when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) val inactivityLimit = 500 // 500 milliseconds // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, + sessionManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1064,15 +1290,24 @@ class FlintREPLTest } test("queryLoop should correctly update loop control variables") { + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + val mockOSClient = mock[OSClient] + when( + mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) - when(osClient.doesIndexExist(*)).thenReturn(true) - when(osClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + + val mockOpenSearchUpdater = mock[OpenSearchUpdater] + doNothing().when(mockOpenSearchUpdater).upsert(any[String], any[String]) // Configure mockReader to return true once and then false to exit the loop when(mockReader.hasNext).thenReturn(true).thenReturn(false) @@ -1088,12 +1323,6 @@ class FlintREPLTest """ when(mockReader.next).thenReturn(command) - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" - val inactivityLimit = 5000 // 5 seconds // Create a SparkSession for testing\ @@ -1108,20 +1337,27 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + when(mockSparkSession.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, "")) + .thenReturn("") when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = new SessionManagerImpl(mockSparkSession, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, mockSparkSession, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, + sessionManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1136,7 +1372,10 @@ class FlintREPLTest // Assuming processCommands updates the lastActivityTime to the current time assert(endTime - startTime >= inactivityLimit) - verify(osClient, times(1)).getIndexMetadata(*) + + val expectedCalls = + Math.ceil(inactivityLimit.toDouble / DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY).toInt + verify(mockOSClient, Mockito.atMost(expectedCalls)).getIndexMetadata(*) } val testCases = Table( @@ -1148,35 +1387,52 @@ class FlintREPLTest test( "queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") { forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) => + val resultIndex = "testResultIndex" + val dataSource = "testDataSource" + val sessionIndex = "testSessionIndex" + val sessionId = "testSessionId" + val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + val mockOSClient = mock[OSClient] + when( + mockOSClient + .createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.getDoc(*, *)).thenReturn(getResponse) + when(mockOSClient.doesIndexExist(*)).thenReturn(true) + when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) + when(getResponse.isExists()).thenReturn(false) - when(mockReader.hasNext).thenReturn(false) + when(getResponse.getSourceAsMap).thenReturn( + Map[String, Object]( + "applicationId" -> applicationId, + "jobId" -> jobId, + "sessionId" -> sessionId, + "lastUpdateTime" -> java.lang.Long.valueOf(12345L), + "error" -> "someError", + "state" -> "running", + "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) - val resultIndex = "testResultIndex" - val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" - val sessionId = "testSessionId" - val jobId = "testJobId" + when(mockReader.hasNext).thenReturn(false) // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + spark.conf.set(FlintSparkConf.REQUEST_INDEX.key, sessionIndex) + val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { + override val osClient: OSClient = mockOSClient + } + val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( + applicationId, + jobId, spark, dataSource, - resultIndex, sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, - jobId, + sessionManager, + queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60,