Skip to content

Commit

Permalink
Implement index refresh options validation
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Mar 29, 2024
1 parent 5b1a49a commit 3380401
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class FlintSpark(val spark: SparkSession) extends Logging {
})
logInfo("Create index complete")
} catch {
case e: FlintSparkException =>
logError("Failed to create Flint index", e)
throw new IllegalStateException("Failed to create Flint index: " + e.getMessage)
case e: Exception =>
logError("Failed to create Flint index", e)
throw new IllegalStateException("Failed to create Flint index")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

/**
* Flint Spark exception.
*/
abstract class FlintSparkException(message: String, cause: Option[Throwable])
extends Throwable(message) {}

object FlintSparkException {

def requireValidation(requirement: Boolean, message: => Any): Unit = {
if (!requirement) {
throw new FlintSparkValidationException(message.toString)
}
}
}

/**
* Flint Spark validation exception.
*
* @param message
* error message
* @param cause
* exception causing the error
*/
class FlintSparkValidationException(message: String, cause: Option[Throwable] = None)
extends FlintSparkException(message, cause)
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -57,8 +53,10 @@ trait FlintSparkIndex {
* Subclasses can extend this method to include additional validation logic.
*/
def validate(spark: SparkSession): Unit = {
val refresh = FlintSparkIndexRefresh.create(name(), this)
refresh.validate(spark) // TODO: why indexName arg necessary?
// Validate if index option valid for refresh
FlintSparkIndexRefresh
.create(name(), this)
.validate(spark) // TODO: why indexName arg necessary?
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ package org.opensearch.flint.spark.refresh
import java.util.Collections

import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions}
import org.opensearch.flint.spark.FlintSparkException.requireValidation
import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.{isCheckpointLocationAccessible, isSourceTableNonHive}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode}

import org.apache.spark.sql.{DataFrame, Row, SparkSession}
Expand All @@ -33,27 +33,27 @@ class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) extends FlintS
override def validate(spark: SparkSession): Unit = {
// Incremental refresh cannot enabled at the same time
val options = index.options
require(
requireValidation(
!options.incrementalRefresh(),
"Incremental refresh cannot be enabled if auto refresh is enabled")

// Non-Hive table is required for auto refresh
require(
requireValidation(
isSourceTableNonHive(spark, index),
"Flint index auto refresh doesn't support Hive table")
"Index auto refresh doesn't support Hive table")

// Checkpoint location is required if mandatory option set
val flintSparkConf = new FlintSparkConf(Collections.emptyMap)
val flintSparkConf = new FlintSparkConf(Collections.emptyMap[String, String])
val checkpointLocation = index.options.checkpointLocation()
if (flintSparkConf.isCheckpointMandatory) {
require(
requireValidation(
checkpointLocation.isDefined,
s"Checkpoint location is required if ${CHECKPOINT_MANDATORY.key} option enabled")
}

// Given checkpoint location is accessible
if (checkpointLocation.isDefined) {
require(
requireValidation(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,6 @@ trait FlintSparkIndexRefresh extends Logging {
* optional Spark job ID
*/
def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String]
}

object FlintSparkIndexRefresh extends Logging {

/** Index refresh mode */
object RefreshMode extends Enumeration {
type RefreshMode = Value
val AUTO, FULL, INCREMENTAL = Value
}

/**
* Create concrete index refresh implementation for the given index.
*
* @param indexName
* Flint index name
* @param index
* Flint index
* @return
* index refresh
*/
def create(indexName: String, index: FlintSparkIndex): FlintSparkIndexRefresh = {
val options = index.options
if (options.autoRefresh()) {
new AutoIndexRefresh(indexName, index)
} else if (options.incrementalRefresh()) {
new IncrementalIndexRefresh(indexName, index)
} else {
new FullIndexRefresh(indexName, index)
}
}

/**
* Validate if source table(s) of the given Flint index are not Hive table.
Expand All @@ -91,7 +61,7 @@ object FlintSparkIndexRefresh extends Logging {
* @return
* true if all non Hive, otherwise false
*/
def isSourceTableNonHive(spark: SparkSession, index: FlintSparkIndex): Boolean = {
protected def isSourceTableNonHive(spark: SparkSession, index: FlintSparkIndex): Boolean = {
// Extract source table name (possibly more than 1 for MV source query)
val tableNames = index match {
case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName)
Expand Down Expand Up @@ -123,7 +93,9 @@ object FlintSparkIndexRefresh extends Logging {
* @return
* true if accessible, otherwise false
*/
def isCheckpointLocationAccessible(spark: SparkSession, checkpointLocation: String): Boolean = {
protected def isCheckpointLocationAccessible(
spark: SparkSession,
checkpointLocation: String): Boolean = {
val checkpointPath = new Path(checkpointLocation)
val checkpointManager =
CheckpointFileManager.create(checkpointPath, spark.sessionState.newHadoopConf())
Expand All @@ -140,3 +112,33 @@ object FlintSparkIndexRefresh extends Logging {
}
}
}

object FlintSparkIndexRefresh {

/** Index refresh mode */
object RefreshMode extends Enumeration {
type RefreshMode = Value
val AUTO, FULL, INCREMENTAL = Value
}

/**
* Create concrete index refresh implementation for the given index.
*
* @param indexName
* Flint index name
* @param index
* Flint index
* @return
* index refresh
*/
def create(indexName: String, index: FlintSparkIndex): FlintSparkIndexRefresh = {
val options = index.options
if (options.autoRefresh()) {
new AutoIndexRefresh(indexName, index)
} else if (options.incrementalRefresh()) {
new IncrementalIndexRefresh(indexName, index)
} else {
new FullIndexRefresh(indexName, index)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

package org.opensearch.flint.spark.refresh

import org.opensearch.flint.spark.FlintSparkException.requireValidation
import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.{isCheckpointLocationAccessible, isSourceTableNonHive}
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{INCREMENTAL, RefreshMode}

import org.apache.spark.sql.SparkSession
Expand All @@ -27,17 +27,17 @@ class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex)

override def validate(spark: SparkSession): Unit = {
// Non-Hive table is required for incremental refresh
require(
!isSourceTableNonHive(spark, index),
"Flint index incremental refresh doesn't support Hive table")
requireValidation(
isSourceTableNonHive(spark, index),
"Index incremental refresh doesn't support Hive table")

// Checkpoint location is required regardless of mandatory option
val options = index.options
val checkpointLocation = options.checkpointLocation()
require(
requireValidation(
options.checkpointLocation().nonEmpty,
"Checkpoint location is required by incremental refresh")
require(
requireValidation(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ package org.apache.spark

import org.opensearch.flint.spark.FlintSparkExtensions

import org.apache.spark.sql.{FlintTestSparkSession, SparkSession}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.flint.config.FlintConfigEntry
import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession}

trait FlintSuite extends SharedSparkSession {
override protected def sparkConf = {
Expand All @@ -26,9 +27,15 @@ trait FlintSuite extends SharedSparkSession {
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
.set("spark.sql.extensions", classOf[FlintSparkExtensions].getName)
.set(StaticSQLConf.CATALOG_IMPLEMENTATION.key, "hive")
conf
}

override protected def createSparkSession: TestSparkSession = {
SparkSession.cleanupAnyExistingSession()
new FlintTestSparkSession(sparkConf)
}

/**
* Set Flint Spark configuration. (Generic "value: T" has problem with FlintConfigEntry[Any])
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.hive.HiveSessionStateBuilder
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.test.TestSparkSession

class FlintTestSparkSession(sparkConf: SparkConf) extends TestSparkSession(sparkConf) {

override lazy val sessionState: SessionState = {
new HiveSessionStateBuilder(this, None).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import org.json4s.native.Serialization
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.storage.FlintOpenSearchClient
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.must.Matchers.defined
import org.scalatest.matchers.must.Matchers.{defined, have}
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the}

import org.apache.spark.sql.Row
Expand Down Expand Up @@ -237,6 +237,24 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite {
}
}

test("should fail if create auto refresh skipping index on Hive table") {
val hiveTableName = "spark_catalog.default.hive_table"
withTable(hiveTableName) {
sql(s"""
| CREATE TABLE $hiveTableName
| ( name STRING )
|""".stripMargin)

the[IllegalStateException] thrownBy {
sql(s"""
| CREATE SKIPPING INDEX ON $hiveTableName
| ( name VALUE_SET )
| WITH (auto_refresh = true)
| """.stripMargin)
} should have message "Failed to create Flint index: Index auto refresh doesn't support Hive table"
}
}

test("should fail if refresh an auto refresh skipping index") {
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
Expand Down
1 change: 1 addition & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ object Dependencies {
"org.apache.spark" %% "spark-core" % sparkVersion % "provided" withSources (),
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided" withSources (),
"org.json4s" %% "json4s-native" % "3.7.0-M5" % "test",
"org.apache.spark" %% "spark-hive" % sparkVersion % "test",
"org.apache.spark" %% "spark-catalyst" % sparkVersion % "test" classifier "tests",
"org.apache.spark" %% "spark-core" % sparkVersion % "test" classifier "tests",
"org.apache.spark" %% "spark-sql" % sparkVersion % "test" classifier "tests")
Expand Down

0 comments on commit 3380401

Please sign in to comment.