From 0721becc71108ba589d56d4cfa9c7d23807261f4 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 30 May 2024 09:57:38 -0700 Subject: [PATCH] Add await monitor API and IT Signed-off-by: Chen Dai --- .../opensearch/flint/spark/FlintSpark.scala | 2 +- .../flint/spark/FlintSparkIndexMonitor.scala | 57 +++++++++++++++++-- .../spark/FlintSparkIndexMonitorITSuite.scala | 43 +++++++++++++- .../org/apache/spark/sql/JobOperator.scala | 6 +- 4 files changed, 96 insertions(+), 12 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 9cd5f60a7..083aa5d16 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -54,7 +54,7 @@ class FlintSpark(val spark: SparkSession) extends Logging { spark.conf.getOption("spark.flint.datasource.name").getOrElse("") /** Flint Spark index monitor */ - private val flintIndexMonitor: FlintSparkIndexMonitor = + val flintIndexMonitor: FlintSparkIndexMonitor = new FlintSparkIndexMonitor(spark, flintClient, dataSourceName) /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 02cbfd7b1..0a325f7b0 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -82,6 +82,57 @@ class FlintSparkIndexMonitor( } } + /** + * Waits for the termination of a Spark streaming job associated with a specified index name and + * updates the Flint index state based on the outcome of the job. + * + * @param indexName + * The name of the index to monitor. If none is specified, the method will default to any + * active stream. + */ + def awaitMonitor(indexName: Option[String] = None): Unit = { + logInfo(s"Awaiting index monitor for $indexName") + + // Find streaming job for the given index name, otherwise use the first if any + val job = indexName + .flatMap(name => spark.streams.active.find(_.name == name)) + .orElse(spark.streams.active.headOption) + + if (job.isDefined) { + val name = job.get.name // use streaming job name because indexName maybe None + logInfo(s"Awaiting streaming job $name until terminated") + try { + + /** + * Await termination of the streaming job. Do not transition the index state to ACTIVE + * post-termination to prevent conflicts with ongoing transactions in DROP/ALTER API + * operations. It's generally expected that the job will be terminated through a DROP or + * ALTER operation if no exceptions are thrown. + */ + job.get.awaitTermination() + logInfo(s"Streaming job $name terminated without exception") + } catch { + case e: Throwable => + /** + * Transition the index state to FAILED upon encountering an exception. + * ``` + * TODO: + * 1) Determine the appropriate state code based on the type of exception encountered + * 2) Record and persist the error message of the root cause for further diagnostics. + * ``` + */ + logError(s"Streaming job $name terminated with exception", e) + flintClient + .startTransaction(name, dataSourceName) + .initialLog(latest => latest.state == REFRESHING) + .finalLog(latest => latest.copy(state = FAILED)) + .commit(_ => {}) + } + } else { + logInfo(s"Index monitor for [$indexName] not found") + } + } + /** * Index monitor task that encapsulates the execution logic with number of consecutive error * tracked. @@ -106,12 +157,6 @@ class FlintSparkIndexMonitor( .commit(_ => {}) } else { logError("Streaming job is not active. Cancelling monitor task") - flintClient - .startTransaction(indexName, dataSourceName) - .initialLog(_ => true) - .finalLog(latest => latest.copy(state = FAILED)) - .commit(_ => {}) - stopMonitor(indexName) logInfo("Index monitor task is cancelled") } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala index 01dc63e00..4bfaaf785 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -12,7 +12,6 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{doAnswer, spy} -import org.opensearch.action.admin.indices.delete.DeleteIndexRequest import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite @@ -105,8 +104,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc spark.streams.active.find(_.name == testFlintIndex).get.stop() waitForMonitorTaskRun() - // Index state transit to failed and task is cancelled - latestLogEntry(testLatestId) should contain("state" -> "failed") + // Monitor task should be cancelled task.isCancelled shouldBe true } @@ -150,6 +148,45 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc spark.streams.active.exists(_.name == testFlintIndex) shouldBe false } + test("await monitor terminated without exception should stay refreshing state") { + // Setup a timer to terminate the streaming job + new Thread(() => { + Thread.sleep(3000L) + spark.streams.active.find(_.name == testFlintIndex).get.stop() + }).start() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "refreshing") + } + + test("await monitor terminated with exception should update index state to failed") { + // Simulate an exception in the streaming job + new Thread(() => { + Thread.sleep(3000L) + + val settings = Map("index.blocks.write" -> true) + val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava) + openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT) + + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=6) + | VALUES ('Test', 35, 'Vancouver') + | """.stripMargin) + }).start() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "failed") + } + private def getLatestTimestamp: (Long, Long) = { val latest = latestLogEntry(testLatestId) (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 6421c7d57..bffcc8b60 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -14,6 +14,7 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter +import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndexMonitor} import org.apache.spark.internal.Logging import org.apache.spark.sql.flint.config.FlintSparkConf @@ -91,14 +92,15 @@ case class JobOperator( if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) { // Clean Spark shuffle data after each microBatch. spark.streams.addListener(new ShuffleCleaner(spark)) - // wait if any child thread to finish before the main thread terminates - spark.streams.awaitAnyTermination() + // Await streaming job thread to finish before the main thread terminates + new FlintSpark(spark).flintIndexMonitor.awaitMonitor() } } catch { case e: Exception => logError("streaming job failed", e) } try { + logInfo("Thread pool is being shut down") threadPool.shutdown() logInfo("shut down thread threadpool") } catch {