Skip to content

Commit

Permalink
Implement FlintJob to support warmpool
Browse files Browse the repository at this point in the history
  • Loading branch information
Shri Saran Raj N committed Dec 20, 2024
1 parent 20ef890 commit 044aeea
Show file tree
Hide file tree
Showing 8 changed files with 701 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,56 @@ public final class MetricConstants {
*/
public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count";

/**
* Metric for tracking the count of streaming jobs failed during query execution
*/
public static final String STREAMING_EXECUTION_FAILED_METRIC = "streaming.execution.failed.count";

/**
* Metric for tracking the count of streaming jobs failed during query result write
*/
public static final String STREAMING_RESULT_WRITER_FAILED_METRIC = "streaming.writer.failed.count";

/**
* Metric for tracking the latency of query execution (start to complete query execution) excluding result write.
*/
public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime";
public static final String QUERY_EXECUTION_TIME_METRIC = "streaming.query.execution.processingTime";

/**
* Metric for tracking the latency of query result write only (excluding query execution)
*/
public static final String QUERY_RESULT_WRITER_TIME_METRIC = "streaming.result.writer.processingTime";

/**
* Metric for tracking the latency of query total execution including result write.
*/
public static final String QUERY_TOTAL_TIME_METRIC = "streaming.query.total.processingTime";

/**
* Metric for tracking the latency of query execution (start to complete query execution) excluding result write.
*/
public static final String STATEMENT_QUERY_EXECUTION_TIME_METRIC = "statement.query.execution.processingTime";

/**
* Metric for tracking the latency of query result write only (excluding query execution)
*/
public static final String STATEMENT_RESULT_WRITER_TIME_METRIC = "statement.result.writer.processingTime";

/**
* Metric for tracking the latency of query total execution including result write.
*/
public static final String STATEMENT_QUERY_TOTAL_TIME_METRIC = "statement.query.total.processingTime";

/**
* Metric for tracking the count of interactive jobs failed during query execution
*/
public static final String STATEMENT_EXECUTION_FAILED_METRIC = "statement.execution.failed.count";

/**
* Metric for tracking the count of interactive jobs failed during query result write
*/
public static final String STATEMENT_RESULT_WRITER_FAILED_METRIC = "statement.writer.failed.count";


/**
* Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ object FlintSparkConf {
.doc("Enable external scheduler for index refresh")
.createWithDefault("false")

val WARMPOOL_ENABLED =
FlintConfig("spark.flint.job.warmpoolEnabled")
.createWithDefault("false")

val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional()

val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD =
FlintConfig("spark.flint.job.externalScheduler.interval")
.doc("Interval threshold in minutes for external scheduler to trigger index refresh")
Expand Down Expand Up @@ -246,6 +252,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.requestIndex")
.doc("Request index")
.createOptional()
val RESULT_INDEX =
FlintConfig(s"spark.flint.job.resultIndex")
.doc("Result index")
.createOptional()
val EXCLUDE_JOB_IDS =
FlintConfig(s"spark.flint.deployment.excludeJobs")
.doc("Exclude job ids")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,48 +30,57 @@ object FlintJob extends Logging with FlintJobExecutor {
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val sparkSession = createSparkSession(conf)
val applicationId =
environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown")
val warmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean
logInfo(s"WarmpoolEnabled: ${warmpoolEnabled}")

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
if (!warmpoolEnabled) {
val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
sparkSession,
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
} else {
// Fetch and execute queries in warm pool mode
val warmpoolJob = WarmpoolJob(conf, sparkSession, resultIndexOption)
warmpoolJob.start()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.common.Strings
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
import org.opensearch.flint.core.metrics.MetricConstants
Expand All @@ -20,6 +21,7 @@ import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintREPL.instantiate
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.exception.UnrecoverableException
Expand Down Expand Up @@ -470,6 +472,13 @@ trait FlintJobExecutor {
else getRootCause(t.getCause)
}

def processQueryException(t: Throwable, flintStatement: FlintStatement): String = {
val error = processQueryException(t)
flintStatement.fail()
flintStatement.error = Some(error)
error
}

/**
* This method converts query exception into error string, which then persist to query result
* metadata
Expand Down Expand Up @@ -515,6 +524,87 @@ trait FlintJobExecutor {
}
}

def handleCommandTimeout(
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
error: String,
flintStatement: FlintStatement,
sessionId: String,
startTime: Long): DataFrame = {
/*
* https://tinyurl.com/2ezs5xj9
*
* This only interrupts active Spark jobs that are actively running.
* This would then throw the error from ExecutePlan and terminate it.
* But if the query is not running a Spark job, but executing code on Spark driver, this
* would be a noop and the execution will keep running.
*
* In Apache Spark, actions that trigger a distributed computation can lead to the creation
* of Spark jobs. In the context of Spark SQL, this typically happens when we perform
* actions that require the computation of results that need to be collected or stored.
*/
spark.sparkContext.cancelJobGroup(flintStatement.queryId)
flintStatement.timeout()
flintStatement.error = Some(error)
constructErrorDF(
applicationId,
jobId,
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime)
}

/**
* handling the case where a command's execution fails, updates the flintStatement with the
* error and failure status, and then write the result to result index. Thus, an error is
* written to both result index or statement model in request index
*
* @param spark
* spark session
* @param dataSource
* data source
* @param error
* error message
* @param flintStatement
* flint command
* @param sessionId
* session id
* @param startTime
* start time
* @return
* failed data frame
*/
def handleCommandFailureAndGetFailedData(
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
error: String,
flintStatement: FlintStatement,
sessionId: String,
startTime: Long): DataFrame = {
flintStatement.fail()
flintStatement.error = Some(error)
constructErrorDF(
applicationId,
jobId,
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime)
}

/**
* Before OS 2.13, there are two arguments from entry point: query and result index Starting
* from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also
Expand Down Expand Up @@ -547,6 +637,39 @@ trait FlintJobExecutor {
}
}

def getSegmentName(sparkSession: SparkSession): String = {
val maxExecutorsCount =
sparkSession.conf.get(FlintSparkConf.MAX_EXECUTORS_COUNT.key, "unknown")
String.format("%se", maxExecutorsCount)
}

def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
instantiate(
new SessionManagerImpl(spark, resultIndexOption),
spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""),
resultIndexOption.getOrElse(""))
}

def instantiateStatementExecutionManager(
commandContext: CommandContext): StatementExecutionManager = {
import commandContext._
instantiate(
new StatementExecutionManagerImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
spark,
sessionId)
}

def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (Strings.isNullOrEmpty(className)) {
defaultConstructor
Expand Down
Loading

0 comments on commit 044aeea

Please sign in to comment.