Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Flint index refresh mode #228

Merged
merged 6 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ flint.skippingIndex()
.addMinMax("request_processing_time")
.create()

flint.refreshIndex("flint_spark_catalog_default_alb_logs_skipping_index", FULL)
flint.refreshIndex("flint_spark_catalog_default_alb_logs_skipping_index")

// Covering index
flint.coveringIndex()
Expand Down Expand Up @@ -539,10 +539,6 @@ CREATE INDEX Idx_elb ON alb_logs ...

For now, only single or conjunct conditions (conditions connected by AND) in WHERE clause can be optimized by skipping index.

### Index Refresh Job Management

Manual refreshing a table which already has skipping index being auto-refreshed, will be prevented. However, this assumption relies on the condition that the incremental refresh job is actively running in the same Spark cluster, which can be identified when performing the check.

## Integration

### AWS EMR Spark Integration - Using execution role
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._
import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NO_LOG_ENTRY
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{AUTO, MANUAL, RefreshMode}
import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
Expand Down Expand Up @@ -135,10 +135,11 @@ class FlintSpark(val spark: SparkSession) extends Logging {
* @return
* refreshing job ID (empty if batch job for now)
*/
def refreshIndex(indexName: String, mode: RefreshMode): Option[String] = {
logInfo(s"Refreshing Flint index $indexName with mode $mode")
def refreshIndex(indexName: String): Option[String] = {
logInfo(s"Refreshing Flint index $indexName")
val index = describeIndex(indexName)
.getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist"))
val mode = if (index.options.autoRefresh()) AUTO else MANUAL

try {
flintClient
Expand All @@ -148,7 +149,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
latest.copy(state = REFRESHING, createTime = System.currentTimeMillis()))
.finalLog(latest => {
// Change state to active if full, otherwise update index state regularly
if (mode == FULL) {
if (mode == MANUAL) {
logInfo("Updating index state to active")
latest.copy(state = ACTIVE)
} else {
Expand Down Expand Up @@ -291,7 +292,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
flintIndexMonitor.startMonitor(indexName)
latest.copy(state = REFRESHING)
})
.commit(_ => doRefreshIndex(index.get, indexName, INCREMENTAL))
.commit(_ => doRefreshIndex(index.get, indexName, AUTO))

logInfo("Recovery complete")
true
Expand All @@ -318,9 +319,6 @@ class FlintSpark(val spark: SparkSession) extends Logging {
spark.read.format(FLINT_DATASOURCE).load(indexName)
}

private def isIncrementalRefreshing(indexName: String): Boolean =
spark.streams.active.exists(_.name == indexName)

// TODO: move to separate class
private def doRefreshIndex(
index: FlintSparkIndex,
Expand All @@ -342,17 +340,13 @@ class FlintSpark(val spark: SparkSession) extends Logging {
}

val jobId = mode match {
case FULL if isIncrementalRefreshing(indexName) =>
throw new IllegalStateException(
s"Index $indexName is incremental refreshing and cannot be manual refreshed")
Comment on lines -345 to -347
Copy link
Collaborator Author

@dai-chen dai-chen Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because refresh mode is decided by index options in refresh index API internally, this case is impossible now.


case FULL =>
case MANUAL =>
logInfo("Start refreshing index in batch style")
batchRefresh()
None

// Flint index has specialized logic and capability for incremental refresh
case INCREMENTAL if index.isInstanceOf[StreamingRefresh] =>
case AUTO if index.isInstanceOf[StreamingRefresh] =>
logInfo("Start refreshing index in streaming style")
val job =
index
Expand All @@ -367,7 +361,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
Some(job.id.toString)

// Otherwise, fall back to foreachBatch + batch refresh
case INCREMENTAL =>
case AUTO =>
logInfo("Start refreshing index in foreach streaming style")
val job = spark.readStream
.options(options.extraSourceOptions(tableName))
Expand Down Expand Up @@ -437,6 +431,6 @@ object FlintSpark {
*/
object RefreshMode extends Enumeration {
type RefreshMode = Value
val FULL, INCREMENTAL = Value
val MANUAL, AUTO = Value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
// Trigger auto refresh if enabled
if (indexOptions.autoRefresh()) {
val flintIndexName = getFlintIndexName(flint, ctx.indexName, ctx.tableName)
flint.refreshIndex(flintIndexName, RefreshMode.INCREMENTAL)
flint.refreshIndex(flintIndexName)
}
Seq.empty
}
Expand All @@ -63,7 +63,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
ctx: RefreshCoveringIndexStatementContext): Command = {
FlintSparkSqlCommand() { flint =>
val flintIndexName = getFlintIndexName(flint, ctx.indexName, ctx.tableName)
flint.refreshIndex(flintIndexName, RefreshMode.FULL)
flint.refreshIndex(flintIndexName)
Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito
// Trigger auto refresh if enabled
if (indexOptions.autoRefresh()) {
val flintIndexName = getFlintIndexName(flint, ctx.mvName)
flint.refreshIndex(flintIndexName, RefreshMode.INCREMENTAL)
flint.refreshIndex(flintIndexName)
}
Seq.empty
}
Expand All @@ -56,7 +56,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito
ctx: RefreshMaterializedViewStatementContext): Command = {
FlintSparkSqlCommand() { flint =>
val flintIndexName = getFlintIndexName(flint, ctx.mvName)
flint.refreshIndex(flintIndexName, RefreshMode.FULL)
flint.refreshIndex(flintIndexName)
Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
// Trigger auto refresh if enabled
if (indexOptions.autoRefresh()) {
val indexName = getSkippingIndexName(flint, ctx.tableName)
flint.refreshIndex(indexName, RefreshMode.INCREMENTAL)
flint.refreshIndex(indexName)
}
Seq.empty
}
Expand All @@ -74,7 +74,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
ctx: RefreshSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
val indexName = getSkippingIndexName(flint, ctx.tableName)
flint.refreshIndex(indexName, RefreshMode.FULL)
flint.refreshIndex(indexName)
Seq.empty
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.opensearch.flint.spark

import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson
import org.opensearch.flint.core.FlintVersion.current
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName
import org.scalatest.matchers.must.Matchers.defined
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
Expand Down Expand Up @@ -85,7 +84,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite {
.addIndexColumns("name", "age")
.create()

flint.refreshIndex(testFlintIndex, FULL)
flint.refreshIndex(testFlintIndex)

val indexData = flint.queryIndex(testFlintIndex)
checkAnswer(indexData, Seq(Row("Hello", 30), Row("World", 25)))
Expand All @@ -97,9 +96,10 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite {
.name(testIndex)
.onTable(testTable)
.addIndexColumns("name", "age")
.options(FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
.create()

val jobId = flint.refreshIndex(testFlintIndex, INCREMENTAL)
val jobId = flint.refreshIndex(testFlintIndex)
jobId shouldBe defined

val job = spark.streams.get(jobId.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest
import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.spark.FlintSpark.RefreshMode._
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.should.Matchers

Expand Down Expand Up @@ -51,7 +50,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
.addValueSet("name")
.options(FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
.create()
flint.refreshIndex(testFlintIndex, INCREMENTAL)
flint.refreshIndex(testFlintIndex)

// Wait for refresh complete and another 5 seconds to make sure monitor thread start
val jobId = spark.streams.active.find(_.name == testFlintIndex).get.id.toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import java.sql.Timestamp

import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson
import org.opensearch.flint.core.FlintVersion.current
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL}
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName
import org.scalatest.matchers.must.Matchers.defined
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
Expand Down Expand Up @@ -99,7 +98,7 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite {
.query(testQuery)
.create()

flint.refreshIndex(testFlintIndex, FULL)
flint.refreshIndex(testFlintIndex)

val indexData = flint.queryIndex(testFlintIndex)
checkAnswer(
Expand Down Expand Up @@ -209,7 +208,7 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite {
.create()

flint
.refreshIndex(testFlintIndex, INCREMENTAL)
.refreshIndex(testFlintIndex)
.map(awaitStreamingComplete)
.orElse(throw new RuntimeException)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package org.opensearch.flint.spark
import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson
import org.json4s.native.JsonMethods._
import org.opensearch.flint.core.FlintVersion.current
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL}
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
Expand Down Expand Up @@ -150,7 +149,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.onTable(testTable)
.addPartitions("year")
.create()
flint.refreshIndex(testIndex, FULL)
flint.refreshIndex(testIndex)

val indexData = flint.queryIndex(testIndex)
indexData.columns should not contain ID_COLUMN
Expand All @@ -164,7 +163,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.addPartitions("year", "month")
.create()

val jobId = flint.refreshIndex(testIndex, FULL)
val jobId = flint.refreshIndex(testIndex)
jobId shouldBe empty

val indexData = flint.queryIndex(testIndex).collect().toSet
Expand All @@ -177,9 +176,10 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.skippingIndex()
.onTable(testTable)
.addPartitions("year", "month")
.options(FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
.create()

val jobId = flint.refreshIndex(testIndex, INCREMENTAL)
val jobId = flint.refreshIndex(testIndex)
jobId shouldBe defined

val job = spark.streams.get(jobId.get)
Expand All @@ -191,24 +191,6 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
indexData should have size 2
}

test("should fail to manual refresh an incremental refreshing index") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year", "month")
.create()

val jobId = flint.refreshIndex(testIndex, INCREMENTAL)
val job = spark.streams.get(jobId.get)
failAfter(streamingTimeout) {
job.processAllAvailable()
}

assertThrows[IllegalStateException] {
flint.refreshIndex(testIndex, FULL)
}
}

test("can have only 1 skipping index on a table") {
flint
.skippingIndex()
Expand Down Expand Up @@ -257,7 +239,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.onTable(testTable)
.addPartitions("year", "month")
.create()
flint.refreshIndex(testIndex, FULL)
flint.refreshIndex(testIndex)

// Assert index data
checkAnswer(
Expand All @@ -282,7 +264,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.onTable(testTable)
.addValueSet("address", Map("max_size" -> "2"))
.create()
flint.refreshIndex(testIndex, FULL)
flint.refreshIndex(testIndex)

// Assert index data
checkAnswer(
Expand Down Expand Up @@ -311,7 +293,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.onTable(testTable)
.addMinMax("age")
.create()
flint.refreshIndex(testIndex, FULL)
flint.refreshIndex(testIndex)

// Assert index data
checkAnswer(
Expand Down Expand Up @@ -384,7 +366,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.onTable(testTable)
.addPartitions("month")
.create()
flint.refreshIndex(testIndex, FULL)
flint.refreshIndex(testIndex)

// Generate a new source file which is not in index data
sql(s"""
Expand Down Expand Up @@ -643,7 +625,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
.addValueSet("varchar_col")
.addValueSet("char_col")
.create()
flint.refreshIndex(testIndex, FULL)
flint.refreshIndex(testIndex)

val query = sql(s"""
| SELECT varchar_col, char_col
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite {
indexData.count() shouldBe 2
}

test("should fail if refresh an auto refresh skipping index") {
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
| ( year PARTITION )
| WITH (auto_refresh = true)
| """.stripMargin)

assertThrows[IllegalStateException] {
sql(s"REFRESH SKIPPING INDEX ON $testTable")
}
}

test("create skipping index if not exists") {
sql(s"""
| CREATE SKIPPING INDEX
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import org.opensearch.action.get.GetRequest
import org.opensearch.client.RequestOptions
import org.opensearch.client.indices.GetIndexRequest
import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.should.Matchers

Expand Down Expand Up @@ -71,7 +70,7 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match
.onTable(testTable)
.addPartitions("year", "month")
.create()
flint.refreshIndex(testFlintIndex, FULL)
flint.refreshIndex(testFlintIndex)

val latest = latestLogEntry(testLatestId)
latest should contain("state" -> "active")
Expand All @@ -85,7 +84,7 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match
.addPartitions("year", "month")
.options(FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
.create()
flint.refreshIndex(testFlintIndex, INCREMENTAL)
flint.refreshIndex(testFlintIndex)

// Job start time should be assigned
var latest = latestLogEntry(testLatestId)
Expand Down
Loading