Skip to content

Commit

Permalink
Add more logging and IT on FlintSpark API layer
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Oct 31, 2023
1 parent 1ac46b1 commit 2af637f
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ class FlintSpark(val spark: SparkSession) extends Logging {
if (latest == null) { // in case transaction capability is disabled
flintClient.createIndex(indexName, metadata)
} else {
logInfo(s"Creating index with metadata log entry ID ${latest.id}")
flintClient.createIndex(indexName, metadata.copy(latestId = Some(latest.id)))
})
logInfo("Create index complete")
} catch {
case e: Exception =>
logError("Failed to create Flint index", e)
Expand Down Expand Up @@ -146,9 +148,11 @@ class FlintSpark(val spark: SparkSession) extends Logging {
.finalLog(latest => {
// Change state to active if full, otherwise update index state regularly
if (mode == FULL) {
logInfo("Updating index state to active")
latest.copy(state = ACTIVE)
} else {
// Schedule regular update and return log entry as refreshing state
logInfo("Scheduling index state updater")
scheduleIndexStateUpdate(indexName)
latest
}
Expand Down Expand Up @@ -240,7 +244,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
* @param indexName
* index name
*/
def recoverIndex(indexName: String): Unit = {
def recoverIndex(indexName: String): Boolean = {
logInfo(s"Recovering Flint index $indexName")
val index = describeIndex(indexName)
if (index.exists(_.options.autoRefresh())) {
Expand All @@ -253,13 +257,17 @@ class FlintSpark(val spark: SparkSession) extends Logging {
latest.copy(state = REFRESHING)
})
.commit(_ => doRefreshIndex(index.get, indexName, INCREMENTAL))

logInfo("Recovery complete")
true
} catch {
case e: Exception =>
logError("Failed to recover Flint index", e)
throw new IllegalStateException("Failed to recover Flint index")
}
} else {
logInfo("Index to be recovered either doesn't exist or auto refreshed")
logInfo("Index to be recovered either doesn't exist or not auto refreshed")
false
}
}

Expand Down Expand Up @@ -304,6 +312,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
index: FlintSparkIndex,
indexName: String,
mode: RefreshMode): Option[String] = {
logInfo(s"Refreshing index $indexName in $mode mode")
val options = index.options
val tableName = index.metadata().source

Expand All @@ -318,17 +327,19 @@ class FlintSpark(val spark: SparkSession) extends Logging {
.save(indexName)
}

mode match {
val jobId = mode match {
case FULL if isIncrementalRefreshing(indexName) =>
throw new IllegalStateException(
s"Index $indexName is incremental refreshing and cannot be manual refreshed")

case FULL =>
logInfo("Start refreshing index in batch style")
batchRefresh()
None

// Flint index has specialized logic and capability for incremental refresh
case INCREMENTAL if index.isInstanceOf[StreamingRefresh] =>
logInfo("Start refreshing index in streaming style")
val job =
index
.asInstanceOf[StreamingRefresh]
Expand All @@ -343,6 +354,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {

// Otherwise, fall back to foreachBatch + batch refresh
case INCREMENTAL =>
logInfo("Start refreshing index in foreach streaming style")
val job = spark.readStream
.options(options.extraSourceOptions(tableName))
.table(tableName)
Expand All @@ -355,6 +367,9 @@ class FlintSpark(val spark: SparkSession) extends Logging {
.start()
Some(job.id.toString)
}

logInfo("Refresh index complete")
jobId
}

private def stopRefreshingJob(indexName: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import java.util.Base64

import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.FAILED
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.should.Matchers

class FlintSparkIndexJobITSuite extends OpenSearchTransactionSuite with Matchers {

/** Test table and index name */
private val testTable = "spark_catalog.default.test"
private val testIndex = getSkippingIndexName(testTable)

override def beforeAll(): Unit = {
super.beforeAll()
createPartitionedTable(testTable)
}

override def afterEach(): Unit = {
super.afterEach()
flint.deleteIndex(testIndex)
}

test("recover should exit if index doesn't exist") {
flint.recoverIndex("non_exist_index") shouldBe false
}

test("recover should exit if index is not auto refreshed") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.create()
flint.recoverIndex(testIndex) shouldBe false
}

test("recover should succeed if index exists and is auto refreshed") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.options(FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
.create()

flint.recoverIndex(testIndex) shouldBe true
spark.streams.active.exists(_.name == testIndex)
}

test("recover should succeed even if index is in failed state") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.options(FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
.create()

val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes)
updateLatestLogEntry(
new FlintMetadataLogEntry(latestId, 1, 1, latestLogEntry(latestId).asJava),
FAILED)

flint.recoverIndex(testIndex) shouldBe true
spark.streams.active.exists(_.name == testIndex)
}
}

0 comments on commit 2af637f

Please sign in to comment.