Skip to content

Commit

Permalink
address Chen's comments
Browse files Browse the repository at this point in the history
Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Nov 17, 2023
1 parent f0e0135 commit 4b8b98a
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.ThreadPoolExecutor

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import org.opensearch.client.{RequestOptions, RestHighLevelClient}
import org.opensearch.cluster.metadata.MappingMetadata
Expand All @@ -23,9 +19,7 @@ import play.api.libs.json._
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.util.ThreadUtils

/**
* Spark SQL Application entrypoint
Expand All @@ -49,77 +43,19 @@ object FlintJob extends Logging with FlintJobExecutor {
val conf = createSparkConf()
val wait = conf.get("spark.flint.job.type", "continue")
val dataSource = conf.get("spark.flint.datasource.name", "")
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
val spark = createSparkSession(conf)

val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index")
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var dataToWrite: Option[DataFrame] = None
val startTime = System.currentTimeMillis()
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
var exceptionThrown = true
try {
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
val data = executeQuery(spark, query, dataSource, "", "")

val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(mappingCheckResult match {
case Right(_) => data
case Left(error) =>
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)
})
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
case e: Exception =>
val error = processQueryException(e, spark, dataSource, query, "", "")
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
} finally {
cleanUpResources(
spark,
exceptionThrown,
wait,
threadPool,
dataToWrite,
resultIndex,
osClient)
}
}

def cleanUpResources(
spark: SparkSession,
exceptionThrown: Boolean,
wait: String,
threadPool: ThreadPoolExecutor,
dataToWrite: Option[DataFrame],
resultIndex: String,
osClient: OSClient): Unit = {
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception => logError("fail to write to result index", e)
}
try {
// Stop SparkSession if streaming job succeeds
if (!exceptionThrown && wait.equalsIgnoreCase("streaming")) {
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
} else {
spark.stop()
}
} catch {
case e: Exception => logError("fail to close spark session", e)
} finally {
threadPool.shutdown()
}
val jobOperator =
JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming"))
jobOperator.start()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.ThreadPoolExecutor

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import com.amazonaws.services.s3.model.AmazonS3Exception
import org.opensearch.flint.core.FlintClient
Expand All @@ -14,11 +18,13 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.{createIndex, getFormattedData, isSuperset, logError, logInfo}
import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider}
import org.apache.spark.util.ThreadUtils

trait FlintJobExecutor {
this: Logging =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
val conf: SparkConf = createSparkConf()
val dataSource = conf.get("spark.flint.datasource.name", "unknown")
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
val wait = conf.get("spark.flint.job.type", "continue")
// we don't allow default value for sessionIndex and sessionId. Throw exception if key not found.
Expand Down Expand Up @@ -99,7 +106,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS)

val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get)
createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get)
addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get)
// 1 thread for updating heart beat
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)
val jobStartTime = currentTimeProvider.currentEpochMillis()
Expand Down Expand Up @@ -767,7 +774,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
flintReader
}

def createShutdownHook(
def addShutdownHook(
flintSessionIndexUpdater: OpenSearchUpdater,
osClient: OSClient,
sessionIndex: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import java.util.concurrent.ThreadPoolExecutor

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}
import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.storage.OpenSearchUpdater

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, updateFlintInstanceBeforeShutdown}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.ThreadUtils

case class JobOperator(
sparkConf: SparkConf,
query: String,
dataSource: String,
resultIndex: String,
streaming: Boolean)
extends Logging
with FlintJobExecutor {
private val spark = createSparkSession(sparkConf)

// jvm shutdown hook
sys.addShutdownHook(stop())

def start(): Unit = {
val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index")
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var dataToWrite: Option[DataFrame] = None
val startTime = System.currentTimeMillis()
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
var exceptionThrown = true
try {
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
val data = executeQuery(spark, query, dataSource, "", "")

val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(mappingCheckResult match {
case Right(_) => data
case Left(error) =>
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)
})
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
case e: Exception =>
val error = processQueryException(e, spark, dataSource, query, "", "")
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
} finally {
cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient)
}
}

def cleanUpResources(
exceptionThrown: Boolean,
threadPool: ThreadPoolExecutor,
dataToWrite: Option[DataFrame],
resultIndex: String,
osClient: OSClient): Unit = {
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception => logError("fail to write to result index", e)
}

try {
// Stop SparkSession if streaming job succeeds
if (!exceptionThrown && streaming) {
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
}
} catch {
case e: Exception => logError("streaming job failed", e)
}

try {
threadPool.shutdown()
} catch {
case e: Exception => logError("Fail to close threadpool", e)
}
}

def stop(): Unit = {
Try {
spark.stop()
} match {
case Success(_) =>
case Failure(e) => logError("unexpected error while shutdown", e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class FlintREPLTest
}

// Here, we're injecting our mockShutdownHookManager into the method
FlintREPL.createShutdownHook(
FlintREPL.addShutdownHook(
flintSessionIndexUpdater,
osClient,
sessionIndex,
Expand Down

0 comments on commit 4b8b98a

Please sign in to comment.