diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index cb1f5c1ca..750e228ef 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -8,9 +8,6 @@ package org.apache.spark.sql import java.util.Locale -import scala.concurrent.{ExecutionContext, Future, TimeoutException} -import scala.concurrent.duration.{Duration, MINUTES} - import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.cluster.metadata.MappingMetadata import org.opensearch.common.settings.Settings @@ -22,9 +19,7 @@ import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} -import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint @@ -48,51 +43,18 @@ object FlintJob extends Logging with FlintJobExecutor { val conf = createSparkConf() val wait = conf.get("spark.flint.job.type", "continue") val dataSource = conf.get("spark.flint.datasource.name", "") - val spark = createSparkSession(conf) - - val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var dataToWrite: Option[DataFrame] = None - val startTime = System.currentTimeMillis() - // osClient needs spark session to be created first to get FlintOptions initialized. - // Otherwise, we will have connection exception from EMR-S to OS. - val osClient = new OSClient(FlintSparkConf().flintOptions()) - var exceptionThrown = true - try { - val futureMappingCheck = Future { - checkAndCreateIndex(osClient, resultIndex) - } - val data = executeQuery(spark, query, dataSource, "", "") - - val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) - dataToWrite = Some(mappingCheckResult match { - case Right(_) => data - case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) - }) - exceptionThrown = false - } catch { - case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" - logError(error, e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - case e: Exception => - val error = processQueryException(e, spark, dataSource, query, "", "") - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - } finally { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - // Stop SparkSession if streaming job succeeds - if (!exceptionThrown && wait.equalsIgnoreCase("streaming")) { - // wait if any child thread to finish before the main thread terminates - spark.streams.awaitAnyTermination() - } else { - spark.stop() - } - - threadPool.shutdown() - } + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + conf.set("spark.sql.defaultCatalog", dataSource) + + val jobOperator = + JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming")) + jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 6e7dbb926..903bcaa09 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -6,6 +6,10 @@ package org.apache.spark.sql import java.util.Locale +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception import org.opensearch.flint.core.FlintClient @@ -14,11 +18,13 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{createIndex, getFormattedData, isSuperset, logError, logInfo} +import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} +import org.apache.spark.util.ThreadUtils trait FlintJobExecutor { this: Logging => diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index edf7d62e6..f21a01d53 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.ThreadUtils object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 30 * 60 * 1000 + private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 @@ -63,30 +63,38 @@ object FlintREPL extends Logging with FlintJobExecutor { val conf: SparkConf = createSparkConf() val dataSource = conf.get("spark.flint.datasource.name", "unknown") // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ conf.set("spark.sql.defaultCatalog", dataSource) val wait = conf.get("spark.flint.job.type", "continue") - // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) - val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) - - if (sessionIndex.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") - } - if (sessionId.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.sessionId is not set") - } - - val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) - val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") if (wait.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") - val result = executeQuery(spark, query, dataSource, "", "") - writeDataFrameToOpensearch(result, resultIndex, osClient) - spark.streams.awaitAnyTermination() + val jobOperator = + JobOperator(conf, query, dataSource, resultIndex, true) + jobOperator.start() } else { + // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. + val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) + val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) + + if (sessionIndex.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") + } + if (sessionId.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.sessionId is not set") + } + + val spark = createSparkSession(conf) + val osClient = new OSClient(FlintSparkConf().flintOptions()) + val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") + val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) @@ -99,7 +107,7 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) - createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) // 1 thread for updating heart beat val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) val jobStartTime = currentTimeProvider.currentEpochMillis() @@ -767,7 +775,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader } - def createShutdownHook( + def addShutdownHook( flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, sessionIndex: String, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala new file mode 100644 index 000000000..c60d250ea --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} + +import org.opensearch.flint.core.storage.OpenSearchUpdater + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.createSparkSession +import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, updateFlintInstanceBeforeShutdown} +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils + +case class JobOperator( + sparkConf: SparkConf, + query: String, + dataSource: String, + resultIndex: String, + streaming: Boolean) + extends Logging + with FlintJobExecutor { + private val spark = createSparkSession(sparkConf) + + // jvm shutdown hook + sys.addShutdownHook(stop()) + + def start(): Unit = { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + var dataToWrite: Option[DataFrame] = None + val startTime = System.currentTimeMillis() + // osClient needs spark session to be created first to get FlintOptions initialized. + // Otherwise, we will have connection exception from EMR-S to OS. + val osClient = new OSClient(FlintSparkConf().flintOptions()) + var exceptionThrown = true + try { + val futureMappingCheck = Future { + checkAndCreateIndex(osClient, resultIndex) + } + val data = executeQuery(spark, query, dataSource, "", "") + + val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + dataToWrite = Some(mappingCheckResult match { + case Right(_) => data + case Left(error) => + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + }) + exceptionThrown = false + } catch { + case e: TimeoutException => + val error = s"Getting the mapping of index $resultIndex timed out" + logError(error, e) + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + case e: Exception => + val error = processQueryException(e, spark, dataSource, query, "", "") + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + } finally { + cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) + } + } + + def cleanUpResources( + exceptionThrown: Boolean, + threadPool: ThreadPoolExecutor, + dataToWrite: Option[DataFrame], + resultIndex: String, + osClient: OSClient): Unit = { + try { + dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + } catch { + case e: Exception => logError("fail to write to result index", e) + } + + try { + // Stop SparkSession if streaming job succeeds + if (!exceptionThrown && streaming) { + // wait if any child thread to finish before the main thread terminates + spark.streams.awaitAnyTermination() + } + } catch { + case e: Exception => logError("streaming job failed", e) + } + + try { + threadPool.shutdown() + logInfo("shut down thread threadpool") + } catch { + case e: Exception => logError("Fail to close threadpool", e) + } + } + + def stop(): Unit = { + Try { + spark.stop() + logInfo("stopped spark session") + } match { + case Success(_) => + case Failure(e) => logError("unexpected error while stopping spark session", e) + } + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 704045e8a..8335f2a72 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -117,7 +117,7 @@ class FlintREPLTest } // Here, we're injecting our mockShutdownHookManager into the method - FlintREPL.createShutdownHook( + FlintREPL.addShutdownHook( flintSessionIndexUpdater, osClient, sessionIndex,