Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 0.4] Transition Flint index state to Failed upon refresh job termination #364

Merged
merged 1 commit into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
spark.conf.getOption("spark.flint.datasource.name").getOrElse("")

/** Flint Spark index monitor */
private val flintIndexMonitor: FlintSparkIndexMonitor =
val flintIndexMonitor: FlintSparkIndexMonitor =
new FlintSparkIndexMonitor(spark, flintClient, dataSourceName)

/**
Expand Down Expand Up @@ -250,7 +250,8 @@ class FlintSpark(val spark: SparkSession) extends Logging {
try {
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == ACTIVE || latest.state == REFRESHING)
.initialLog(latest =>
latest.state == ACTIVE || latest.state == REFRESHING || latest.state == FAILED)
.transientLog(latest => latest.copy(state = DELETING))
.finalLog(latest => latest.copy(state = DELETED))
.commit(_ => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

package org.opensearch.flint.spark

import java.time.Duration
import java.time.temporal.ChronoUnit.SECONDS
import java.util.Collections.singletonList
import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit}

import scala.collection.concurrent.{Map, TrieMap}
import scala.sys.addShutdownHook

import dev.failsafe.{Failsafe, RetryPolicy}
import dev.failsafe.event.ExecutionAttemptedEvent
import dev.failsafe.function.CheckedRunnable
import org.opensearch.flint.core.FlintClient
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING}
import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil}
Expand Down Expand Up @@ -82,6 +88,60 @@ class FlintSparkIndexMonitor(
}
}

/**
* Waits for the termination of a Spark streaming job associated with a specified index name and
* updates the Flint index state based on the outcome of the job.
*
* @param indexName
* The name of the index to monitor. If none is specified, the method will default to any
* active stream.
*/
def awaitMonitor(indexName: Option[String] = None): Unit = {
logInfo(s"Awaiting index monitor for $indexName")

// Find streaming job for the given index name, otherwise use the first if any
val job = indexName
.flatMap(name => spark.streams.active.find(_.name == name))
.orElse(spark.streams.active.headOption)

if (job.isDefined) { // must be present after DataFrameWriter.start() called in refreshIndex API
val name = job.get.name // use streaming job name because indexName maybe None
logInfo(s"Awaiting streaming job $name until terminated")
try {

/**
* Await termination of the streaming job. Do not transition the index state to ACTIVE
* post-termination to prevent conflicts with ongoing transactions in DROP/ALTER API
* operations. It's generally expected that the job will be terminated through a DROP or
* ALTER operation if no exceptions are thrown.
*/
job.get.awaitTermination()
logInfo(s"Streaming job $name terminated without exception")
} catch {
case e: Throwable =>
/**
* Transition the index state to FAILED upon encountering an exception. Retry in case
* conflicts with final transaction in scheduled task.
* ```
* TODO:
* 1) Determine the appropriate state code based on the type of exception encountered
* 2) Record and persist the error message of the root cause for further diagnostics.
* ```
*/
logError(s"Streaming job $name terminated with exception", e)
retry {
flintClient
.startTransaction(name, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
}
}
} else {
logInfo(s"Index monitor for [$indexName] not found")
}
}

/**
* Index monitor task that encapsulates the execution logic with number of consecutive error
* tracked.
Expand All @@ -106,12 +166,6 @@ class FlintSparkIndexMonitor(
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
Expand Down Expand Up @@ -144,6 +198,26 @@ class FlintSparkIndexMonitor(
logWarning("Refreshing job not found")
}
}

private def retry(operation: => Unit): Unit = {
// Retry policy for 3 times every 1 second
val retryPolicy = RetryPolicy
.builder[Unit]()
.handle(classOf[Throwable])
.withBackoff(1, 30, SECONDS)
.withJitter(Duration.ofMillis(100))
.withMaxRetries(3)
.onFailedAttempt((event: ExecutionAttemptedEvent[Unit]) =>
logError("Attempt to update index state failed: " + event))
.build()

// Use the retry policy with Failsafe
Failsafe
.`with`(singletonList(retryPolicy))
.run(new CheckedRunnable {
override def run(): Unit = operation
})
}
}

object FlintSparkIndexMonitor extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@

package org.apache.spark.sql

import java.util.{Base64, Collections}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{Duration, MINUTES}
import scala.util.{Failure, Success}
import scala.util.control.Breaks.{break, breakable}

import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.action.get.GetRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.spark.FlintSparkSuite
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.must.Matchers.{defined, have}
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the}
import org.scalatest.matchers.must.Matchers.{contain, defined}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils

Expand All @@ -38,15 +42,18 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
super.beforeAll()
// initialized after the container is started
osClient = new OSClient(new FlintOptions(openSearchOptions.asJava))
}

protected override def beforeEach(): Unit = {
super.beforeEach()
createPartitionedMultiRowAddressTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()

deleteTestIndex(testIndex)

waitJobStop(threadLocalFuture.get())
sql(s"DROP TABLE $testTable")

threadLocalFuture.remove()
}
Expand All @@ -72,6 +79,17 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val streamingRunningCount = new AtomicInteger(0)

val futureResult = Future {
/*
* Because we cannot test from FlintJob.main() for the reason below, we have to configure
* all Spark conf required by Flint code underlying manually.
*/
spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName)
spark.conf.set(JOB_TYPE.key, "streaming")

/**
* FlintJob.main() is not called because we need to manually set these variables within a
* JobOperator instance to accommodate specific runtime requirements.
*/
val job =
JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount)
job.envinromentProvider = new MockEnvironment(
Expand Down Expand Up @@ -130,6 +148,53 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
indexData.count() shouldBe 2
}

test("create skipping index with auto refresh and streaming job failure") {
val query =
s"""
| CREATE SKIPPING INDEX ON $testTable
| ( year PARTITION )
| WITH (auto_refresh = true)
| """.stripMargin
val jobRunId = "00ff4o3b5091080q"
threadLocalFuture.set(startJob(query, jobRunId))

// Waiting from streaming job start and complete current batch in Future thread in startJob
// Otherwise, active job will be None here
Thread.sleep(5000L)
pollForResultAndAssert(_ => true, jobRunId)
val activeJob = spark.streams.active.find(_.name == testIndex)
activeJob shouldBe defined
awaitStreamingComplete(activeJob.get.id.toString)

// Wait in case JobOperator has not reached condition check before awaitTermination
Thread.sleep(5000L)
try {
// Set Flint index readonly to simulate streaming job exception
setFlintIndexReadOnly(true)

// Trigger a new micro batch execution
sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=6)
| SELECT *
| FROM VALUES ('Test', 35, 'Seattle')
|""".stripMargin)
try {
awaitStreamingComplete(activeJob.get.id.toString)
} catch {
case _: Exception => // expected
}

// Assert Flint index transitioned to FAILED state after waiting seconds
Thread.sleep(2000L)
val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes)
latestLogEntry(latestId) should contain("state" -> "failed")
} finally {
// Reset so Flint index can be cleaned up in afterEach
setFlintIndexReadOnly(false)
}
}

test("create skipping index with non-existent table") {
val query =
s"""
Expand Down Expand Up @@ -260,4 +325,23 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
streamingTimeout.toMillis,
resultIndex)
}

private def setFlintIndexReadOnly(readonly: Boolean): Unit = {
logInfo(s"Updating index $testIndex setting with readonly [$readonly]")
openSearchClient
.indices()
.putSettings(
new UpdateSettingsRequest(testIndex).settings(
Map("index.blocks.write" -> readonly).asJava),
RequestOptions.DEFAULT)
}

private def latestLogEntry(latestId: String): Map[String, AnyRef] = {
val response = openSearchClient
.get(
new GetRequest(s".query_execution_request_$dataSourceName", latestId),
RequestOptions.DEFAULT)

Option(response.getSourceAsMap).getOrElse(Collections.emptyMap()).asScala.toMap
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{doAnswer, spy}
import org.opensearch.action.admin.indices.delete.DeleteIndexRequest
import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
Expand Down Expand Up @@ -105,8 +104,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.find(_.name == testFlintIndex).get.stop()
waitForMonitorTaskRun()

// Index state transit to failed and task is cancelled
latestLogEntry(testLatestId) should contain("state" -> "failed")
// Monitor task should be cancelled
task.isCancelled shouldBe true
}

Expand Down Expand Up @@ -150,6 +148,46 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.exists(_.name == testFlintIndex) shouldBe false
}

test("await monitor terminated without exception should stay refreshing state") {
// Setup a timer to terminate the streaming job
new Thread(() => {
Thread.sleep(3000L)
spark.streams.active.find(_.name == testFlintIndex).get.stop()
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "refreshing")
}

test("await monitor terminated with exception should update index state to failed") {
new Thread(() => {
Thread.sleep(3000L)

// Set Flint index readonly to simulate streaming job exception
val settings = Map("index.blocks.write" -> true)
val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava)
openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT)

// Trigger a new micro batch execution
sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=6)
| VALUES ('Test', 35, 'Vancouver')
| """.stripMargin)
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "failed")
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.spark.FlintSpark

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
Expand Down Expand Up @@ -91,14 +92,22 @@ case class JobOperator(
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
// Await streaming job thread to finish before the main thread terminates
new FlintSpark(spark).flintIndexMonitor.awaitMonitor()
} else {
logInfo(s"""
| Skip streaming job await due to conditions not met:
| - exceptionThrown: $exceptionThrown
| - streaming: $streaming
| - activeStreams: ${spark.streams.active.mkString(",")}
|""".stripMargin)
}
} catch {
case e: Exception => logError("streaming job failed", e)
}

try {
logInfo("Thread pool is being shut down")
threadPool.shutdown()
logInfo("shut down thread threadpool")
} catch {
Expand All @@ -121,8 +130,9 @@ case class JobOperator(

def stop(): Unit = {
Try {
logInfo("Stopping Spark session")
spark.stop()
logInfo("stopped spark session")
logInfo("Stopped Spark session")
} match {
case Success(_) =>
case Failure(e) => logError("unexpected error while stopping spark session", e)
Expand Down
Loading