Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FlintJob to handle all query types in warmpool mode #979

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,56 @@ public final class MetricConstants {
*/
public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count";

/**
* Metric for tracking the count of streaming jobs failed during query execution
*/
public static final String STREAMING_EXECUTION_FAILED_METRIC = "streaming.execution.failed.count";

/**
* Metric for tracking the count of streaming jobs failed during query result write
*/
public static final String STREAMING_RESULT_WRITER_FAILED_METRIC = "streaming.writer.failed.count";

/**
* Metric for tracking the latency of query execution (start to complete query execution) excluding result write.
*/
public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime";
public static final String QUERY_EXECUTION_TIME_METRIC = "streaming.query.execution.processingTime";

/**
* Metric for tracking the latency of query result write only (excluding query execution)
*/
public static final String QUERY_RESULT_WRITER_TIME_METRIC = "streaming.result.writer.processingTime";

/**
* Metric for tracking the latency of query total execution including result write.
*/
public static final String QUERY_TOTAL_TIME_METRIC = "streaming.query.total.processingTime";

/**
* Metric for tracking the latency of query execution (start to complete query execution) excluding result write.
*/
public static final String STATEMENT_QUERY_EXECUTION_TIME_METRIC = "statement.query.execution.processingTime";

/**
* Metric for tracking the latency of query result write only (excluding query execution)
*/
public static final String STATEMENT_RESULT_WRITER_TIME_METRIC = "statement.result.writer.processingTime";

/**
* Metric for tracking the latency of query total execution including result write.
*/
public static final String STATEMENT_QUERY_TOTAL_TIME_METRIC = "statement.query.total.processingTime";

/**
* Metric for tracking the count of interactive jobs failed during query execution
*/
public static final String STATEMENT_EXECUTION_FAILED_METRIC = "statement.execution.failed.count";

/**
* Metric for tracking the count of interactive jobs failed during query result write
*/
public static final String STATEMENT_RESULT_WRITER_FAILED_METRIC = "statement.writer.failed.count";


/**
* Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ object FlintSparkConf {
.doc("Enable external scheduler for index refresh")
.createWithDefault("false")

val WARMPOOL_ENABLED =
FlintConfig("spark.flint.job.warmpoolEnabled")
.createWithDefault("false")

val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional()

val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD =
FlintConfig("spark.flint.job.externalScheduler.interval")
.doc("Interval threshold in minutes for external scheduler to trigger index refresh")
Expand Down Expand Up @@ -246,6 +252,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.requestIndex")
.doc("Request index")
.createOptional()
val RESULT_INDEX =
FlintConfig(s"spark.flint.job.resultIndex")
.doc("Result index")
.createOptional()
val EXCLUDE_JOB_IDS =
FlintConfig(s"spark.flint.deployment.excludeJobs")
.doc("Exclude job ids")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,48 +30,57 @@ object FlintJob extends Logging with FlintJobExecutor {
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val sparkSession = createSparkSession(conf)
val applicationId =
environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown")
val warmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean
logInfo(s"WarmpoolEnabled: ${warmpoolEnabled}")

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
if (!warmpoolEnabled) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduced huge if/else block, which reduce the readability/maintainability a lot. Can you split the class for warmpool and original interactive job?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this, let's abstract the common interface and move from there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to have the conf key hard-coded here?
We can probably do FlintSparkConf.JOB_TYPE.key, which similar to FlintSparkConf.WARMPOOL_ENABLED.key, on above.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing code was doing the same; I just wrapped it in an if block. I can modify this if needed.

CustomLogging.logInfo(s"""Job type is: ${jobType}""")
sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for DATA_SOURCE_NAME

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing code was doing the same; I just wrapped it in an if block. I can modify this if needed.

val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
sparkSession,
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
} else {
// Fetch and execute queries in warm pool mode
val warmpoolJob = WarmpoolJob(conf, sparkSession, resultIndexOption)
warmpoolJob.start()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.common.Strings
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
import org.opensearch.flint.core.metrics.MetricConstants
Expand All @@ -20,6 +21,7 @@ import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintREPL.instantiate
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.exception.UnrecoverableException
Expand Down Expand Up @@ -470,6 +472,13 @@ trait FlintJobExecutor {
else getRootCause(t.getCause)
}

def processQueryException(t: Throwable, flintStatement: FlintStatement): String = {
val error = processQueryException(t)
flintStatement.fail()
flintStatement.error = Some(error)
error
}

/**
* This method converts query exception into error string, which then persist to query result
* metadata
Expand Down Expand Up @@ -515,6 +524,87 @@ trait FlintJobExecutor {
}
}

def handleCommandTimeout(
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
error: String,
flintStatement: FlintStatement,
sessionId: String,
startTime: Long): DataFrame = {
/*
* https://tinyurl.com/2ezs5xj9
*
* This only interrupts active Spark jobs that are actively running.
* This would then throw the error from ExecutePlan and terminate it.
* But if the query is not running a Spark job, but executing code on Spark driver, this
* would be a noop and the execution will keep running.
*
* In Apache Spark, actions that trigger a distributed computation can lead to the creation
* of Spark jobs. In the context of Spark SQL, this typically happens when we perform
* actions that require the computation of results that need to be collected or stored.
*/
spark.sparkContext.cancelJobGroup(flintStatement.queryId)
flintStatement.timeout()
flintStatement.error = Some(error)
constructErrorDF(
applicationId,
jobId,
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime)
}

/**
* handling the case where a command's execution fails, updates the flintStatement with the
* error and failure status, and then write the result to result index. Thus, an error is
* written to both result index or statement model in request index
*
* @param spark
* spark session
* @param dataSource
* data source
* @param error
* error message
* @param flintStatement
* flint command
* @param sessionId
* session id
* @param startTime
* start time
* @return
* failed data frame
*/
def handleCommandFailureAndGetFailedData(
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
error: String,
flintStatement: FlintStatement,
sessionId: String,
startTime: Long): DataFrame = {
flintStatement.fail()
flintStatement.error = Some(error)
constructErrorDF(
applicationId,
jobId,
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime)
}

/**
* Before OS 2.13, there are two arguments from entry point: query and result index Starting
* from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also
Expand Down Expand Up @@ -547,6 +637,39 @@ trait FlintJobExecutor {
}
}

def getSegmentName(sparkSession: SparkSession): String = {
val maxExecutorsCount =
sparkSession.conf.get(FlintSparkConf.MAX_EXECUTORS_COUNT.key, "unknown")
String.format("%se", maxExecutorsCount)
}
Comment on lines +640 to +644
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This segmentName is specific to warmpool logic; let us create abstractions on warmpool and record metrics via AOP.


def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
instantiate(
new SessionManagerImpl(spark, resultIndexOption),
spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""),
resultIndexOption.getOrElse(""))
}

def instantiateStatementExecutionManager(
commandContext: CommandContext): StatementExecutionManager = {
import commandContext._
instantiate(
new StatementExecutionManagerImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
spark,
sessionId)
}

def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (Strings.isNullOrEmpty(className)) {
defaultConstructor
Expand Down
Loading