diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index d3f3ff0ee..594f99b02 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -80,7 +80,8 @@ class FlintSparkIndexMonitor( */ def stopMonitor(indexName: String): Unit = { logInfo(s"Cancelling scheduled task for index $indexName") - val task = FlintSparkIndexMonitor.indexMonitorTracker.remove(indexName) + // Hack: Don't remove because awaitMonitor API requires Flint index name. + val task = FlintSparkIndexMonitor.indexMonitorTracker.get(indexName) if (task.isDefined) { task.get.cancel(true) } else { @@ -119,26 +120,25 @@ class FlintSparkIndexMonitor( logInfo(s"Streaming job $name terminated without exception") } catch { case e: Throwable => - /** - * Transition the index state to FAILED upon encountering an exception. Retry in case - * conflicts with final transaction in scheduled task. - * ``` - * TODO: - * 1) Determine the appropriate state code based on the type of exception encountered - * 2) Record and persist the error message of the root cause for further diagnostics. - * ``` - */ - logError(s"Streaming job $name terminated with exception", e) - retry { - flintClient - .startTransaction(name, dataSourceName) - .initialLog(latest => latest.state == REFRESHING) - .finalLog(latest => latest.copy(state = FAILED)) - .commit(_ => {}) - } + logError(s"Streaming job $name terminated with exception: ${e.getMessage}") + retryUpdateIndexStateToFailed(name) } } else { - logInfo(s"Index monitor for [$indexName] not found") + logInfo(s"Index monitor for [$indexName] not found.") + + /* + * Streaming job exits early. Try to find Flint index name in monitor list. + * Assuming: 1) there are at most 1 entry in the list, otherwise index name + * must be given upon this method call; 2) this await API must be called for + * auto refresh index, otherwise index state will be updated mistakenly. + */ + val name = FlintSparkIndexMonitor.indexMonitorTracker.keys.headOption + if (name.isDefined) { + logInfo(s"Found index name in index monitor task list: ${name.get}") + retryUpdateIndexStateToFailed(name.get) + } else { + logInfo(s"Index monitor task list is empty") + } } } @@ -199,6 +199,26 @@ class FlintSparkIndexMonitor( } } + /** + * Transition the index state to FAILED upon encountering an exception. Retry in case conflicts + * with final transaction in scheduled task. + * ``` + * TODO: + * 1) Determine the appropriate state code based on the type of exception encountered + * 2) Record and persist the error message of the root cause for further diagnostics. + * ``` + */ + private def retryUpdateIndexStateToFailed(indexName: String): Unit = { + logInfo(s"Updating index state to failed for $indexName") + retry { + flintClient + .startTransaction(indexName, dataSourceName) + .initialLog(latest => latest.state == REFRESHING) + .finalLog(latest => latest.copy(state = FAILED)) + .commit(_ => {}) + } + } + private def retry(operation: => Unit): Unit = { // Retry policy for 3 times every 1 second val retryPolicy = RetryPolicy diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala index 990b9e449..7318e5c7c 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -17,13 +17,15 @@ import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest import org.opensearch.action.get.GetRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.core.FlintOptions -import org.opensearch.flint.spark.FlintSparkSuite +import org.opensearch.flint.spark.{FlintSparkIndexMonitor, FlintSparkSuite} import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.scalatest.matchers.must.Matchers.{contain, defined} import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.flint.config.FlintSparkConf._ +import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils @@ -46,6 +48,11 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { protected override def beforeEach(): Unit = { super.beforeEach() + + // Clear up because awaitMonitor will assume single name in tracker + FlintSparkIndexMonitor.indexMonitorTracker.values.foreach(_.cancel(true)) + FlintSparkIndexMonitor.indexMonitorTracker.clear() + createPartitionedMultiRowAddressTable(testTable) } @@ -195,6 +202,42 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { } } + test("create skipping index with auto refresh and streaming job early exit") { + // Custom listener to force streaming job to fail at the beginning + val listener = new StreamingQueryListener { + override def onQueryStarted(event: QueryStartedEvent): Unit = { + logInfo("Stopping streaming job intentionally") + spark.streams.active.find(_.name == event.name).get.stop() + } + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {} + } + + try { + spark.streams.addListener(listener) + val query = + s""" + | CREATE SKIPPING INDEX ON $testTable + | (name VALUE_SET) + | WITH (auto_refresh = true) + | """.stripMargin + val jobRunId = "00ff4o3b5091080q" + threadLocalFuture.set(startJob(query, jobRunId)) + + // Assert streaming job must exit + Thread.sleep(5000) + pollForResultAndAssert(_ => true, jobRunId) + spark.streams.active.exists(_.name == testIndex) shouldBe false + + // Assert Flint index transitioned to FAILED state after waiting seconds + Thread.sleep(2000L) + val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes) + latestLogEntry(latestId) should contain("state" -> "failed") + } finally { + spark.streams.removeListener(listener) + } + } + test("create skipping index with non-existent table") { val query = s""" diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala index 2627ed964..1e2d68b8e 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -188,6 +188,19 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc latestLog should contain("state" -> "failed") } + test( + "await monitor terminated with streaming job exit early should update index state to failed") { + // Terminate streaming job intentionally before await + spark.streams.active.find(_.name == testFlintIndex).get.stop() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "failed") + } + private def getLatestTimestamp: (Long, Long) = { val latest = latestLogEntry(testLatestId) (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 3582bcf09..f315dc836 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -88,11 +88,11 @@ case class JobOperator( } try { - // Wait for streaming job complete if no error and there is streaming job running - if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) { + // Wait for streaming job complete if no error + if (!exceptionThrown && streaming) { // Clean Spark shuffle data after each microBatch. spark.streams.addListener(new ShuffleCleaner(spark)) - // Await streaming job thread to finish before the main thread terminates + // Await index monitor before the main thread terminates new FlintSpark(spark).flintIndexMonitor.awaitMonitor() } else { logInfo(s"""