Skip to content

Commit

Permalink
Transition Flint index state to Failed upon refresh job termination (#…
Browse files Browse the repository at this point in the history
…362)

* Add await monitor API and IT

Signed-off-by: Chen Dai <[email protected]>

* Fix FlintJob IT

Signed-off-by: Chen Dai <[email protected]>

* Add new FlintJob IT

Signed-off-by: Chen Dai <[email protected]>

* Fix failed index cannot be deleted issue

Signed-off-by: Chen Dai <[email protected]>

* Fix FlintJob IT due to multi threading

Signed-off-by: Chen Dai <[email protected]>

* Address PR comments

Signed-off-by: Chen Dai <[email protected]>

---------

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Jun 1, 2024
1 parent c01d389 commit 689788a
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
spark.conf.getOption("spark.flint.datasource.name").getOrElse("")

/** Flint Spark index monitor */
private val flintIndexMonitor: FlintSparkIndexMonitor =
val flintIndexMonitor: FlintSparkIndexMonitor =
new FlintSparkIndexMonitor(spark, flintClient, dataSourceName)

/**
Expand Down Expand Up @@ -250,7 +250,8 @@ class FlintSpark(val spark: SparkSession) extends Logging {
try {
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == ACTIVE || latest.state == REFRESHING)
.initialLog(latest =>
latest.state == ACTIVE || latest.state == REFRESHING || latest.state == FAILED)
.transientLog(latest => latest.copy(state = DELETING))
.finalLog(latest => latest.copy(state = DELETED))
.commit(_ => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

package org.opensearch.flint.spark

import java.time.Duration
import java.time.temporal.ChronoUnit.SECONDS
import java.util.Collections.singletonList
import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit}

import scala.collection.concurrent.{Map, TrieMap}
import scala.sys.addShutdownHook

import dev.failsafe.{Failsafe, RetryPolicy}
import dev.failsafe.event.ExecutionAttemptedEvent
import dev.failsafe.function.CheckedRunnable
import org.opensearch.flint.core.FlintClient
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING}
import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil}
Expand Down Expand Up @@ -82,6 +88,60 @@ class FlintSparkIndexMonitor(
}
}

/**
* Waits for the termination of a Spark streaming job associated with a specified index name and
* updates the Flint index state based on the outcome of the job.
*
* @param indexName
* The name of the index to monitor. If none is specified, the method will default to any
* active stream.
*/
def awaitMonitor(indexName: Option[String] = None): Unit = {
logInfo(s"Awaiting index monitor for $indexName")

// Find streaming job for the given index name, otherwise use the first if any
val job = indexName
.flatMap(name => spark.streams.active.find(_.name == name))
.orElse(spark.streams.active.headOption)

if (job.isDefined) { // must be present after DataFrameWriter.start() called in refreshIndex API
val name = job.get.name // use streaming job name because indexName maybe None
logInfo(s"Awaiting streaming job $name until terminated")
try {

/**
* Await termination of the streaming job. Do not transition the index state to ACTIVE
* post-termination to prevent conflicts with ongoing transactions in DROP/ALTER API
* operations. It's generally expected that the job will be terminated through a DROP or
* ALTER operation if no exceptions are thrown.
*/
job.get.awaitTermination()
logInfo(s"Streaming job $name terminated without exception")
} catch {
case e: Throwable =>
/**
* Transition the index state to FAILED upon encountering an exception. Retry in case
* conflicts with final transaction in scheduled task.
* ```
* TODO:
* 1) Determine the appropriate state code based on the type of exception encountered
* 2) Record and persist the error message of the root cause for further diagnostics.
* ```
*/
logError(s"Streaming job $name terminated with exception", e)
retry {
flintClient
.startTransaction(name, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
}
}
} else {
logInfo(s"Index monitor for [$indexName] not found")
}
}

/**
* Index monitor task that encapsulates the execution logic with number of consecutive error
* tracked.
Expand All @@ -106,12 +166,6 @@ class FlintSparkIndexMonitor(
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
Expand Down Expand Up @@ -144,6 +198,26 @@ class FlintSparkIndexMonitor(
logWarning("Refreshing job not found")
}
}

private def retry(operation: => Unit): Unit = {
// Retry policy for 3 times every 1 second
val retryPolicy = RetryPolicy
.builder[Unit]()
.handle(classOf[Throwable])
.withBackoff(1, 30, SECONDS)
.withJitter(Duration.ofMillis(100))
.withMaxRetries(3)
.onFailedAttempt((event: ExecutionAttemptedEvent[Unit]) =>
logError("Attempt to update index state failed: " + event))
.build()

// Use the retry policy with Failsafe
Failsafe
.`with`(singletonList(retryPolicy))
.run(new CheckedRunnable {
override def run(): Unit = operation
})
}
}

object FlintSparkIndexMonitor extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@

package org.apache.spark.sql

import java.util.{Base64, Collections}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{Duration, MINUTES}
import scala.util.{Failure, Success}
import scala.util.control.Breaks.{break, breakable}

import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.action.get.GetRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.spark.FlintSparkSuite
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.must.Matchers.{defined, have}
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the}
import org.scalatest.matchers.must.Matchers.{contain, defined}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils

Expand All @@ -38,15 +42,18 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
super.beforeAll()
// initialized after the container is started
osClient = new OSClient(new FlintOptions(openSearchOptions.asJava))
}

protected override def beforeEach(): Unit = {
super.beforeEach()
createPartitionedMultiRowAddressTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()

deleteTestIndex(testIndex)

waitJobStop(threadLocalFuture.get())
sql(s"DROP TABLE $testTable")

threadLocalFuture.remove()
}
Expand All @@ -72,6 +79,17 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val streamingRunningCount = new AtomicInteger(0)

val futureResult = Future {
/*
* Because we cannot test from FlintJob.main() for the reason below, we have to configure
* all Spark conf required by Flint code underlying manually.
*/
spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName)
spark.conf.set(JOB_TYPE.key, "streaming")

/**
* FlintJob.main() is not called because we need to manually set these variables within a
* JobOperator instance to accommodate specific runtime requirements.
*/
val job =
JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount)
job.envinromentProvider = new MockEnvironment(
Expand Down Expand Up @@ -130,6 +148,53 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
indexData.count() shouldBe 2
}

test("create skipping index with auto refresh and streaming job failure") {
val query =
s"""
| CREATE SKIPPING INDEX ON $testTable
| ( year PARTITION )
| WITH (auto_refresh = true)
| """.stripMargin
val jobRunId = "00ff4o3b5091080q"
threadLocalFuture.set(startJob(query, jobRunId))

// Waiting from streaming job start and complete current batch in Future thread in startJob
// Otherwise, active job will be None here
Thread.sleep(5000L)
pollForResultAndAssert(_ => true, jobRunId)
val activeJob = spark.streams.active.find(_.name == testIndex)
activeJob shouldBe defined
awaitStreamingComplete(activeJob.get.id.toString)

// Wait in case JobOperator has not reached condition check before awaitTermination
Thread.sleep(5000L)
try {
// Set Flint index readonly to simulate streaming job exception
setFlintIndexReadOnly(true)

// Trigger a new micro batch execution
sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=6)
| SELECT *
| FROM VALUES ('Test', 35, 'Seattle')
|""".stripMargin)
try {
awaitStreamingComplete(activeJob.get.id.toString)
} catch {
case _: Exception => // expected
}

// Assert Flint index transitioned to FAILED state after waiting seconds
Thread.sleep(2000L)
val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes)
latestLogEntry(latestId) should contain("state" -> "failed")
} finally {
// Reset so Flint index can be cleaned up in afterEach
setFlintIndexReadOnly(false)
}
}

test("create skipping index with non-existent table") {
val query =
s"""
Expand Down Expand Up @@ -260,4 +325,23 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
streamingTimeout.toMillis,
resultIndex)
}

private def setFlintIndexReadOnly(readonly: Boolean): Unit = {
logInfo(s"Updating index $testIndex setting with readonly [$readonly]")
openSearchClient
.indices()
.putSettings(
new UpdateSettingsRequest(testIndex).settings(
Map("index.blocks.write" -> readonly).asJava),
RequestOptions.DEFAULT)
}

private def latestLogEntry(latestId: String): Map[String, AnyRef] = {
val response = openSearchClient
.get(
new GetRequest(s".query_execution_request_$dataSourceName", latestId),
RequestOptions.DEFAULT)

Option(response.getSourceAsMap).getOrElse(Collections.emptyMap()).asScala.toMap
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{doAnswer, spy}
import org.opensearch.action.admin.indices.delete.DeleteIndexRequest
import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
Expand Down Expand Up @@ -105,8 +104,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.find(_.name == testFlintIndex).get.stop()
waitForMonitorTaskRun()

// Index state transit to failed and task is cancelled
latestLogEntry(testLatestId) should contain("state" -> "failed")
// Monitor task should be cancelled
task.isCancelled shouldBe true
}

Expand Down Expand Up @@ -150,6 +148,46 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.exists(_.name == testFlintIndex) shouldBe false
}

test("await monitor terminated without exception should stay refreshing state") {
// Setup a timer to terminate the streaming job
new Thread(() => {
Thread.sleep(3000L)
spark.streams.active.find(_.name == testFlintIndex).get.stop()
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "refreshing")
}

test("await monitor terminated with exception should update index state to failed") {
new Thread(() => {
Thread.sleep(3000L)

// Set Flint index readonly to simulate streaming job exception
val settings = Map("index.blocks.write" -> true)
val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava)
openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT)

// Trigger a new micro batch execution
sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=6)
| VALUES ('Test', 35, 'Vancouver')
| """.stripMargin)
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "failed")
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.spark.FlintSpark

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
Expand Down Expand Up @@ -91,14 +92,22 @@ case class JobOperator(
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
// Await streaming job thread to finish before the main thread terminates
new FlintSpark(spark).flintIndexMonitor.awaitMonitor()
} else {
logInfo(s"""
| Skip streaming job await due to conditions not met:
| - exceptionThrown: $exceptionThrown
| - streaming: $streaming
| - activeStreams: ${spark.streams.active.mkString(",")}
|""".stripMargin)
}
} catch {
case e: Exception => logError("streaming job failed", e)
}

try {
logInfo("Thread pool is being shut down")
threadPool.shutdown()
logInfo("shut down thread threadpool")
} catch {
Expand All @@ -121,8 +130,9 @@ case class JobOperator(

def stop(): Unit = {
Try {
logInfo("Stopping Spark session")
spark.stop()
logInfo("stopped spark session")
logInfo("Stopped Spark session")
} match {
case Success(_) =>
case Failure(e) => logError("unexpected error while stopping spark session", e)
Expand Down

0 comments on commit 689788a

Please sign in to comment.