Skip to content

Commit

Permalink
Add await monitor API and IT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed May 30, 2024
1 parent c01d389 commit cc0a878
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.flint.core.metadata.log

import org.opensearch.flint.core.metadata.FlintJsonHelper.buildJson
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.IndexState
import org.opensearch.index.seqno.SequenceNumbers.{UNASSIGNED_PRIMARY_TERM, UNASSIGNED_SEQ_NO}
Expand Down Expand Up @@ -53,6 +54,7 @@ case class FlintMetadataLogEntry(

def toJson: String = {
// Implicitly populate latest appId, jobId and timestamp whenever persist
/*
s"""
|{
| "version": "1.0",
Expand All @@ -67,6 +69,21 @@ case class FlintMetadataLogEntry(
| "error": "$error"
|}
|""".stripMargin
*/

buildJson(builder => {
builder
.field("version", "1.0")
.field("latestId", id)
.field("type", "flintindexstate")
.field("state", state.toString)
.field("applicationId", sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"))
.field("jobId", sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"))
.field("dataSourceName", dataSource)
.field("jobStartTime", createTime)
.field("lastUpdateTime", System.currentTimeMillis())
.field("error", error)
})
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.metadata

import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson
import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.FAILED
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class FlintMetadataLogEntrySuite extends AnyFlatSpec with Matchers {

implicit val formats: Formats = Serialization.formats(NoTypeHints)

"toJson" should "correctly serialize an entry with special characters" in {
val logEntry = FlintMetadataLogEntry(
"id",
1,
2,
1627890000000L,
FAILED,
"dataSource",
s"""Error with quotes ' " and newline: \n and tab: \t characters""")

val actualJson = logEntry.toJson
val lastUpdateTime = (parse(actualJson) \ "lastUpdateTime").extract[Long]
val expectedJson =
s"""
|{
| "version": "1.0",
| "latestId": "id",
| "type": "flintindexstate",
| "state": "failed",
| "applicationId": "unknown",
| "jobId": "unknown",
| "dataSourceName": "dataSource",
| "jobStartTime": 1627890000000,
| "lastUpdateTime": $lastUpdateTime,
| "error": "Error with quotes ' \\" and newline: \\n and tab: \\t characters"
|}
|""".stripMargin
actualJson should matchJson(expectedJson)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
spark.conf.getOption("spark.flint.datasource.name").getOrElse("")

/** Flint Spark index monitor */
private val flintIndexMonitor: FlintSparkIndexMonitor =
val flintIndexMonitor: FlintSparkIndexMonitor =
new FlintSparkIndexMonitor(spark, flintClient, dataSourceName)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,42 @@ class FlintSparkIndexMonitor(
}
}

/**
* Await the termination of a Spark streaming job for the given index name and update Flint
* index state accordingly.
*
* @param indexName
* Optional index monitor name to await. If not provided, choose any active stream.
*/
def awaitMonitor(indexName: Option[String] = None): Unit = {
logInfo(s"Awaiting index monitor for $indexName")

// Find streaming query for the given index name, otherwise use the first
val stream = indexName
.flatMap(name => spark.streams.active.find(_.name == name))
.orElse(spark.streams.active.headOption)

if (stream.isDefined) {
val name = stream.get.name // use streaming job name because indexName maybe None
logInfo(s"Awaiting streaming job $name until terminated")
try {
stream.get.awaitTermination()
logInfo(s"Streaming job $name terminated without exception")
} catch {
case e: Throwable =>
// Transit to failed state. TODO: determine state code on exception type
logError(s"Streaming job $name terminated with exception", e)
flintClient
.startTransaction(name, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest.copy(state = FAILED, error = extractRootCause(e)))
.commit(_ => {})
}
} else {
logInfo(s"Index monitor for [$indexName] not found")
}
}

/**
* Index monitor task that encapsulates the execution logic with number of consecutive error
* tracked.
Expand All @@ -106,12 +142,6 @@ class FlintSparkIndexMonitor(
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
Expand Down Expand Up @@ -144,6 +174,21 @@ class FlintSparkIndexMonitor(
logWarning("Refreshing job not found")
}
}

private def extractRootCause(e: Throwable): String = {
var cause = e
while (cause.getCause != null && cause.getCause != cause) {
cause = cause.getCause
}

if (cause.getLocalizedMessage != null) {
return cause.getLocalizedMessage
}
if (cause.getMessage != null) {
return cause.getMessage
}
cause.toString
}
}

object FlintSparkIndexMonitor extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{doAnswer, spy}
import org.opensearch.action.admin.indices.delete.DeleteIndexRequest
import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
Expand Down Expand Up @@ -105,8 +104,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.find(_.name == testFlintIndex).get.stop()
waitForMonitorTaskRun()

// Index state transit to failed and task is cancelled
latestLogEntry(testLatestId) should contain("state" -> "failed")
// Monitor task should be cancelled
task.isCancelled shouldBe true
}

Expand Down Expand Up @@ -150,6 +148,46 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
spark.streams.active.exists(_.name == testFlintIndex) shouldBe false
}

test("await monitor terminated without exception should stay refreshing state") {
// Setup a timer to terminate the streaming job
new Thread(() => {
Thread.sleep(3000L)
spark.streams.active.find(_.name == testFlintIndex).get.stop()
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should contain("state" -> "refreshing")
}

test("await monitor terminated with exception should update index state to failed") {
// Simulate an exception in the streaming job
new Thread(() => {
Thread.sleep(3000L)

val settings = Map("index.blocks.write" -> true)
val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava)
openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT)

sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=6)
| VALUES ('Test', 35, 'Vancouver')
| """.stripMargin)
}).start()

// Await until streaming job terminated
flint.flintIndexMonitor.awaitMonitor()

// Assert index state is active now
val latestLog = latestLogEntry(testLatestId)
latestLog should (contain("state" -> "failed") and contain key "error")
latestLog("error").asInstanceOf[String] should include("OpenSearchException")
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndexMonitor}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
Expand Down Expand Up @@ -92,7 +93,7 @@ case class JobOperator(
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
new FlintSpark(spark).flintIndexMonitor.awaitMonitor()
}
} catch {
case e: Exception => logError("streaming job failed", e)
Expand Down

0 comments on commit cc0a878

Please sign in to comment.