Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve pre-validation for Flint index refresh options #297

Merged
merged 12 commits into from
Apr 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,6 @@ class FlintSpark(val spark: SparkSession) extends Logging {
case (true, false) => AUTO
case (false, false) => FULL
case (false, true) => INCREMENTAL
case (true, true) =>
throw new IllegalArgumentException(
"auto_refresh and incremental_refresh options cannot both be true")
}

// validate allowed options depending on refresh mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.opensearch.flint.spark
import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh

import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -59,7 +60,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
* ignore existing index
*/
def create(ignoreIfExists: Boolean = false): Unit =
flint.createIndex(buildIndex(), ignoreIfExists)
flint.createIndex(validateIndex(buildIndex()), ignoreIfExists)

/**
* Copy Flint index with updated options.
Expand All @@ -80,7 +81,24 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
FlintSparkIndexFactory.create(updatedMetadata).get
validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get)
}

/**
* Pre-validate index to ensure its validity. By default, this method validates index options by
* delegating to specific index refresh (index options are mostly serving index refresh).
* Subclasses can extend this method to include additional validation logic.
*
* @param index
* Flint index to be validated
* @return
* the index or exception occurred if validation failed
*/
protected def validateIndex(index: FlintSparkIndex): FlintSparkIndex = {
FlintSparkIndexRefresh
.create(index.name(), index) // TODO: remove first argument?
.validate(flint.spark)
index
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import java.io.IOException

import org.apache.hadoop.fs.Path
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}

/**
* Flint Spark validation helper.
*/
trait FlintSparkValidationHelper extends Logging {

/**
* Determines whether the source table(s) for a given Flint index are supported.
*
* @param spark
* Spark session
* @param index
* Flint index
* @return
* true if all non Hive, otherwise false
*/
def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = {
// Extract source table name (possibly more than one for MV query)
val tableNames = index match {
case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName)
case covering: FlintSparkCoveringIndex => Seq(covering.tableName)
case mv: FlintSparkMaterializedView =>
spark.sessionState.sqlParser
.parsePlan(mv.query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(spark, relation.tableName)
}
}

// Validate if any source table is not supported (currently Hive only)
tableNames.exists { tableName =>
val (catalog, ident) = parseTableName(spark, tableName)
val table = loadTable(catalog, ident).get

// TODO: add allowed table provider list
DDLUtils.isHiveTable(Option(table.properties().get("provider")))
}
}

/**
* Checks whether a specified checkpoint location is accessible. Accessibility, in this context,
* means that the folder exists and the current Spark session has the necessary permissions to
* access it.
*
* @param spark
* Spark session
* @param checkpointLocation
* checkpoint location
* @return
* true if accessible, otherwise false
*/
def isCheckpointLocationAccessible(spark: SparkSession, checkpointLocation: String): Boolean = {
try {
val checkpointManager =
CheckpointFileManager.create(
new Path(checkpointLocation),
spark.sessionState.newHadoopConf())

checkpointManager.exists(new Path(checkpointLocation))
} catch {
case e: IOException =>
logWarning(s"Failed to check if checkpoint location $checkpointLocation exists", e)
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.flint.spark.refresh

import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions}
import java.util.Collections

import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions, FlintSparkValidationHelper}
import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode}

Expand All @@ -23,10 +25,41 @@ import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger}
* @param index
* Flint index
*/
class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) extends FlintSparkIndexRefresh {
class AutoIndexRefresh(indexName: String, index: FlintSparkIndex)
extends FlintSparkIndexRefresh
with FlintSparkValidationHelper {

override def refreshMode: RefreshMode = AUTO

override def validate(spark: SparkSession): Unit = {
// Incremental refresh cannot enabled at the same time
val options = index.options
require(
!options.incrementalRefresh(),
"Incremental refresh cannot be enabled if auto refresh is enabled")
dai-chen marked this conversation as resolved.
Show resolved Hide resolved

// Hive table doesn't support auto refresh
require(
!isTableProviderSupported(spark, index),
"Index auto refresh doesn't support Hive table")

// Checkpoint location is required if mandatory option set
val flintSparkConf = new FlintSparkConf(Collections.emptyMap[String, String])
val checkpointLocation = options.checkpointLocation()
if (flintSparkConf.isCheckpointMandatory) {
require(
checkpointLocation.isDefined,
s"Checkpoint location is required if ${CHECKPOINT_MANDATORY.key} option enabled")
}

// Checkpoint location must be accessible
if (checkpointLocation.isDefined) {
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
val options = index.options
val tableName = index.metadata().source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ trait FlintSparkIndexRefresh extends Logging {
*/
def refreshMode: RefreshMode

/**
* Validates the current index refresh settings before the actual execution begins. This method
* checks for the integrity of the index refresh configurations and ensures that all options set
* for the current refresh mode are valid. This preemptive validation helps in identifying
* configuration issues before the refresh operation is initiated, minimizing runtime errors and
* potential inconsistencies.
*
* @param spark
* Spark session
* @throws IllegalArgumentException
* if any invalid or inapplicable config identified
*/
def validate(spark: SparkSession): Unit

/**
* Start refreshing the index.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class FullIndexRefresh(

override def refreshMode: RefreshMode = FULL

override def validate(spark: SparkSession): Unit = {
// Full refresh validates nothing for now, including Hive table validation.
// This allows users to continue using their existing Hive table with full refresh only.
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
logInfo(s"Start refreshing index $indexName in full mode")
index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.spark.refresh

import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkValidationHelper}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{INCREMENTAL, RefreshMode}

import org.apache.spark.sql.SparkSession
Expand All @@ -20,18 +20,31 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
* Flint index
*/
class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex)
extends FlintSparkIndexRefresh {
extends FlintSparkIndexRefresh
with FlintSparkValidationHelper {

override def refreshMode: RefreshMode = INCREMENTAL

override def validate(spark: SparkSession): Unit = {
// Non-Hive table is required for incremental refresh
require(
!isTableProviderSupported(spark, index),
"Index incremental refresh doesn't support Hive table")

// Checkpoint location is required regardless of mandatory option
val options = index.options
val checkpointLocation = options.checkpointLocation()
require(
options.checkpointLocation().nonEmpty,
"Checkpoint location is required by incremental refresh")
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
logInfo(s"Start refreshing index $indexName in incremental mode")

// TODO: move this to validation method together in future
if (index.options.checkpointLocation().isEmpty) {
throw new IllegalStateException("Checkpoint location is required by incremental refresh")
}

// Reuse auto refresh which uses AvailableNow trigger and will stop once complete
val jobId =
new AutoIndexRefresh(indexName, index)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.hive.HiveSessionStateBuilder
import org.apache.spark.sql.internal.{SessionState, StaticSQLConf}
import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession}

/**
* Flint Spark base suite with Hive support enabled. Because enabling Hive support in Spark
* configuration alone is not adequate, as [[TestSparkSession]] disregards it and consistently
* creates its own instance of [[org.apache.spark.sql.test.TestSQLSessionStateBuilder]]. We need
* to override its session state with that of Hive in the meanwhile.
*
* Note that we need to extend [[SharedSparkSession]] to call super.sparkConf() method.
*/
trait SparkHiveSupportSuite extends SharedSparkSession {

override protected def sparkConf: SparkConf = {
super.sparkConf
// Enable Hive support
.set(StaticSQLConf.CATALOG_IMPLEMENTATION.key, "hive")
// Use in-memory Derby as Hive metastore so no need to clean up metastore_db folder after test
.set("javax.jdo.option.ConnectionURL", "jdbc:derby:memory:metastore_db;create=true")
.set("hive.metastore.uris", "")
}

override protected def createSparkSession: TestSparkSession = {
SparkSession.cleanupAnyExistingSession()
new FlintTestSparkSession(sparkConf)
}

class FlintTestSparkSession(sparkConf: SparkConf) extends TestSparkSession(sparkConf) { self =>

override lazy val sessionState: SessionState = {
// Override to replace [[TestSQLSessionStateBuilder]] with Hive session state
new HiveSessionStateBuilder(spark, None).build()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {
test("create skipping index with auto refresh should fail if mandatory checkpoint enabled") {
setFlintSparkConf(CHECKPOINT_MANDATORY, "true")
try {
the[IllegalStateException] thrownBy {
the[IllegalArgumentException] thrownBy {
sql(s"""
| CREATE INDEX $testIndex ON $testTable
| (name, age)
Expand Down
Loading
Loading