From 4b8b98a72d8235f24c96b26027a0919d59523a1f Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Thu, 16 Nov 2023 13:39:24 -0800 Subject: [PATCH] address Chen's comments Signed-off-by: Kaituo Li --- .../scala/org/apache/spark/sql/FlintJob.scala | 88 ++------------ .../apache/spark/sql/FlintJobExecutor.scala | 8 +- .../org/apache/spark/sql/FlintREPL.scala | 11 +- .../org/apache/spark/sql/JobOperator.scala | 111 ++++++++++++++++++ .../org/apache/spark/sql/FlintREPLTest.scala | 2 +- 5 files changed, 140 insertions(+), 80 deletions(-) create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 2c4fe01af..42a14643d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -7,10 +7,6 @@ package org.apache.spark.sql import java.util.Locale -import java.util.concurrent.ThreadPoolExecutor - -import scala.concurrent.{ExecutionContext, Future, TimeoutException} -import scala.concurrent.duration.{Duration, MINUTES} import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.cluster.metadata.MappingMetadata @@ -23,9 +19,7 @@ import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} -import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint @@ -49,77 +43,19 @@ object FlintJob extends Logging with FlintJobExecutor { val conf = createSparkConf() val wait = conf.get("spark.flint.job.type", "continue") val dataSource = conf.get("spark.flint.datasource.name", "") + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + conf.set("spark.sql.defaultCatalog", dataSource) val spark = createSparkSession(conf) - val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var dataToWrite: Option[DataFrame] = None - val startTime = System.currentTimeMillis() - // osClient needs spark session to be created first to get FlintOptions initialized. - // Otherwise, we will have connection exception from EMR-S to OS. - val osClient = new OSClient(FlintSparkConf().flintOptions()) - var exceptionThrown = true - try { - val futureMappingCheck = Future { - checkAndCreateIndex(osClient, resultIndex) - } - val data = executeQuery(spark, query, dataSource, "", "") - - val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) - dataToWrite = Some(mappingCheckResult match { - case Right(_) => data - case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) - }) - exceptionThrown = false - } catch { - case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" - logError(error, e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - case e: Exception => - val error = processQueryException(e, spark, dataSource, query, "", "") - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - } finally { - cleanUpResources( - spark, - exceptionThrown, - wait, - threadPool, - dataToWrite, - resultIndex, - osClient) - } - } - - def cleanUpResources( - spark: SparkSession, - exceptionThrown: Boolean, - wait: String, - threadPool: ThreadPoolExecutor, - dataToWrite: Option[DataFrame], - resultIndex: String, - osClient: OSClient): Unit = { - try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - } catch { - case e: Exception => logError("fail to write to result index", e) - } - try { - // Stop SparkSession if streaming job succeeds - if (!exceptionThrown && wait.equalsIgnoreCase("streaming")) { - // wait if any child thread to finish before the main thread terminates - spark.streams.awaitAnyTermination() - } else { - spark.stop() - } - } catch { - case e: Exception => logError("fail to close spark session", e) - } finally { - threadPool.shutdown() - } + val jobOperator = + JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming")) + jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 6e7dbb926..903bcaa09 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -6,6 +6,10 @@ package org.apache.spark.sql import java.util.Locale +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception import org.opensearch.flint.core.FlintClient @@ -14,11 +18,13 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{createIndex, getFormattedData, isSuperset, logError, logInfo} +import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} +import org.apache.spark.util.ThreadUtils trait FlintJobExecutor { this: Logging => diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 969f14002..538430cd9 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -63,6 +63,13 @@ object FlintREPL extends Logging with FlintJobExecutor { val conf: SparkConf = createSparkConf() val dataSource = conf.get("spark.flint.datasource.name", "unknown") // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ conf.set("spark.sql.defaultCatalog", dataSource) val wait = conf.get("spark.flint.job.type", "continue") // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. @@ -99,7 +106,7 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) - createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) // 1 thread for updating heart beat val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) val jobStartTime = currentTimeProvider.currentEpochMillis() @@ -767,7 +774,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader } - def createShutdownHook( + def addShutdownHook( flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, sessionIndex: String, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala new file mode 100644 index 000000000..43b31945b --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} + +import org.opensearch.flint.core.storage.OpenSearchUpdater + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.createSparkSession +import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, updateFlintInstanceBeforeShutdown} +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils + +case class JobOperator( + sparkConf: SparkConf, + query: String, + dataSource: String, + resultIndex: String, + streaming: Boolean) + extends Logging + with FlintJobExecutor { + private val spark = createSparkSession(sparkConf) + + // jvm shutdown hook + sys.addShutdownHook(stop()) + + def start(): Unit = { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + var dataToWrite: Option[DataFrame] = None + val startTime = System.currentTimeMillis() + // osClient needs spark session to be created first to get FlintOptions initialized. + // Otherwise, we will have connection exception from EMR-S to OS. + val osClient = new OSClient(FlintSparkConf().flintOptions()) + var exceptionThrown = true + try { + val futureMappingCheck = Future { + checkAndCreateIndex(osClient, resultIndex) + } + val data = executeQuery(spark, query, dataSource, "", "") + + val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + dataToWrite = Some(mappingCheckResult match { + case Right(_) => data + case Left(error) => + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + }) + exceptionThrown = false + } catch { + case e: TimeoutException => + val error = s"Getting the mapping of index $resultIndex timed out" + logError(error, e) + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + case e: Exception => + val error = processQueryException(e, spark, dataSource, query, "", "") + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + } finally { + cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) + } + } + + def cleanUpResources( + exceptionThrown: Boolean, + threadPool: ThreadPoolExecutor, + dataToWrite: Option[DataFrame], + resultIndex: String, + osClient: OSClient): Unit = { + try { + dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + } catch { + case e: Exception => logError("fail to write to result index", e) + } + + try { + // Stop SparkSession if streaming job succeeds + if (!exceptionThrown && streaming) { + // wait if any child thread to finish before the main thread terminates + spark.streams.awaitAnyTermination() + } + } catch { + case e: Exception => logError("streaming job failed", e) + } + + try { + threadPool.shutdown() + } catch { + case e: Exception => logError("Fail to close threadpool", e) + } + } + + def stop(): Unit = { + Try { + spark.stop() + } match { + case Success(_) => + case Failure(e) => logError("unexpected error while shutdown", e) + } + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 704045e8a..8335f2a72 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -117,7 +117,7 @@ class FlintREPLTest } // Here, we're injecting our mockShutdownHookManager into the method - FlintREPL.createShutdownHook( + FlintREPL.addShutdownHook( flintSessionIndexUpdater, osClient, sessionIndex,