Skip to content

Commit

Permalink
Fix index state stuck in refreshing when streaming job exits early (#370
Browse files Browse the repository at this point in the history
)

* Handle streaming job exit early case

Signed-off-by: Chen Dai <[email protected]>

* Modify IT to simplify

Signed-off-by: Chen Dai <[email protected]>

* Address PR comments

Signed-off-by: Chen Dai <[email protected]>

* Add more IT

Signed-off-by: Chen Dai <[email protected]>

---------

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Jun 10, 2024
1 parent c5ad7e7 commit 5460acc
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class FlintSparkIndexMonitor(
*/
def stopMonitor(indexName: String): Unit = {
logInfo(s"Cancelling scheduled task for index $indexName")
val task = FlintSparkIndexMonitor.indexMonitorTracker.remove(indexName)
// Hack: Don't remove because awaitMonitor API requires Flint index name.
val task = FlintSparkIndexMonitor.indexMonitorTracker.get(indexName)
if (task.isDefined) {
task.get.cancel(true)
} else {
Expand Down Expand Up @@ -119,26 +120,25 @@ class FlintSparkIndexMonitor(
logInfo(s"Streaming job $name terminated without exception")
} catch {
case e: Throwable =>
/**
* Transition the index state to FAILED upon encountering an exception. Retry in case
* conflicts with final transaction in scheduled task.
* ```
* TODO:
* 1) Determine the appropriate state code based on the type of exception encountered
* 2) Record and persist the error message of the root cause for further diagnostics.
* ```
*/
logError(s"Streaming job $name terminated with exception", e)
retry {
flintClient
.startTransaction(name, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
}
logError(s"Streaming job $name terminated with exception: ${e.getMessage}")
retryUpdateIndexStateToFailed(name)
}
} else {
logInfo(s"Index monitor for [$indexName] not found")
logInfo(s"Index monitor for [$indexName] not found.")

/*
* Streaming job exits early. Try to find Flint index name in monitor list.
* Assuming: 1) there are at most 1 entry in the list, otherwise index name
* must be given upon this method call; 2) this await API must be called for
* auto refresh index, otherwise index state will be updated mistakenly.
*/
val name = FlintSparkIndexMonitor.indexMonitorTracker.keys.headOption
if (name.isDefined) {
logInfo(s"Found index name in index monitor task list: ${name.get}")
retryUpdateIndexStateToFailed(name.get)
} else {
logInfo(s"Index monitor task list is empty")
}
}
}

Expand Down Expand Up @@ -199,6 +199,26 @@ class FlintSparkIndexMonitor(
}
}

/**
* Transition the index state to FAILED upon encountering an exception. Retry in case conflicts
* with final transaction in scheduled task.
* ```
* TODO:
* 1) Determine the appropriate state code based on the type of exception encountered
* 2) Record and persist the error message of the root cause for further diagnostics.
* ```
*/
private def retryUpdateIndexStateToFailed(indexName: String): Unit = {
logInfo(s"Updating index state to failed for $indexName")
retry {
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
}
}

private def retry(operation: => Unit): Unit = {
// Retry policy for 3 times every 1 second
val retryPolicy = RetryPolicy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.action.get.GetRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.spark.FlintSparkSuite
import org.opensearch.flint.spark.{FlintSparkIndexMonitor, FlintSparkSuite}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.must.Matchers.{contain, defined}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener._
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils

Expand All @@ -46,6 +48,11 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {

protected override def beforeEach(): Unit = {
super.beforeEach()

// Clear up because awaitMonitor will assume single name in tracker
FlintSparkIndexMonitor.indexMonitorTracker.values.foreach(_.cancel(true))
FlintSparkIndexMonitor.indexMonitorTracker.clear()

createPartitionedMultiRowAddressTable(testTable)
}

Expand Down Expand Up @@ -195,6 +202,42 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
}
}

test("create skipping index with auto refresh and streaming job early exit") {
// Custom listener to force streaming job to fail at the beginning
val listener = new StreamingQueryListener {
override def onQueryStarted(event: QueryStartedEvent): Unit = {
logInfo("Stopping streaming job intentionally")
spark.streams.active.find(_.name == event.name).get.stop()
}
override def onQueryProgress(event: QueryProgressEvent): Unit = {}
override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {}
}

try {
spark.streams.addListener(listener)
val query =
s"""
| CREATE SKIPPING INDEX ON $testTable
| (name VALUE_SET)
| WITH (auto_refresh = true)
| """.stripMargin
val jobRunId = "00ff4o3b5091080q"
threadLocalFuture.set(startJob(query, jobRunId))

// Assert streaming job must exit
Thread.sleep(5000)
pollForResultAndAssert(_ => true, jobRunId)
spark.streams.active.exists(_.name == testIndex) shouldBe false

// Assert Flint index transitioned to FAILED state after waiting seconds
Thread.sleep(2000L)
val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes)
latestLogEntry(latestId) should contain("state" -> "failed")
} finally {
spark.streams.removeListener(listener)
}
}

test("create skipping index with non-existent table") {
val query =
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
latestLog should contain("state" -> "failed")
}

test(
"await monitor terminated with streaming job exit early should update index state to failed") {
// Terminate streaming job intentionally before await
spark.streams.active.find(_.name == testFlintIndex).get.stop()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "failed")
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ case class JobOperator(
}

try {
// Wait for streaming job complete if no error and there is streaming job running
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
// Wait for streaming job complete if no error
if (!exceptionThrown && streaming) {
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// Await streaming job thread to finish before the main thread terminates
// Await index monitor before the main thread terminates
new FlintSpark(spark).flintIndexMonitor.awaitMonitor()
} else {
logInfo(s"""
Expand Down

0 comments on commit 5460acc

Please sign in to comment.