Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 0.4] Enhance index monitor to terminate streaming job on consecutive errors #347

Merged
merged 1 commit into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,9 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.flint.index.hybridscan.enabled`: default is false.
- `spark.flint.index.checkpoint.mandatory`: default is true.
- `spark.datasource.flint.socket_timeout_millis`: default value is 60000.
- `spark.flint.monitor.initialDelaySeconds`: Initial delay in seconds before starting the monitoring task. Default value is 15.
- `spark.flint.monitor.intervalSeconds`: Interval in seconds for scheduling the monitoring task. Default value is 60.
- `spark.flint.monitor.maxErrorCount`: Maximum number of consecutive errors allowed before stopping the monitoring task. Default value is 5.

#### Data Type Mapping

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ object FlintSparkConf {
.doc("Checkpoint location for incremental refresh index will be mandatory if enabled")
.createWithDefault("true")

val MONITOR_INITIAL_DELAY_SECONDS = FlintConfig("spark.flint.monitor.initialDelaySeconds")
.doc("Initial delay in seconds before starting the monitoring task")
.createWithDefault("15")

val MONITOR_INTERVAL_SECONDS = FlintConfig("spark.flint.monitor.intervalSeconds")
.doc("Interval in seconds for scheduling the monitoring task")
.createWithDefault("60")

val MONITOR_MAX_ERROR_COUNT = FlintConfig("spark.flint.monitor.maxErrorCount")
.doc("Maximum number of consecutive errors allowed in index monitor")
.createWithDefault("5")

val SOCKET_TIMEOUT_MILLIS =
FlintConfig(s"spark.datasource.flint.${FlintOptions.SOCKET_TIMEOUT_MILLIS}")
.datasourceOption()
Expand Down Expand Up @@ -223,6 +235,12 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable

def isCheckpointMandatory: Boolean = CHECKPOINT_MANDATORY.readFrom(reader).toBoolean

def monitorInitialDelaySeconds(): Int = MONITOR_INITIAL_DELAY_SECONDS.readFrom(reader).toInt

def monitorIntervalSeconds(): Int = MONITOR_INTERVAL_SECONDS.readFrom(reader).toInt

def monitorMaxErrorCount(): Int = MONITOR_MAX_ERROR_COUNT.readFrom(reader).toInt

/**
* spark.sql.session.timeZone
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor

/**
Expand All @@ -34,43 +35,32 @@ class FlintSparkIndexMonitor(
dataSourceName: String)
extends Logging {

/** Task execution initial delay in seconds */
private val INITIAL_DELAY_SECONDS = FlintSparkConf().monitorInitialDelaySeconds()

/** Task execution interval in seconds */
private val INTERVAL_SECONDS = FlintSparkConf().monitorIntervalSeconds()

/** Max error count allowed */
private val MAX_ERROR_COUNT = FlintSparkConf().monitorMaxErrorCount()

/**
* Start monitoring task on the given Flint index.
*
* @param indexName
* Flint index name
*/
def startMonitor(indexName: String): Unit = {
val task = FlintSparkIndexMonitor.executor.scheduleWithFixedDelay(
() => {
logInfo(s"Scheduler trigger index monitor task for $indexName")
try {
if (isStreamingJobActive(indexName)) {
logInfo("Streaming job is still active")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest) // timestamp will update automatically
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
logInfo(s"""Starting index monitor for $indexName with configuration:
| - Initial delay: $INITIAL_DELAY_SECONDS seconds
| - Interval: $INTERVAL_SECONDS seconds
| - Max error count: $MAX_ERROR_COUNT
|""".stripMargin)

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
} catch {
case e: Throwable =>
logError("Failed to update index log entry", e)
MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC)
}
},
15, // Delay to ensure final logging is complete first, otherwise version conflicts
60, // TODO: make interval configurable
val task = FlintSparkIndexMonitor.executor.scheduleWithFixedDelay(
new FlintSparkIndexMonitorTask(indexName),
INITIAL_DELAY_SECONDS, // Delay to ensure final logging is complete first, otherwise version conflicts
INTERVAL_SECONDS,
TimeUnit.SECONDS)

FlintSparkIndexMonitor.indexMonitorTracker.put(indexName, task)
Expand All @@ -92,8 +82,68 @@ class FlintSparkIndexMonitor(
}
}

/**
* Index monitor task that encapsulates the execution logic with number of consecutive error
* tracked.
*
* @param indexName
* Flint index name
*/
private class FlintSparkIndexMonitorTask(indexName: String) extends Runnable {

/** The number of consecutive error */
private var errorCnt = 0

override def run(): Unit = {
logInfo(s"Scheduler trigger index monitor task for $indexName")
try {
if (isStreamingJobActive(indexName)) {
logInfo("Streaming job is still active")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest) // timestamp will update automatically
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
errorCnt = 0 // Reset counter if no error
} catch {
case e: Throwable =>
errorCnt += 1
logError(s"Failed to update index log entry, consecutive errors: $errorCnt", e)
MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC)

// Stop streaming job and its monitor if max retry limit reached
if (errorCnt >= MAX_ERROR_COUNT) {
logInfo(s"Terminating streaming job and index monitor for $indexName")
stopStreamingJob(indexName)
stopMonitor(indexName)
logInfo(s"Streaming job and index monitor terminated")
}
}
}
}

private def isStreamingJobActive(indexName: String): Boolean =
spark.streams.active.exists(_.name == indexName)

private def stopStreamingJob(indexName: String): Unit = {
val job = spark.streams.active.find(_.name == indexName)
if (job.isDefined) {
job.get.stop()
} else {
logWarning("Refreshing job not found")
}
}
}

object FlintSparkIndexMonitor extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.flint.config.FlintSparkConf.{MONITOR_INITIAL_DELAY_SECONDS, MONITOR_INTERVAL_SECONDS, MONITOR_MAX_ERROR_COUNT}

class FlintSparkConfSuite extends FlintSuite {
test("test spark conf") {
Expand Down Expand Up @@ -84,6 +85,24 @@ class FlintSparkConfSuite extends FlintSuite {
overrideConf.flintOptions().getBatchBytes shouldBe 4 * 1024 * 1024
}

test("test index monitor options") {
val defaultConf = FlintSparkConf()
defaultConf.monitorInitialDelaySeconds() shouldBe 15
defaultConf.monitorIntervalSeconds() shouldBe 60
defaultConf.monitorMaxErrorCount() shouldBe 5

withSparkConf(MONITOR_MAX_ERROR_COUNT.key, MONITOR_INTERVAL_SECONDS.key) {
setFlintSparkConf(MONITOR_INITIAL_DELAY_SECONDS, 5)
setFlintSparkConf(MONITOR_INTERVAL_SECONDS, 30)
setFlintSparkConf(MONITOR_MAX_ERROR_COUNT, 10)

val overrideConf = FlintSparkConf()
defaultConf.monitorInitialDelaySeconds() shouldBe 5
overrideConf.monitorIntervalSeconds() shouldBe 30
overrideConf.monitorMaxErrorCount() shouldBe 10
}
}

/**
* Delete index `indexNames` after calling `f`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.should.Matchers

import org.apache.spark.sql.flint.config.FlintSparkConf.MONITOR_MAX_ERROR_COUNT
import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor

class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matchers {
Expand All @@ -40,6 +41,9 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
realExecutor.scheduleWithFixedDelay(invocation.getArgument(0), 5, 1, TimeUnit.SECONDS)
}).when(FlintSparkIndexMonitor.executor)
.scheduleWithFixedDelay(any[Runnable], any[Long], any[Long], any[TimeUnit])

// Set max error count higher to avoid impact on transient error test case
setFlintSparkConf(MONITOR_MAX_ERROR_COUNT, 10)
}

override def beforeEach(): Unit = {
Expand Down Expand Up @@ -128,6 +132,24 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
}
}

test("monitor task and streaming job should terminate if exception occurred consistently") {
val task = FlintSparkIndexMonitor.indexMonitorTracker(testFlintIndex)

// Block write on metadata log index
setWriteBlockOnMetadataLogIndex(true)
waitForMonitorTaskRun()

// Both monitor task and streaming job should stop after 10 times
10 times { (_, _) =>
{
// assert nothing. just wait enough times of task execution
}
}

task.isCancelled shouldBe true
spark.streams.active.exists(_.name == testFlintIndex) shouldBe false
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down
Loading