Skip to content

Commit

Permalink
Add covering index based query rewriter rule (#318)
Browse files Browse the repository at this point in the history
* Add query rewriting rule for covering index

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Apr 30, 2024
1 parent c877e09 commit a8a376f
Show file tree
Hide file tree
Showing 8 changed files with 457 additions and 30 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration"))
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
"org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test",
"org.mockito" % "mockito-inline" % "4.6.0" % "test",
"com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test",
"com.github.sbt" % "junit-interface" % "0.13.3" % "test"),
libraryDependencies ++= deps(sparkVersion),
Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.datasource.flint.retry.max_retries`: max retries on failed HTTP request. default value is 3. Use 0 to disable retry.
- `spark.datasource.flint.retry.http_status_codes`: retryable HTTP response status code list. default value is "429,502" (429 Too Many Request and 502 Bad Gateway).
- `spark.datasource.flint.retry.exception_class_names`: retryable exception class name list. by default no retry on any exception thrown.
- `spark.flint.optimizer.enabled`: default is true.
- `spark.flint.optimizer.enabled`: default is true. enable the Flint optimizer for improving query performance.
- `spark.flint.optimizer.covering.enabled`: default is true. enable the Flint covering index optimizer for improving query performance.
- `spark.flint.index.hybridscan.enabled`: default is false.
- `spark.flint.index.checkpoint.mandatory`: default is true.
- `spark.datasource.flint.socket_timeout_millis`: default value is 60000.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ object FlintSparkConf {
.doc("Enable Flint optimizer rule for query rewrite with Flint index")
.createWithDefault("true")

val OPTIMIZER_RULE_COVERING_INDEX_ENABLED =
FlintConfig("spark.flint.optimizer.covering.enabled")
.doc("Enable Flint optimizer rule for query rewrite with Flint covering index")
.createWithDefault("true")

val HYBRID_SCAN_ENABLED = FlintConfig("spark.flint.index.hybridscan.enabled")
.doc("Enable hybrid scan to include latest source data not refreshed to index yet")
.createWithDefault("false")
Expand Down Expand Up @@ -200,6 +205,9 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable

def isOptimizerEnabled: Boolean = OPTIMIZER_RULE_ENABLED.readFrom(reader).toBoolean

def isCoveringIndexOptimizerEnabled: Boolean =
OPTIMIZER_RULE_COVERING_INDEX_ENABLED.readFrom(reader).toBoolean

def isHybridScanEnabled: Boolean = HYBRID_SCAN_ENABLED.readFrom(reader).toBoolean

def isCheckpointMandatory: Boolean = CHECKPOINT_MANDATORY.readFrom(reader).toBoolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.flint.spark

import org.opensearch.flint.spark.covering.ApplyFlintSparkCoveringIndex
import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex

import org.apache.spark.sql.SparkSession
Expand All @@ -22,18 +23,30 @@ class FlintSparkOptimizer(spark: SparkSession) extends Rule[LogicalPlan] {
/** Flint Spark API */
private val flint: FlintSpark = new FlintSpark(spark)

/** Only one Flint optimizer rule for now. Need to estimate cost if more than one in future. */
private val rule = new ApplyFlintSparkSkippingIndex(flint)
/** Skipping index rewrite rule */
private val skippingIndexRule = new ApplyFlintSparkSkippingIndex(flint)

/** Covering index rewrite rule */
private val coveringIndexRule = new ApplyFlintSparkCoveringIndex(flint)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (isOptimizerEnabled) {
rule.apply(plan)
if (isFlintOptimizerEnabled) {
if (isCoveringIndexOptimizerEnabled) {
// Apply covering index rule first
skippingIndexRule.apply(coveringIndexRule.apply(plan))
} else {
skippingIndexRule.apply(plan)
}
} else {
plan
}
}

private def isOptimizerEnabled: Boolean = {
private def isFlintOptimizerEnabled: Boolean = {
FlintSparkConf().isOptimizerEnabled
}

private def isCoveringIndexOptimizerEnabled: Boolean = {
FlintSparkConf().isCoveringIndexOptimizerEnabled
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.covering

import java.util

import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.DELETED
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.flint.{qualifyTableName, FlintDataSourceV2}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* Flint Spark covering index apply rule that replace applicable query's table scan operator to
* accelerate query by scanning covering index data.
*
* @param flint
* Flint Spark API
*/
class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case relation @ LogicalRelation(_, _, Some(table), false)
if !plan.isInstanceOf[V2WriteCommand] => // TODO: make sure only intercept SELECT query
val relationCols = collectRelationColumnsInQueryPlan(relation, plan)

// Choose the first covering index that meets all criteria above
findAllCoveringIndexesOnTable(table.qualifiedName)
.sortBy(_.name())
.collectFirst {
case index: FlintSparkCoveringIndex if isCoveringIndexApplicable(index, relationCols) =>
replaceTableRelationWithIndexRelation(index, relation)
}
.getOrElse(relation) // If no index found, return the original relation
}

private def collectRelationColumnsInQueryPlan(
relation: LogicalRelation,
plan: LogicalPlan): Set[String] = {
/*
* Collect all columns of the relation present in query plan, except those in relation itself.
* Because this rule executes before push down optimization, relation includes all columns.
*/
val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap
plan
.collect {
case _: LogicalRelation => Set.empty
case other =>
other.expressions
.flatMap(_.references)
.flatMap(ref =>
relationColsById.get(ref.exprId)) // Ignore attribute not belong to target relation
.map(attr => attr.name)
}
.flatten
.toSet
}

private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkIndex] = {
val qualifiedTableName = qualifyTableName(flint.spark, tableName)
val indexPattern = getFlintIndexName("*", qualifiedTableName)
flint.describeIndexes(indexPattern)
}

private def isCoveringIndexApplicable(
index: FlintSparkCoveringIndex,
relationCols: Set[String]): Boolean = {
index.latestLogEntry.exists(_.state != DELETED) &&
index.filterCondition.isEmpty && // TODO: support partial covering index later
relationCols.subsetOf(index.indexedColumns.keySet)
}

private def replaceTableRelationWithIndexRelation(
index: FlintSparkCoveringIndex,
relation: LogicalRelation): LogicalPlan = {
// Make use of data source relation to avoid Spark looking for OpenSearch index in catalog
val ds = new FlintDataSourceV2
val options = new CaseInsensitiveStringMap(util.Map.of("path", index.name()))
val inferredSchema = ds.inferSchema(options)
val flintTable = ds.getTable(inferredSchema, Array.empty, options)

// Reuse original attribute's exprId because it's already analyzed and referenced
// by the other parts of the query plan.
val allRelationCols = relation.output.map(attr => (attr.name, attr)).toMap
val outputAttributes =
flintTable
.schema()
.map(field => {
val relationCol = allRelationCols(field.name) // index column must exist in relation
AttributeReference(field.name, field.dataType, field.nullable, field.metadata)(
relationCol.exprId,
relationCol.qualifier)
})

// Create the DataSourceV2 scan with corrected attributes
DataSourceV2Relation(flintTable, outputAttributes, None, None, options)
}
}
Loading

0 comments on commit a8a376f

Please sign in to comment.