diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala index 5f229d412..efe20e2bb 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala @@ -5,6 +5,7 @@ package org.opensearch.flint.core.metadata.log +import org.opensearch.flint.core.metadata.FlintJsonHelper.buildJson import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.IndexState import org.opensearch.index.seqno.SequenceNumbers.{UNASSIGNED_PRIMARY_TERM, UNASSIGNED_SEQ_NO} @@ -53,6 +54,7 @@ case class FlintMetadataLogEntry( def toJson: String = { // Implicitly populate latest appId, jobId and timestamp whenever persist + /* s""" |{ | "version": "1.0", @@ -67,6 +69,21 @@ case class FlintMetadataLogEntry( | "error": "$error" |} |""".stripMargin + */ + + buildJson(builder => { + builder + .field("version", "1.0") + .field("latestId", id) + .field("type", "flintindexstate") + .field("state", state.toString) + .field("applicationId", sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")) + .field("jobId", sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown")) + .field("dataSourceName", dataSource) + .field("jobStartTime", createTime) + .field("lastUpdateTime", System.currentTimeMillis()) + .field("error", error) + }) } } diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataLogEntrySuite.scala b/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataLogEntrySuite.scala new file mode 100644 index 000000000..3dfc1b022 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataLogEntrySuite.scala @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata + +import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.json4s._ +import org.json4s.native.JsonMethods._ +import org.json4s.native.Serialization +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.FAILED +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class FlintMetadataLogEntrySuite extends AnyFlatSpec with Matchers { + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + "toJson" should "correctly serialize an entry with special characters" in { + val logEntry = FlintMetadataLogEntry( + "id", + 1, + 2, + 1627890000000L, + FAILED, + "dataSource", + s"""Error with quotes ' " and newline: \n and tab: \t characters""") + + val actualJson = logEntry.toJson + val lastUpdateTime = (parse(actualJson) \ "lastUpdateTime").extract[Long] + val expectedJson = + s""" + |{ + | "version": "1.0", + | "latestId": "id", + | "type": "flintindexstate", + | "state": "failed", + | "applicationId": "unknown", + | "jobId": "unknown", + | "dataSourceName": "dataSource", + | "jobStartTime": 1627890000000, + | "lastUpdateTime": $lastUpdateTime, + | "error": "Error with quotes ' \\" and newline: \\n and tab: \\t characters" + |} + |""".stripMargin + actualJson should matchJson(expectedJson) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 9cd5f60a7..083aa5d16 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -54,7 +54,7 @@ class FlintSpark(val spark: SparkSession) extends Logging { spark.conf.getOption("spark.flint.datasource.name").getOrElse("") /** Flint Spark index monitor */ - private val flintIndexMonitor: FlintSparkIndexMonitor = + val flintIndexMonitor: FlintSparkIndexMonitor = new FlintSparkIndexMonitor(spark, flintClient, dataSourceName) /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 02cbfd7b1..f999b372d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -82,6 +82,42 @@ class FlintSparkIndexMonitor( } } + /** + * Await the termination of a Spark streaming job for the given index name and update Flint + * index state accordingly. + * + * @param indexName + * Optional index monitor name to await. If not provided, choose any active stream. + */ + def awaitMonitor(indexName: Option[String] = None): Unit = { + logInfo(s"Awaiting index monitor for $indexName") + + // Find streaming query for the given index name, otherwise use the first + val stream = indexName + .flatMap(name => spark.streams.active.find(_.name == name)) + .orElse(spark.streams.active.headOption) + + if (stream.isDefined) { + val name = stream.get.name // use streaming job name because indexName maybe None + logInfo(s"Awaiting streaming job $name until terminated") + try { + stream.get.awaitTermination() + logInfo(s"Streaming job $name terminated without exception") + } catch { + case e: Throwable => + // Transit to failed state. TODO: determine state code on exception type + logError(s"Streaming job $name terminated with exception", e) + flintClient + .startTransaction(name, dataSourceName) + .initialLog(latest => latest.state == REFRESHING) + .finalLog(latest => latest.copy(state = FAILED, error = extractRootCause(e))) + .commit(_ => {}) + } + } else { + logInfo(s"Index monitor for [$indexName] not found") + } + } + /** * Index monitor task that encapsulates the execution logic with number of consecutive error * tracked. @@ -106,12 +142,6 @@ class FlintSparkIndexMonitor( .commit(_ => {}) } else { logError("Streaming job is not active. Cancelling monitor task") - flintClient - .startTransaction(indexName, dataSourceName) - .initialLog(_ => true) - .finalLog(latest => latest.copy(state = FAILED)) - .commit(_ => {}) - stopMonitor(indexName) logInfo("Index monitor task is cancelled") } @@ -144,6 +174,21 @@ class FlintSparkIndexMonitor( logWarning("Refreshing job not found") } } + + private def extractRootCause(e: Throwable): String = { + var cause = e + while (cause.getCause != null && cause.getCause != cause) { + cause = cause.getCause + } + + if (cause.getLocalizedMessage != null) { + return cause.getLocalizedMessage + } + if (cause.getMessage != null) { + return cause.getMessage + } + cause.toString + } } object FlintSparkIndexMonitor extends Logging { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala index 01dc63e00..1ef84b3c0 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -12,7 +12,6 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{doAnswer, spy} -import org.opensearch.action.admin.indices.delete.DeleteIndexRequest import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite @@ -105,8 +104,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc spark.streams.active.find(_.name == testFlintIndex).get.stop() waitForMonitorTaskRun() - // Index state transit to failed and task is cancelled - latestLogEntry(testLatestId) should contain("state" -> "failed") + // Monitor task should be cancelled task.isCancelled shouldBe true } @@ -150,6 +148,46 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc spark.streams.active.exists(_.name == testFlintIndex) shouldBe false } + test("await monitor terminated without exception should stay refreshing state") { + // Setup a timer to terminate the streaming job + new Thread(() => { + Thread.sleep(3000L) + spark.streams.active.find(_.name == testFlintIndex).get.stop() + }).start() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "refreshing") + } + + test("await monitor terminated with exception should update index state to failed") { + // Simulate an exception in the streaming job + new Thread(() => { + Thread.sleep(3000L) + + val settings = Map("index.blocks.write" -> true) + val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava) + openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT) + + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=6) + | VALUES ('Test', 35, 'Vancouver') + | """.stripMargin) + }).start() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should (contain("state" -> "failed") and contain key "error") + latestLog("error").asInstanceOf[String] should include("OpenSearchException") + } + private def getLatestTimestamp: (Long, Long) = { val latest = latestLogEntry(testLatestId) (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 6421c7d57..8c06593d2 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -14,6 +14,7 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter +import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndexMonitor} import org.apache.spark.internal.Logging import org.apache.spark.sql.flint.config.FlintSparkConf @@ -92,7 +93,7 @@ case class JobOperator( // Clean Spark shuffle data after each microBatch. spark.streams.addListener(new ShuffleCleaner(spark)) // wait if any child thread to finish before the main thread terminates - spark.streams.awaitAnyTermination() + new FlintSpark(spark).flintIndexMonitor.awaitMonitor() } } catch { case e: Exception => logError("streaming job failed", e)