diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 9cd5f60a7..848bbe61f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -54,7 +54,7 @@ class FlintSpark(val spark: SparkSession) extends Logging { spark.conf.getOption("spark.flint.datasource.name").getOrElse("") /** Flint Spark index monitor */ - private val flintIndexMonitor: FlintSparkIndexMonitor = + val flintIndexMonitor: FlintSparkIndexMonitor = new FlintSparkIndexMonitor(spark, flintClient, dataSourceName) /** @@ -250,7 +250,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { try { flintClient .startTransaction(indexName, dataSourceName) - .initialLog(latest => latest.state == ACTIVE || latest.state == REFRESHING) + .initialLog(latest => + latest.state == ACTIVE || latest.state == REFRESHING || latest.state == FAILED) .transientLog(latest => latest.copy(state = DELETING)) .finalLog(latest => latest.copy(state = DELETED)) .commit(_ => { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 02cbfd7b1..d3f3ff0ee 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -5,11 +5,17 @@ package org.opensearch.flint.spark +import java.time.Duration +import java.time.temporal.ChronoUnit.SECONDS +import java.util.Collections.singletonList import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} import scala.collection.concurrent.{Map, TrieMap} import scala.sys.addShutdownHook +import dev.failsafe.{Failsafe, RetryPolicy} +import dev.failsafe.event.ExecutionAttemptedEvent +import dev.failsafe.function.CheckedRunnable import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING} import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} @@ -82,6 +88,60 @@ class FlintSparkIndexMonitor( } } + /** + * Waits for the termination of a Spark streaming job associated with a specified index name and + * updates the Flint index state based on the outcome of the job. + * + * @param indexName + * The name of the index to monitor. If none is specified, the method will default to any + * active stream. + */ + def awaitMonitor(indexName: Option[String] = None): Unit = { + logInfo(s"Awaiting index monitor for $indexName") + + // Find streaming job for the given index name, otherwise use the first if any + val job = indexName + .flatMap(name => spark.streams.active.find(_.name == name)) + .orElse(spark.streams.active.headOption) + + if (job.isDefined) { // must be present after DataFrameWriter.start() called in refreshIndex API + val name = job.get.name // use streaming job name because indexName maybe None + logInfo(s"Awaiting streaming job $name until terminated") + try { + + /** + * Await termination of the streaming job. Do not transition the index state to ACTIVE + * post-termination to prevent conflicts with ongoing transactions in DROP/ALTER API + * operations. It's generally expected that the job will be terminated through a DROP or + * ALTER operation if no exceptions are thrown. + */ + job.get.awaitTermination() + logInfo(s"Streaming job $name terminated without exception") + } catch { + case e: Throwable => + /** + * Transition the index state to FAILED upon encountering an exception. Retry in case + * conflicts with final transaction in scheduled task. + * ``` + * TODO: + * 1) Determine the appropriate state code based on the type of exception encountered + * 2) Record and persist the error message of the root cause for further diagnostics. + * ``` + */ + logError(s"Streaming job $name terminated with exception", e) + retry { + flintClient + .startTransaction(name, dataSourceName) + .initialLog(latest => latest.state == REFRESHING) + .finalLog(latest => latest.copy(state = FAILED)) + .commit(_ => {}) + } + } + } else { + logInfo(s"Index monitor for [$indexName] not found") + } + } + /** * Index monitor task that encapsulates the execution logic with number of consecutive error * tracked. @@ -106,12 +166,6 @@ class FlintSparkIndexMonitor( .commit(_ => {}) } else { logError("Streaming job is not active. Cancelling monitor task") - flintClient - .startTransaction(indexName, dataSourceName) - .initialLog(_ => true) - .finalLog(latest => latest.copy(state = FAILED)) - .commit(_ => {}) - stopMonitor(indexName) logInfo("Index monitor task is cancelled") } @@ -144,6 +198,26 @@ class FlintSparkIndexMonitor( logWarning("Refreshing job not found") } } + + private def retry(operation: => Unit): Unit = { + // Retry policy for 3 times every 1 second + val retryPolicy = RetryPolicy + .builder[Unit]() + .handle(classOf[Throwable]) + .withBackoff(1, 30, SECONDS) + .withJitter(Duration.ofMillis(100)) + .withMaxRetries(3) + .onFailedAttempt((event: ExecutionAttemptedEvent[Unit]) => + logError("Attempt to update index state failed: " + event)) + .build() + + // Use the retry policy with Failsafe + Failsafe + .`with`(singletonList(retryPolicy)) + .run(new CheckedRunnable { + override def run(): Unit = operation + }) + } } object FlintSparkIndexMonitor extends Logging { diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala index 9697588d4..990b9e449 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -5,21 +5,25 @@ package org.apache.spark.sql +import java.util.{Base64, Collections} import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.{Duration, MINUTES} import scala.util.{Failure, Success} -import scala.util.control.Breaks.{break, breakable} +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest +import org.opensearch.action.get.GetRequest +import org.opensearch.client.RequestOptions import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.spark.FlintSparkSuite import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName -import org.scalatest.matchers.must.Matchers.{defined, have} -import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} +import org.scalatest.matchers.must.Matchers.{contain, defined} +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE +import org.apache.spark.sql.flint.config.FlintSparkConf._ import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils @@ -38,6 +42,10 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { super.beforeAll() // initialized after the container is started osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) + } + + protected override def beforeEach(): Unit = { + super.beforeEach() createPartitionedMultiRowAddressTable(testTable) } @@ -45,8 +53,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { super.afterEach() deleteTestIndex(testIndex) - - waitJobStop(threadLocalFuture.get()) + sql(s"DROP TABLE $testTable") threadLocalFuture.remove() } @@ -72,6 +79,17 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { val streamingRunningCount = new AtomicInteger(0) val futureResult = Future { + /* + * Because we cannot test from FlintJob.main() for the reason below, we have to configure + * all Spark conf required by Flint code underlying manually. + */ + spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName) + spark.conf.set(JOB_TYPE.key, "streaming") + + /** + * FlintJob.main() is not called because we need to manually set these variables within a + * JobOperator instance to accommodate specific runtime requirements. + */ val job = JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount) job.envinromentProvider = new MockEnvironment( @@ -130,6 +148,53 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { indexData.count() shouldBe 2 } + test("create skipping index with auto refresh and streaming job failure") { + val query = + s""" + | CREATE SKIPPING INDEX ON $testTable + | ( year PARTITION ) + | WITH (auto_refresh = true) + | """.stripMargin + val jobRunId = "00ff4o3b5091080q" + threadLocalFuture.set(startJob(query, jobRunId)) + + // Waiting from streaming job start and complete current batch in Future thread in startJob + // Otherwise, active job will be None here + Thread.sleep(5000L) + pollForResultAndAssert(_ => true, jobRunId) + val activeJob = spark.streams.active.find(_.name == testIndex) + activeJob shouldBe defined + awaitStreamingComplete(activeJob.get.id.toString) + + // Wait in case JobOperator has not reached condition check before awaitTermination + Thread.sleep(5000L) + try { + // Set Flint index readonly to simulate streaming job exception + setFlintIndexReadOnly(true) + + // Trigger a new micro batch execution + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=6) + | SELECT * + | FROM VALUES ('Test', 35, 'Seattle') + |""".stripMargin) + try { + awaitStreamingComplete(activeJob.get.id.toString) + } catch { + case _: Exception => // expected + } + + // Assert Flint index transitioned to FAILED state after waiting seconds + Thread.sleep(2000L) + val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes) + latestLogEntry(latestId) should contain("state" -> "failed") + } finally { + // Reset so Flint index can be cleaned up in afterEach + setFlintIndexReadOnly(false) + } + } + test("create skipping index with non-existent table") { val query = s""" @@ -260,4 +325,23 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { streamingTimeout.toMillis, resultIndex) } + + private def setFlintIndexReadOnly(readonly: Boolean): Unit = { + logInfo(s"Updating index $testIndex setting with readonly [$readonly]") + openSearchClient + .indices() + .putSettings( + new UpdateSettingsRequest(testIndex).settings( + Map("index.blocks.write" -> readonly).asJava), + RequestOptions.DEFAULT) + } + + private def latestLogEntry(latestId: String): Map[String, AnyRef] = { + val response = openSearchClient + .get( + new GetRequest(s".query_execution_request_$dataSourceName", latestId), + RequestOptions.DEFAULT) + + Option(response.getSourceAsMap).getOrElse(Collections.emptyMap()).asScala.toMap + } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala index 01dc63e00..2627ed964 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -12,7 +12,6 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{doAnswer, spy} -import org.opensearch.action.admin.indices.delete.DeleteIndexRequest import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite @@ -105,8 +104,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc spark.streams.active.find(_.name == testFlintIndex).get.stop() waitForMonitorTaskRun() - // Index state transit to failed and task is cancelled - latestLogEntry(testLatestId) should contain("state" -> "failed") + // Monitor task should be cancelled task.isCancelled shouldBe true } @@ -150,6 +148,46 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc spark.streams.active.exists(_.name == testFlintIndex) shouldBe false } + test("await monitor terminated without exception should stay refreshing state") { + // Setup a timer to terminate the streaming job + new Thread(() => { + Thread.sleep(3000L) + spark.streams.active.find(_.name == testFlintIndex).get.stop() + }).start() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "refreshing") + } + + test("await monitor terminated with exception should update index state to failed") { + new Thread(() => { + Thread.sleep(3000L) + + // Set Flint index readonly to simulate streaming job exception + val settings = Map("index.blocks.write" -> true) + val request = new UpdateSettingsRequest(testFlintIndex).settings(settings.asJava) + openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT) + + // Trigger a new micro batch execution + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=6) + | VALUES ('Test', 35, 'Vancouver') + | """.stripMargin) + }).start() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "failed") + } + private def getLatestTimestamp: (Long, Long) = { val latest = latestLogEntry(testLatestId) (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 6421c7d57..3582bcf09 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -14,6 +14,7 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter +import org.opensearch.flint.spark.FlintSpark import org.apache.spark.internal.Logging import org.apache.spark.sql.flint.config.FlintSparkConf @@ -91,14 +92,22 @@ case class JobOperator( if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) { // Clean Spark shuffle data after each microBatch. spark.streams.addListener(new ShuffleCleaner(spark)) - // wait if any child thread to finish before the main thread terminates - spark.streams.awaitAnyTermination() + // Await streaming job thread to finish before the main thread terminates + new FlintSpark(spark).flintIndexMonitor.awaitMonitor() + } else { + logInfo(s""" + | Skip streaming job await due to conditions not met: + | - exceptionThrown: $exceptionThrown + | - streaming: $streaming + | - activeStreams: ${spark.streams.active.mkString(",")} + |""".stripMargin) } } catch { case e: Exception => logError("streaming job failed", e) } try { + logInfo("Thread pool is being shut down") threadPool.shutdown() logInfo("shut down thread threadpool") } catch { @@ -121,8 +130,9 @@ case class JobOperator( def stop(): Unit = { Try { + logInfo("Stopping Spark session") spark.stop() - logInfo("stopped spark session") + logInfo("Stopped Spark session") } match { case Success(_) => case Failure(e) => logError("unexpected error while stopping spark session", e)