Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Logging for Job Cleanup and Reduce REPL Inactivity Timeout #160

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ package org.apache.spark.sql

import java.util.Locale

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import org.opensearch.client.{RequestOptions, RestHighLevelClient}
import org.opensearch.cluster.metadata.MappingMetadata
import org.opensearch.common.settings.Settings
Expand All @@ -22,9 +19,7 @@ import play.api.libs.json._
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.util.ThreadUtils

/**
* Spark SQL Application entrypoint
Expand All @@ -48,51 +43,19 @@ object FlintJob extends Logging with FlintJobExecutor {
val conf = createSparkConf()
val wait = conf.get("spark.flint.job.type", "continue")
val dataSource = conf.get("spark.flint.datasource.name", "")
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
val spark = createSparkSession(conf)

val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index")
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var dataToWrite: Option[DataFrame] = None
val startTime = System.currentTimeMillis()
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
var exceptionThrown = true
try {
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
val data = executeQuery(spark, query, dataSource, "", "")

val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(mappingCheckResult match {
case Right(_) => data
case Left(error) =>
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)
})
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
case e: Exception =>
val error = processQueryException(e, spark, dataSource, query, "", "")
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
} finally {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
// Stop SparkSession if streaming job succeeds
if (!exceptionThrown && wait.equalsIgnoreCase("streaming")) {
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
} else {
spark.stop()
}

threadPool.shutdown()
}
val jobOperator =
JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming"))
jobOperator.start()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.ThreadPoolExecutor

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import com.amazonaws.services.s3.model.AmazonS3Exception
import org.opensearch.flint.core.FlintClient
Expand All @@ -14,11 +18,13 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.{createIndex, getFormattedData, isSuperset, logError, logInfo}
import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider}
import org.apache.spark.util.ThreadUtils

trait FlintJobExecutor {
this: Logging =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.util.ThreadUtils
object FlintREPL extends Logging with FlintJobExecutor {

private val HEARTBEAT_INTERVAL_MILLIS = 60000L
private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 30 * 60 * 1000
private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000
private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES)
private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES)
private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000
Expand All @@ -63,6 +63,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
val conf: SparkConf = createSparkConf()
val dataSource = conf.get("spark.flint.datasource.name", "unknown")
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
val wait = conf.get("spark.flint.job.type", "continue")
// we don't allow default value for sessionIndex and sessionId. Throw exception if key not found.
Expand Down Expand Up @@ -99,7 +106,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS)

val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get)
createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get)
addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get)
// 1 thread for updating heart beat
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)
val jobStartTime = currentTimeProvider.currentEpochMillis()
Expand Down Expand Up @@ -767,7 +774,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
flintReader
}

def createShutdownHook(
def addShutdownHook(
flintSessionIndexUpdater: OpenSearchUpdater,
osClient: OSClient,
sessionIndex: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import java.util.concurrent.ThreadPoolExecutor

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}
import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.storage.OpenSearchUpdater

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, updateFlintInstanceBeforeShutdown}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.ThreadUtils

case class JobOperator(
sparkConf: SparkConf,
query: String,
dataSource: String,
resultIndex: String,
streaming: Boolean)
extends Logging
with FlintJobExecutor {
private val spark = createSparkSession(sparkConf)

// jvm shutdown hook
sys.addShutdownHook(stop())

def start(): Unit = {
val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index")
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var dataToWrite: Option[DataFrame] = None
val startTime = System.currentTimeMillis()
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
var exceptionThrown = true
try {
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
val data = executeQuery(spark, query, dataSource, "", "")

val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(mappingCheckResult match {
case Right(_) => data
case Left(error) =>
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)
})
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
case e: Exception =>
val error = processQueryException(e, spark, dataSource, query, "", "")
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
} finally {
cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient)
}
}

def cleanUpResources(
exceptionThrown: Boolean,
threadPool: ThreadPoolExecutor,
dataToWrite: Option[DataFrame],
resultIndex: String,
osClient: OSClient): Unit = {
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception => logError("fail to write to result index", e)
}

try {
// Stop SparkSession if streaming job succeeds
if (!exceptionThrown && streaming) {
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
}
} catch {
case e: Exception => logError("streaming job failed", e)
}

try {
threadPool.shutdown()
} catch {
case e: Exception => logError("Fail to close threadpool", e)
}
}

def stop(): Unit = {
Try {
spark.stop()
} match {
case Success(_) =>
case Failure(e) => logError("unexpected error while shutdown", e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class FlintREPLTest
}

// Here, we're injecting our mockShutdownHookManager into the method
FlintREPL.createShutdownHook(
FlintREPL.addShutdownHook(
flintSessionIndexUpdater,
osClient,
sessionIndex,
Expand Down
Loading