diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala index acf28c572..35367eb62 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/StatementExecutionManager.scala @@ -38,5 +38,5 @@ trait StatementExecutionManager { /** * Terminates the statement lifecycle. */ - def terminateStatementsExecution(): Unit + def terminateStatementExecution(): Unit } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index c96b71fd9..5b9318f8e 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -211,6 +211,10 @@ object FlintSparkConf { FlintConfig("spark.flint.job.query") .doc("Flint query for batch and streaming job") .createOptional() + val QUERY_ID = + FlintConfig("spark.flint.job.queryId") + .doc("Flint query id for batch and streaming job") + .createOptional() val JOB_TYPE = FlintConfig(s"spark.flint.job.type") .doc("Flint job type. Including interactive and streaming") diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala index 9371557a2..57277440e 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -37,6 +37,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { val resultIndex = "query_results2" val appId = "00feq82b752mbt0p" val dataSourceName = "my_glue1" + val queryId = "testQueryId" var osClient: OSClient = _ val threadLocalFuture = new ThreadLocal[Future[Unit]]() @@ -91,7 +92,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { * all Spark conf required by Flint code underlying manually. */ spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName) - spark.conf.set(JOB_TYPE.key, "streaming") + spark.conf.set(JOB_TYPE.key, FlintJobType.STREAMING) /** * FlintJob.main() is not called because we need to manually set these variables within a @@ -103,9 +104,10 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { jobRunId, spark, query, + queryId, dataSourceName, resultIndex, - true, + FlintJobType.STREAMING, streamingRunningCount) job.terminateJVM = false job.start() @@ -144,7 +146,6 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { assert(result.status == "SUCCESS", s"expected status is SUCCESS, but got ${result.status}") assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}") - assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") commonAssert(result, jobRunId, query, queryStartTime) true @@ -362,7 +363,9 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { result.queryRunTime < System.currentTimeMillis() - queryStartTime, s"expected query run time ${result.queryRunTime} should be less than ${System .currentTimeMillis() - queryStartTime}, but it is not") - assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") + assert( + result.queryId == queryId, + s"expected query id is ${queryId}, but got ${result.queryId}") } def pollForResultAndAssert(expected: REPLResult => Boolean, jobId: String): Unit = { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 0d5c062ae..42b1ae2f6 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -15,6 +15,7 @@ case class CommandContext( jobId: String, spark: SparkSession, dataSource: String, + jobType: String, sessionId: String, sessionManager: SessionManager, queryResultWriter: QueryResultWriter, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index c556e2786..04609cf3d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -11,11 +11,9 @@ import java.util.concurrent.atomic.AtomicInteger import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge -import play.api.libs.json._ import org.apache.spark.internal.Logging import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.sql.types._ /** * Spark SQL Application entrypoint @@ -32,7 +30,7 @@ object FlintJob extends Logging with FlintJobExecutor { val (queryOption, resultIndexOption) = parseArgs(args) val conf = createSparkConf() - val jobType = conf.get("spark.flint.job.type", "batch") + val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH) CustomLogging.logInfo(s"""Job type is: ${jobType}""") conf.set(FlintSparkConf.JOB_TYPE.key, jobType) @@ -41,6 +39,8 @@ object FlintJob extends Logging with FlintJobExecutor { if (query.isEmpty) { logAndThrow(s"Query undefined for the ${jobType} job.") } + val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") + if (resultIndexOption.isEmpty) { logAndThrow("resultIndex is not set") } @@ -66,9 +66,10 @@ object FlintJob extends Logging with FlintJobExecutor { jobId, createSparkSession(conf), query, + queryId, dataSource, resultIndexOption.get, - jobType.equalsIgnoreCase("streaming"), + jobType, streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 95d3ba0f1..24d68fd47 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY import org.apache.spark.sql.types._ import org.apache.spark.sql.util._ +import org.apache.spark.util.Utils object SparkConfConstants { val SQL_EXTENSIONS_KEY = "spark.sql.extensions" @@ -33,6 +34,12 @@ object SparkConfConstants { "org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions" } +object FlintJobType { + val INTERACTIVE = "interactive" + val BATCH = "batch" + val STREAMING = "streaming" +} + trait FlintJobExecutor { this: Logging => @@ -131,7 +138,7 @@ trait FlintJobExecutor { * https://github.com/opensearch-project/opensearch-spark/issues/324 */ def configDYNMaxExecutors(conf: SparkConf, jobType: String): Unit = { - if (jobType.equalsIgnoreCase("streaming")) { + if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) { conf.set( "spark.dynamicAllocation.maxExecutors", conf @@ -524,4 +531,25 @@ trait FlintJobExecutor { CustomLogging.logError(t) throw t } + + def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { + if (className.isEmpty) { + defaultConstructor + } else { + try { + val classObject = Utils.classForName(className) + val ctor = if (args.isEmpty) { + classObject.getDeclaredConstructor() + } else { + classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*) + } + ctor.setAccessible(true) + ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to instantiate provider: $className", e) + } + } + } + } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 340c0656e..635a5226e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.FlintREPLConfConstants._ import org.apache.spark.sql.SessionUpdateMode._ import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ThreadUtils object FlintREPLConfConstants { val HEARTBEAT_INTERVAL_MILLIS = 60000L @@ -87,8 +87,9 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.set(FlintSparkConf.JOB_TYPE.key, jobType) val query = getQuery(queryOption, jobType, conf) + val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") - if (jobType.equalsIgnoreCase("streaming")) { + if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) { if (resultIndexOption.isEmpty) { logAndThrow("resultIndex is not set") } @@ -100,9 +101,10 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId, createSparkSession(conf), query, + queryId, dataSource, resultIndexOption.get, - true, + jobType, streamingRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() @@ -174,6 +176,7 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId, spark, dataSource, + jobType, sessionId, sessionManager, queryResultWriter, @@ -220,7 +223,7 @@ object FlintREPL extends Logging with FlintJobExecutor { def getQuery(queryOption: Option[String], jobType: String, conf: SparkConf): String = { queryOption.getOrElse { - if (jobType.equalsIgnoreCase("streaming")) { + if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) { val defaultQuery = conf.get(FlintSparkConf.QUERY.key, "") if (defaultQuery.isEmpty) { logAndThrow("Query undefined for the streaming job.") @@ -352,7 +355,7 @@ object FlintREPL extends Logging with FlintJobExecutor { canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime } finally { - statementsExecutionManager.terminateStatementsExecution() + statementsExecutionManager.terminateStatementExecution() } Thread.sleep(commandContext.queryLoopExecutionFrequency) @@ -975,26 +978,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { - if (className.isEmpty) { - defaultConstructor - } else { - try { - val classObject = Utils.classForName(className) - val ctor = if (args.isEmpty) { - classObject.getDeclaredConstructor() - } else { - classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*) - } - ctor.setAccessible(true) - ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T] - } catch { - case e: Exception => - throw new RuntimeException(s"Failed to instantiate provider: $className", e) - } - } - } - private def instantiateSessionManager( spark: SparkSession, resultIndexOption: Option[String]): SessionManager = { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index b49f4a9ed..cb4af86da 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -12,6 +12,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import scala.util.{Failure, Success, Try} +import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.spark.FlintSpark @@ -24,11 +25,12 @@ import org.apache.spark.util.ThreadUtils case class JobOperator( applicationId: String, jobId: String, - spark: SparkSession, + sparkSession: SparkSession, query: String, + queryId: String, dataSource: String, resultIndex: String, - streaming: Boolean, + jobType: String, streamingRunningCount: AtomicInteger) extends Logging with FlintJobExecutor { @@ -48,43 +50,68 @@ case class JobOperator( // osClient needs spark session to be created first to get FlintOptions initialized. // Otherwise, we will have connection exception from EMR-S to OS. val osClient = new OSClient(FlintSparkConf().flintOptions()) + + // TODO: Update FlintJob to Support All Query Types. Track on https://github.com/opensearch-project/opensearch-spark/issues/633 + val commandContext = CommandContext( + applicationId, + jobId, + sparkSession, + dataSource, + jobType, + "", // FlintJob doesn't have sessionId + null, // FlintJob doesn't have SessionManager + null, // FlintJob doesn't have QueryResultWriter + Duration.Inf, // FlintJob doesn't have queryExecutionTimeout + -1, // FlintJob doesn't have inactivityLimitMillis + -1, // FlintJob doesn't have queryWaitTimeMillis + -1 // FlintJob doesn't have queryLoopExecutionFrequency + ) + + val statementExecutionManager = + instantiateStatementExecutionManager(commandContext, resultIndex, osClient) + + val statement = + new FlintStatement("running", query, "", queryId, currentTimeProvider.currentEpochMillis()) + var exceptionThrown = true + var error: String = null + try { - val futureMappingCheck = Future { - checkAndCreateIndex(osClient, resultIndex) + val futurePrepareQueryExecution = Future { + statementExecutionManager.prepareStatementExecution() } - val data = executeQuery(applicationId, jobId, spark, query, dataSource, "", "", streaming) - - val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) - dataToWrite = Some(mappingCheckResult match { - case Right(_) => data - case Left(error) => - constructErrorDF( - applicationId, - jobId, - spark, - dataSource, - "FAILED", - error, - "", - query, - "", - startTime) - }) + val data = statementExecutionManager.executeStatement(statement) + dataToWrite = Some( + ThreadUtils.awaitResult(futurePrepareQueryExecution, Duration(1, MINUTES)) match { + case Right(_) => data + case Left(err) => + error = err + constructErrorDF( + applicationId, + jobId, + sparkSession, + dataSource, + "FAILED", + err, + queryId, + query, + "", + startTime) + }) exceptionThrown = false } catch { case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" + error = s"Preparation for query execution timed out" logError(error, e) dataToWrite = Some( constructErrorDF( applicationId, jobId, - spark, + sparkSession, dataSource, "TIMEOUT", error, - "", + queryId, query, "", startTime)) @@ -94,44 +121,46 @@ case class JobOperator( constructErrorDF( applicationId, jobId, - spark, + sparkSession, dataSource, "FAILED", error, - "", + queryId, query, "", startTime)) } finally { - cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) - } - } + try { + dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + } catch { + case e: Exception => + exceptionThrown = true + error = s"Failed to write to result index. originalError='${error}'" + logError(error, e) + } + if (exceptionThrown) statement.fail() else statement.complete() + statement.error = Some(error) + statementExecutionManager.updateStatement(statement) - def cleanUpResources( - exceptionThrown: Boolean, - threadPool: ThreadPoolExecutor, - dataToWrite: Option[DataFrame], - resultIndex: String, - osClient: OSClient): Unit = { - try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - } catch { - case e: Exception => logError("fail to write to result index", e) + cleanUpResources(exceptionThrown, threadPool) } + } + def cleanUpResources(exceptionThrown: Boolean, threadPool: ThreadPoolExecutor): Unit = { + val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) try { // Wait for streaming job complete if no error - if (!exceptionThrown && streaming) { + if (!exceptionThrown && isStreaming) { // Clean Spark shuffle data after each microBatch. - spark.streams.addListener(new ShuffleCleaner(spark)) + sparkSession.streams.addListener(new ShuffleCleaner(sparkSession)) // Await index monitor before the main thread terminates - new FlintSpark(spark).flintIndexMonitor.awaitMonitor() + new FlintSpark(sparkSession).flintIndexMonitor.awaitMonitor() } else { logInfo(s""" | Skip streaming job await due to conditions not met: | - exceptionThrown: $exceptionThrown - | - streaming: $streaming - | - activeStreams: ${spark.streams.active.mkString(",")} + | - streaming: $isStreaming + | - activeStreams: ${sparkSession.streams.active.mkString(",")} |""".stripMargin) } } catch { @@ -163,7 +192,7 @@ case class JobOperator( def stop(): Unit = { Try { logInfo("Stopping Spark session") - spark.stop() + sparkSession.stop() logInfo("Stopped Spark session") } match { case Success(_) => @@ -191,4 +220,15 @@ case class JobOperator( } } + private def instantiateStatementExecutionManager( + commandContext: CommandContext, + resultIndex: String, + osClient: OSClient): StatementExecutionManager = { + import commandContext._ + instantiate( + new SingleStatementExecutionManager(commandContext, resultIndex, osClient), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + spark, + sessionId) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/SingleStatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/SingleStatementExecutionManagerImpl.scala new file mode 100644 index 000000000..52b8edd1d --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/SingleStatementExecutionManagerImpl.scala @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.storage.OpenSearchUpdater +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging + +/** + * SingleStatementExecutionManager is an implementation of StatementExecutionManager interface to + * run single statement + * @param commandContext + */ +class SingleStatementExecutionManager( + commandContext: CommandContext, + resultIndex: String, + osClient: OSClient) + extends StatementExecutionManager + with FlintJobExecutor + with Logging { + + override def prepareStatementExecution(): Either[String, Unit] = { + checkAndCreateIndex(osClient, resultIndex) + } + + override def updateStatement(statement: FlintStatement): Unit = { + // TODO: Update FlintJob to Support All Query Types. Track on https://github.com/opensearch-project/opensearch-spark/issues/633 + } + + override def terminateStatementExecution(): Unit = { + // TODO: Update FlintJob to Support All Query Types. Track on https://github.com/opensearch-project/opensearch-spark/issues/633 + } + + override def getNextStatement(): Option[FlintStatement] = { + // TODO: Update FlintJob to Support All Query Types. Track on https://github.com/opensearch-project/opensearch-spark/issues/633 + None + } + + override def executeStatement(statement: FlintStatement): DataFrame = { + import commandContext._ + val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) + executeQuery( + applicationId, + jobId, + spark, + statement.query, + dataSource, + statement.queryId, + sessionId, + isStreaming) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index 0b059f1d3..4e9435f7b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -38,7 +38,7 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) override def updateStatement(statement: FlintStatement): Unit = { flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) } - override def terminateStatementsExecution(): Unit = { + override def terminateStatementExecution(): Unit = { flintReader.close() } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 433c2351b..a3f990a59 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -50,6 +50,7 @@ class FlintREPLTest private val jobId = "testJobId" private val applicationId = "testApplicationId" + private val INTERACTIVE_JOB_TYPE = "interactive" test("parseArgs with no arguments should return (None, None)") { val args = Array.empty[String] @@ -669,6 +670,7 @@ class FlintREPLTest jobId, spark, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -734,6 +736,7 @@ class FlintREPLTest jobId, mockSparkSession, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -808,6 +811,7 @@ class FlintREPLTest jobId, mockSparkSession, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -1059,6 +1063,7 @@ class FlintREPLTest jobId, spark, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -1128,6 +1133,7 @@ class FlintREPLTest jobId, spark, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -1193,6 +1199,7 @@ class FlintREPLTest jobId, spark, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -1263,6 +1270,7 @@ class FlintREPLTest jobId, spark, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -1355,6 +1363,7 @@ class FlintREPLTest jobId, mockSparkSession, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter, @@ -1430,6 +1439,7 @@ class FlintREPLTest jobId, spark, dataSource, + INTERACTIVE_JOB_TYPE, sessionId, sessionManager, queryResultWriter,