Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index state stuck in refreshing when streaming job exits early #370

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class FlintSparkIndexMonitor(
*/
def stopMonitor(indexName: String): Unit = {
logInfo(s"Cancelling scheduled task for index $indexName")
val task = FlintSparkIndexMonitor.indexMonitorTracker.remove(indexName)
// Hack: Don't remove because awaitMonitor API requires Flint index name.
val task = FlintSparkIndexMonitor.indexMonitorTracker.get(indexName)
if (task.isDefined) {
task.get.cancel(true)
} else {
Expand Down Expand Up @@ -119,26 +120,25 @@ class FlintSparkIndexMonitor(
logInfo(s"Streaming job $name terminated without exception")
} catch {
case e: Throwable =>
/**
* Transition the index state to FAILED upon encountering an exception. Retry in case
* conflicts with final transaction in scheduled task.
* ```
* TODO:
* 1) Determine the appropriate state code based on the type of exception encountered
* 2) Record and persist the error message of the root cause for further diagnostics.
* ```
*/
logError(s"Streaming job $name terminated with exception", e)
retry {
flintClient
.startTransaction(name, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
}
logError(s"Streaming job $name terminated with exception: ${e.getMessage}")
retryUpdateIndexStateToFailed(name)
}
} else {
logInfo(s"Index monitor for [$indexName] not found")
logInfo(s"Index monitor for [$indexName] not found.")

/*
* Streaming job exits early. Try to find Flint index name in monitor list.
* Assuming: 1) there are at most 1 entry in the list, otherwise index name
* must be given upon this method call; 2) this await API must be called for
* auto refresh index, otherwise index state will be updated mistakenly.
*/
val name = FlintSparkIndexMonitor.indexMonitorTracker.keys.headOption
if (name.isDefined) {
logInfo(s"Found index name in index monitor task list: ${name.get}")
retryUpdateIndexStateToFailed(name.get)
} else {
logInfo(s"Index monitor task list is empty")
}
}
}

Expand Down Expand Up @@ -199,6 +199,26 @@ class FlintSparkIndexMonitor(
}
}

/**
* Transition the index state to FAILED upon encountering an exception. Retry in case conflicts
* with final transaction in scheduled task.
* ```
* TODO:
* 1) Determine the appropriate state code based on the type of exception encountered
* 2) Record and persist the error message of the root cause for further diagnostics.
* ```
*/
private def retryUpdateIndexStateToFailed(indexName: String): Unit = {
logInfo(s"Updating index state to failed for $indexName")
retry {
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
}
}

private def retry(operation: => Unit): Unit = {
// Retry policy for 3 times every 1 second
val retryPolicy = RetryPolicy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.action.get.GetRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.spark.FlintSparkSuite
import org.opensearch.flint.spark.{FlintSparkIndexMonitor, FlintSparkSuite}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.must.Matchers.{contain, defined}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener._
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils

Expand All @@ -46,6 +48,11 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {

protected override def beforeEach(): Unit = {
super.beforeEach()

// Clear up because awaitMonitor will assume single name in tracker
FlintSparkIndexMonitor.indexMonitorTracker.values.foreach(_.cancel(true))
FlintSparkIndexMonitor.indexMonitorTracker.clear()

createPartitionedMultiRowAddressTable(testTable)
}

Expand Down Expand Up @@ -195,6 +202,42 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
}
}

test("create skipping index with auto refresh and streaming job early exit") {
// Custom listener to force streaming job to fail at the beginning
val listener = new StreamingQueryListener {
override def onQueryStarted(event: QueryStartedEvent): Unit = {
logInfo("Stopping streaming job intentionally")
spark.streams.active.find(_.name == event.name).get.stop()
}
override def onQueryProgress(event: QueryProgressEvent): Unit = {}
override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {}
}

try {
spark.streams.addListener(listener)
val query =
s"""
| CREATE SKIPPING INDEX ON $testTable
| (name VALUE_SET)
| WITH (auto_refresh = true)
| """.stripMargin
val jobRunId = "00ff4o3b5091080q"
threadLocalFuture.set(startJob(query, jobRunId))

// Assert streaming job must exit
Thread.sleep(5000)
pollForResultAndAssert(_ => true, jobRunId)
spark.streams.active.exists(_.name == testIndex) shouldBe false

// Assert Flint index transitioned to FAILED state after waiting seconds
Thread.sleep(2000L)
val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes)
latestLogEntry(latestId) should contain("state" -> "failed")
} finally {
spark.streams.removeListener(listener)
}
}

test("create skipping index with non-existent table") {
val query =
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
latestLog should contain("state" -> "failed")
}

test(
"await monitor terminated with streaming job exit early should update index state to failed") {
// Terminate streaming job intentionally before await
spark.streams.active.find(_.name == testFlintIndex).get.stop()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "failed")
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ case class JobOperator(
}

try {
// Wait for streaming job complete if no error and there is streaming job running
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
// Wait for streaming job complete if no error
if (!exceptionThrown && streaming) {
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// Await streaming job thread to finish before the main thread terminates
// Await index monitor before the main thread terminates
new FlintSpark(spark).flintIndexMonitor.awaitMonitor()
} else {
logInfo(s"""
Expand Down
Loading