Skip to content

Commit

Permalink
Add await monitor API and IT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed May 30, 2024
1 parent c01d389 commit 0721bec
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
spark.conf.getOption("spark.flint.datasource.name").getOrElse("")

/** Flint Spark index monitor */
private val flintIndexMonitor: FlintSparkIndexMonitor =
val flintIndexMonitor: FlintSparkIndexMonitor =
new FlintSparkIndexMonitor(spark, flintClient, dataSourceName)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,57 @@ class FlintSparkIndexMonitor(
}
}

/**
* Waits for the termination of a Spark streaming job associated with a specified index name and
* updates the Flint index state based on the outcome of the job.
*
* @param indexName
* The name of the index to monitor. If none is specified, the method will default to any
* active stream.
*/
def awaitMonitor(indexName: Option[String] = None): Unit = {
logInfo(s"Awaiting index monitor for $indexName")

// Find streaming job for the given index name, otherwise use the first if any
val job = indexName
.flatMap(name => spark.streams.active.find(_.name == name))
.orElse(spark.streams.active.headOption)

if (job.isDefined) {
val name = job.get.name // use streaming job name because indexName maybe None
logInfo(s"Awaiting streaming job $name until terminated")
try {

/**
* Await termination of the streaming job. Do not transition the index state to ACTIVE
* post-termination to prevent conflicts with ongoing transactions in DROP/ALTER API
* operations. It's generally expected that the job will be terminated through a DROP or
* ALTER operation if no exceptions are thrown.
*/
job.get.awaitTermination()
logInfo(s"Streaming job $name terminated without exception")
} catch {
case e: Throwable =>
/**
* Transition the index state to FAILED upon encountering an exception.
* ```
* TODO:
* 1) Determine the appropriate state code based on the type of exception encountered
* 2) Record and persist the error message of the root cause for further diagnostics.
* ```
*/
logError(s"Streaming job $name terminated with exception", e)
flintClient
.startTransaction(name, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
}
} else {
logInfo(s"Index monitor for [$indexName] not found")
}
}

/**
* Index monitor task that encapsulates the execution logic with number of consecutive error
* tracked.
Expand All @@ -106,12 +157,6 @@ class FlintSparkIndexMonitor(
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{doAnswer, spy}
import org.opensearch.action.admin.indices.delete.DeleteIndexRequest
import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
Expand Down Expand Up @@ -105,8 +104,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.find(_.name == testFlintIndex).get.stop()
waitForMonitorTaskRun()

// Index state transit to failed and task is cancelled
latestLogEntry(testLatestId) should contain("state" -> "failed")
// Monitor task should be cancelled
task.isCancelled shouldBe true
}

Expand Down Expand Up @@ -150,6 +148,45 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.exists(_.name == testFlintIndex) shouldBe false
}

test("await monitor terminated without exception should stay refreshing state") {
// Setup a timer to terminate the streaming job
new Thread(() => {
Thread.sleep(3000L)
spark.streams.active.find(_.name == testFlintIndex).get.stop()
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "refreshing")
}

test("await monitor terminated with exception should update index state to failed") {
// Simulate an exception in the streaming job
new Thread(() => {
Thread.sleep(3000L)

val settings = Map("index.blocks.write" -> true)
val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava)
openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT)

sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=6)
| VALUES ('Test', 35, 'Vancouver')
| """.stripMargin)
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "failed")
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndexMonitor}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
Expand Down Expand Up @@ -91,14 +92,15 @@ case class JobOperator(
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
// Await streaming job thread to finish before the main thread terminates
new FlintSpark(spark).flintIndexMonitor.awaitMonitor()
}
} catch {
case e: Exception => logError("streaming job failed", e)
}

try {
logInfo("Thread pool is being shut down")
threadPool.shutdown()
logInfo("shut down thread threadpool")
} catch {
Expand Down

0 comments on commit 0721bec

Please sign in to comment.