Skip to content

Commit

Permalink
refactors
Browse files Browse the repository at this point in the history
* Use FlintSparkIndexOptions, instead of options Map
* Define merge operation
* Remove UpdateMode
* Have updateIndexOptions return FlintSparkIndex

Signed-off-by: Sean Kao <[email protected]>
  • Loading branch information
seankao-az committed Mar 23, 2024
1 parent 05ebd93 commit 63a7236
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGN
* Flint Spark integration API entrypoint.
*/
class FlintSpark(val spark: SparkSession) extends Logging {
import FlintSpark.UpdateMode._

/** Flint spark configuration */
private val flintSparkConf: FlintSparkConf =
Expand Down Expand Up @@ -215,14 +214,15 @@ class FlintSpark(val spark: SparkSession) extends Logging {
* @return
* refreshing job ID (empty if no job)
*/
def updateIndex(index: FlintSparkIndex, updateMode: UpdateMode): Option[String] = {
def updateIndex(index: FlintSparkIndex): Option[String] = {
logInfo(s"Updating Flint index $index")
val indexName = index.name
if (flintClient.exists(indexName)) {
try {
updateMode match {
case MANUAL_TO_AUTO => updateIndexManualToAuto(index)
case AUTO_TO_MANUAL => updateIndexAutoToManual(index)
// Relies on validation to forbid auto-to-auto and manual-to-manual updates
index.options.autoRefresh() match {
case true => updateIndexManualToAuto(index)
case false => updateIndexAutoToManual(index)
}
} catch {
case e: Exception =>
Expand All @@ -236,64 +236,56 @@ class FlintSpark(val spark: SparkSession) extends Logging {

/**
* Update Flint index metadata and job associated with index.
* TODO: This should probably not be in FlintSpark but be in some Factory or Builder class.
* Name for this method is confusing as well.
*
* @param flint
* Flint Spark which has access to Spark Catalog
* @param indexName
* index name
* @param updateOptions
* options to update
* @return
* Flint index
*/
def updateIndex(indexName: String, updateOptions: Map[String, String]): Option[String] = {
def updateIndexOptions(
indexName: String,
updateOptions: FlintSparkIndexOptions): FlintSparkIndex = {
val oldIndex = describeIndex(indexName)
.getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist"))

val oldOptions = oldIndex.options.options
validateUpdateOptions(oldOptions, updateOptions, oldIndex.kind)

val mergedOptions = oldOptions ++ updateOptions
val newMetadata =
oldIndex.metadata().copy(options = mergedOptions.mapValues(_.asInstanceOf[AnyRef]).asJava)
val newIndex = FlintSparkIndexFactory.create(newMetadata)

val updateMode = newIndex.options.autoRefresh() match {
case true => MANUAL_TO_AUTO
case false => AUTO_TO_MANUAL
}
val oldOptions = oldIndex.options
validateIndexUpdateOptions(oldOptions, updateOptions)

updateIndex(newIndex, updateMode)
val mergedOptions = oldOptions.merge(updateOptions)
val newMetadata = oldIndex
.metadata()
.copy(options = mergedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
FlintSparkIndexFactory.create(newMetadata)
}

/**
* Validate update options. These are rules specific for updating index, validating the update
* is allowed. It doesn't check whether the resulting index options will be valid.
* Validate index update options. These are rules specific for updating index, validating the
* update is allowed. It doesn't check whether the resulting index options will be valid.
*
* @param oldOptions
* existing options
* @param updateOptions
* options to update
* @param indexKind
* index kind
*/
private def validateUpdateOptions(
oldOptions: Map[String, String],
updateOptions: Map[String, String],
indexKind: String): Unit = {
val mergedOptions = oldOptions ++ updateOptions
val newAutoRefresh = mergedOptions.getOrElse(AUTO_REFRESH.toString, "false")
val oldAutoRefresh = oldOptions.getOrElse(AUTO_REFRESH.toString, "false")
private def validateIndexUpdateOptions(
oldOptions: FlintSparkIndexOptions,
updateOptions: FlintSparkIndexOptions): Unit = {
val mergedOptions = oldOptions.merge(updateOptions)

// auto_refresh must change
if (newAutoRefresh == oldAutoRefresh) {
if (mergedOptions.autoRefresh() == oldOptions.autoRefresh()) {
throw new IllegalArgumentException("auto_refresh option must be updated")
}

val newIncrementalRefresh = mergedOptions.getOrElse(INCREMENTAL_REFRESH.toString, "false")
val refreshMode = (newAutoRefresh, newIncrementalRefresh) match {
case ("true", "false") => AUTO
case ("false", "false") => FULL
case ("false", "true") => INCREMENTAL
case ("true", "true") =>
val refreshMode = (mergedOptions.autoRefresh(), mergedOptions.incrementalRefresh()) match {
case (true, false) => AUTO
case (false, false) => FULL
case (false, true) => INCREMENTAL
case (true, true) =>
throw new IllegalArgumentException(
"auto_refresh and incremental_refresh options cannot both be true")
}
Expand All @@ -309,9 +301,9 @@ class FlintSpark(val spark: SparkSession) extends Logging {
CHECKPOINT_LOCATION,
WATERMARK_DELAY)
}
if (!updateOptions.keys.forall(allowedOptions.map(_.toString).contains)) {
if (!updateOptions.options.keys.forall(allowedOptions.map(_.toString).contains)) {
throw new IllegalArgumentException(
s"Altering ${indexKind} index to ${refreshMode} refresh only allows options: ${allowedOptions}")
s"Altering index to ${refreshMode} refresh only allows options: ${allowedOptions}")
}
}

Expand Down Expand Up @@ -503,11 +495,3 @@ class FlintSpark(val spark: SparkSession) extends Logging {
})
}
}

object FlintSpark {
object UpdateMode extends Enumeration {
type UpdateMode = Value
val MANUAL_TO_AUTO, AUTO_TO_MANUAL = Value
// TODO: support AUTO_TO_AUTO and MANUAL_TO_MANUAL
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ case class FlintSparkIndexOptions(options: Map[String, String]) {
map.result()
}

/**
* Merge two FlintSparkIndexOptions. If an option exists in both instances, the value from the
* other instance overwrites the value from this instance.
* @param other
* option to merge
* @return
* merged Flint Spark index options
*/
def merge(other: FlintSparkIndexOptions): FlintSparkIndexOptions = {
FlintSparkIndexOptions(options ++ other.options)
}

private def getOptionValue(name: OptionName): Option[String] = {
options.get(name.toString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.Locale

import scala.collection.JavaConverters.asScalaBufferConverter

import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{PropertyKeyContext, PropertyListContext, PropertyValueContext}

import org.apache.spark.sql.catalyst.parser.ParserUtils.string
Expand All @@ -19,16 +20,16 @@ import org.apache.spark.sql.catalyst.parser.ParserUtils.string
*/
trait SparkSqlAstBuilder extends FlintSparkSqlExtensionsVisitor[AnyRef] {

override def visitPropertyList(ctx: PropertyListContext): Map[String, String] = {
override def visitPropertyList(ctx: PropertyListContext): FlintSparkIndexOptions = {
if (ctx == null) {
Map.empty
FlintSparkIndexOptions.empty
} else {
val properties = ctx.property.asScala.map { property =>
val key = visitPropertyKey(property.key)
val value = visitPropertyValue(property.value)
key -> value
}
properties.toMap
FlintSparkIndexOptions(properties.toMap)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.opensearch.flint.spark.sql.covering

import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
Expand Down Expand Up @@ -45,7 +44,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
}

val ignoreIfExists = ctx.EXISTS() != null
val indexOptions = FlintSparkIndexOptions(visitPropertyList(ctx.propertyList()))
val indexOptions = visitPropertyList(ctx.propertyList())
indexBuilder
.options(indexOptions)
.create(ignoreIfExists)
Expand Down Expand Up @@ -107,8 +106,9 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
ctx: AlterCoveringIndexStatementContext): Command = {
FlintSparkSqlCommand() { flint =>
val indexName = getFlintIndexName(flint, ctx.indexName, ctx.tableName)
val indexOptionsMap = visitPropertyList(ctx.propertyList())
flint.updateIndex(indexName, indexOptionsMap)
val indexOptions = visitPropertyList(ctx.propertyList())
val updatedIndex = flint.updateIndexOptions(indexName, indexOptions)
flint.updateIndex(updatedIndex)
Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`

import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
Expand Down Expand Up @@ -38,7 +37,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito
.query(query)

val ignoreIfExists = ctx.EXISTS() != null
val indexOptions = FlintSparkIndexOptions(visitPropertyList(ctx.propertyList()))
val indexOptions = visitPropertyList(ctx.propertyList())
mvBuilder
.options(indexOptions)
.create(ignoreIfExists)
Expand Down Expand Up @@ -104,8 +103,9 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito
ctx: AlterMaterializedViewStatementContext): Command = {
FlintSparkSqlCommand() { flint =>
val indexName = getFlintIndexName(flint, ctx.mvName)
val indexOptionsMap = visitPropertyList(ctx.propertyList())
flint.updateIndex(indexName, indexOptionsMap)
val indexOptions = visitPropertyList(ctx.propertyList())
val updatedIndex = flint.updateIndexOptions(indexName, indexOptions)
flint.updateIndex(updatedIndex)
Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import scala.collection.JavaConverters.collectionAsScalaIterableConverter
import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory._
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, MIN_MAX, PARTITION, VALUE_SET}
Expand Down Expand Up @@ -71,7 +70,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
}

val ignoreIfExists = ctx.EXISTS() != null
val indexOptions = FlintSparkIndexOptions(visitPropertyList(ctx.propertyList()))
val indexOptions = visitPropertyList(ctx.propertyList())
indexBuilder
.options(indexOptions)
.create(ignoreIfExists)
Expand Down Expand Up @@ -115,8 +114,9 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
ctx: AlterSkippingIndexStatementContext): Command = {
FlintSparkSqlCommand() { flint =>
val indexName = getSkippingIndexName(flint, ctx.tableName)
val indexOptionsMap = visitPropertyList(ctx.propertyList())
flint.updateIndex(indexName, indexOptionsMap)
val indexOptions = visitPropertyList(ctx.propertyList())
val updatedIndex = flint.updateIndexOptions(indexName, indexOptions)
flint.updateIndex(updatedIndex)
Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite {
checkAnswer(indexData, Seq())

// Update Flint index to auto refresh and wait for complete
val jobId = flint.updateIndex(testFlintIndex, Map("auto_refresh" -> "true"))
val updatedIndex = flint.updateIndexOptions(
testFlintIndex,
FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
val jobId = flint.updateIndex(updatedIndex)
jobId shouldBe defined

val job = spark.streams.get(jobId.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,14 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite {
checkAnswer(indexData, Seq())

// Update Flint index to auto refresh and wait for complete
val jobId = flint.updateIndex(
val updatedIndex = flint.updateIndexOptions(
testFlintIndex,
Map(
"auto_refresh" -> "true",
"checkpoint_location" -> checkpointDir.getAbsolutePath,
"watermark_delay" -> "1 Minute"))
FlintSparkIndexOptions(
Map(
"auto_refresh" -> "true",
"checkpoint_location" -> checkpointDir.getAbsolutePath,
"watermark_delay" -> "1 Minute")))
val jobId = flint.updateIndex(updatedIndex)
jobId shouldBe defined

val job = spark.streams.get(jobId.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
flint.queryIndex(testIndex).collect().toSet should have size 0

// Update Flint index to auto refresh and wait for complete
val jobId = flint.updateIndex(testIndex, Map("auto_refresh" -> "true"))
val updatedIndex =
flint.updateIndexOptions(testIndex, FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
val jobId = flint.updateIndex(updatedIndex)
jobId shouldBe defined

val job = spark.streams.get(jobId.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match
.addPartitions("year", "month")
.create()

flint.updateIndex(testFlintIndex, Map("auto_refresh" -> "true"))
val updatedIndex = flint.updateIndexOptions(
testFlintIndex,
FlintSparkIndexOptions(Map("auto_refresh" -> "true")))
flint.updateIndex(updatedIndex)
val latest = latestLogEntry(testLatestId)
latest should contain("state" -> "refreshing")
latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L
Expand All @@ -143,7 +146,10 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match
.create()
flint.refreshIndex(testFlintIndex)

flint.updateIndex(testFlintIndex, Map("auto_refresh" -> "false"))
val updatedIndex = flint.updateIndexOptions(
testFlintIndex,
FlintSparkIndexOptions(Map("auto_refresh" -> "false")))
flint.updateIndex(updatedIndex)
val latest = latestLogEntry(testLatestId)
latest should contain("state" -> "active")
}
Expand Down
Loading

0 comments on commit 63a7236

Please sign in to comment.