Skip to content

Commit

Permalink
Add index refresh validation
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Mar 26, 2024
1 parent 94fc2f5 commit 1328586
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ class FlintSpark(val spark: SparkSession) extends Logging {
} else {
val metadata = index.metadata()
try {
// Validate index beforehand to avoid leaving behind an orphaned OS index
// when streaming job fails to start due to invalidity later
index.validate(spark)

// Start transaction only if index validation passed
flintClient
.startTransaction(indexName, dataSourceName, true)
.initialLog(latest => latest.state == EMPTY || latest.state == DELETED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -46,6 +51,16 @@ trait FlintSparkIndex {
*/
def metadata(): FlintMetadata

/**
* Validates the index to ensure its validity. By default, this method validates index options
* by delegating to specific index refresh (index options are mostly serving index refresh).
* Subclasses can extend this method to include additional validation logic.
*/
def validate(spark: SparkSession): Unit = {
val refresh = FlintSparkIndexRefresh.create(name(), this)
refresh.validate(spark) // TODO: why indexName arg necessary?
}

/**
* Build a data frame to represent index data computation logic. Upper level code decides how to
* use this, ex. batch or streaming, fully or incremental refresh.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

package org.opensearch.flint.spark.refresh

import java.util.Collections

import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions}
import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.{isCheckpointLocationAccessible, isSourceTableNonHive}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode}

import org.apache.spark.sql.{DataFrame, Row, SparkSession}
Expand All @@ -27,6 +30,35 @@ class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) extends FlintS

override def refreshMode: RefreshMode = AUTO

override def validate(spark: SparkSession): Unit = {
// Incremental refresh cannot enabled at the same time
val options = index.options
require(
!options.incrementalRefresh(),
"Incremental refresh cannot be enabled if auto refresh is enabled")

// Non-Hive table is required for auto refresh
require(
isSourceTableNonHive(spark, index),
"Flint index auto refresh doesn't support Hive table")

// Checkpoint location is required if mandatory option set
val flintSparkConf = new FlintSparkConf(Collections.emptyMap)
val checkpointLocation = index.options.checkpointLocation()
if (flintSparkConf.isCheckpointMandatory) {
require(
checkpointLocation.isDefined,
s"Checkpoint location is required if ${CHECKPOINT_MANDATORY.key} option enabled")
}

// Given checkpoint location is accessible
if (checkpointLocation.isDefined) {
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
val options = index.options
val tableName = index.metadata().source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@

package org.opensearch.flint.spark.refresh

import java.io.IOException

import org.apache.hadoop.fs.Path
import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.RefreshMode
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}
import org.apache.spark.sql.flint.config.FlintSparkConf

/**
Expand All @@ -24,6 +34,11 @@ trait FlintSparkIndexRefresh extends Logging {
*/
def refreshMode: RefreshMode

/**
* Validate the current index refresh beforehand.
*/
def validate(spark: SparkSession): Unit = {}

/**
* Start refreshing the index.
*
Expand All @@ -37,7 +52,7 @@ trait FlintSparkIndexRefresh extends Logging {
def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String]
}

object FlintSparkIndexRefresh {
object FlintSparkIndexRefresh extends Logging {

/** Index refresh mode */
object RefreshMode extends Enumeration {
Expand Down Expand Up @@ -65,4 +80,42 @@ object FlintSparkIndexRefresh {
new FullIndexRefresh(indexName, index)
}
}

def isSourceTableNonHive(spark: SparkSession, index: FlintSparkIndex): Boolean = {
// Extract source table name (possibly more than 1 for MV source query)
val tableNames = index match {
case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName)
case covering: FlintSparkCoveringIndex => Seq(covering.tableName)
case mv: FlintSparkMaterializedView =>
spark.sessionState.sqlParser
.parsePlan(mv.query)
.collect { case LogicalRelation(_, _, Some(table), _) =>
qualifyTableName(spark, table.identifier.table)
}
}

// Validate each source table is Hive
tableNames.forall { tableName =>
val (catalog, ident) = parseTableName(spark, tableName)
val table = loadTable(catalog, ident).get
!DDLUtils.isHiveTable(Option(table.properties().get("provider")))
}
}

def isCheckpointLocationAccessible(spark: SparkSession, checkpointLocation: String): Boolean = {
val checkpointPath = new Path(checkpointLocation)
val checkpointManager =
CheckpointFileManager.create(checkpointPath, spark.sessionState.newHadoopConf())
try {
// require(
checkpointManager.exists(checkpointPath)
// s"Checkpoint location $checkpointLocation doesn't exist")
} catch {
case e: IOException =>
logWarning(s"Failed to check if checkpoint location $checkpointLocation exists", e)
// throw new IllegalArgumentException(
// s"No permission to access checkpoint location $checkpointLocation")
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.flint.spark.refresh

import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.{isCheckpointLocationAccessible, isSourceTableNonHive}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{INCREMENTAL, RefreshMode}

import org.apache.spark.sql.SparkSession
Expand All @@ -24,14 +25,24 @@ class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex)

override def refreshMode: RefreshMode = INCREMENTAL

override def validate(spark: SparkSession): Unit = {
// Non-Hive table is required for incremental refresh
require(!isSourceTableNonHive(spark, index), "Flint index incremental refresh doesn't support Hive table")

// Checkpoint location is required regardless of mandatory option
val options = index.options
val checkpointLocation = options.checkpointLocation()
require(
options.checkpointLocation().nonEmpty,
"Checkpoint location is required by incremental refresh")
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
logInfo(s"Start refreshing index $indexName in incremental mode")

// TODO: move this to validation method together in future
if (index.options.checkpointLocation().isEmpty) {
throw new IllegalStateException("Checkpoint location is required by incremental refresh")
}

// Reuse auto refresh which uses AvailableNow trigger and will stop once complete
val jobId =
new AutoIndexRefresh(indexName, index)
Expand Down

0 comments on commit 1328586

Please sign in to comment.