Skip to content

Commit

Permalink
Fix bound comparison bug and refactor UT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Jul 3, 2024
1 parent 22b236f commit b5c958a
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,87 +5,164 @@

package org.opensearch.flint.spark

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BinaryComparison, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BinaryComparison, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Or}

/**
* Query rewrite helper that provides common utilities for query rewrite rule of various Flint
* indexes.
*/
trait FlintSparkQueryRewriteHelper {

/**
* Determines if the given filter expression consists solely of AND operations and no OR
* operations, implying that it's a conjunction of conditions.
*
* @param filter
* The filter expression to check.
* @return
* True if the filter contains only AND operations, False if any OR operations are found.
*/
def isConjunction(filter: Expression): Boolean = {
filter.collectFirst { case Or(_, _) =>
true
}.isEmpty
}

/**
* Determines if the conditions in an index filter can subsume those in a query filter. This is
* essential to verify if all outputs that satisfy the index filter also satisfy the query
* filter, indicating that the index can potentially optimize the query.
*
* @param indexFilter
* The filter expression defined from the index
* The filter expression defined from the index, required to be a conjunction.
* @param queryFilter
* The filter expression present in the user query
* The filter expression present in the user query, required to be a conjunction.
* @return
* True if the index filter can subsume the query filter, otherwise False
* True if the index filter can subsume the query filter, otherwise False.
*/
def subsume(indexFilter: Expression, queryFilter: Expression): Boolean = {
require(isConjunction(indexFilter), "Index filter is not a conjunction")
require(isConjunction(queryFilter), "Query filter is not a conjunction")

// Flatten a potentially nested conjunction into a sequence of individual conditions
def flattenConditions(filter: Expression): Seq[Expression] = filter match {
case And(left, right) => flattenConditions(left) ++ flattenConditions(right)
case other => Seq(other)
}

val indexConditions = flattenConditions(indexFilter)
val queryConditions = flattenConditions(queryFilter)

// Each index condition must subsume in a query condition
// otherwise it means index data cannot "cover" query condition
// Ensures that every condition in the index filter is subsumed by at least one condition
// in the query filter
indexConditions.forall { indexCondition =>
queryConditions.exists { queryCondition =>
(indexCondition, queryCondition) match {
case (
i @ BinaryComparison(indexCol: Attribute, _),
q @ BinaryComparison(queryCol: Attribute, _)) if indexCol.name == queryCol.name =>
Range(i).subsume(Range(q))
indexComparison @ BinaryComparison(indexCol: Attribute, _),
queryComparison @ BinaryComparison(queryCol: Attribute, _))
if indexCol.name == queryCol.name =>
Range(indexComparison).subsume(Range(queryComparison))
case _ => false
}
}
}
}

case class Bound(value: Literal, inclusive: Boolean) {

def lessThanOrEqualTo(other: Bound): Boolean = {
val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value)
cmp < 0 || (cmp == 0 && inclusive && other.inclusive)
}
}

/**
* Represents a range with optional lower and upper bounds.
*
* @param lower
* The optional lower bound
* @param upper
* The optional upper bound
*/
case class Range(lower: Option[Bound], upper: Option[Bound]) {

/**
* Determines if this range subsumes (completely covers) another range. A range is considered
* to subsume another if its lower bound is less restrictive and its upper bound is more
* restrictive than those of the other range.
*
* @param other
* The other range to compare against.
* @return
* True if this range subsumes the other, otherwise false.
*/
def subsume(other: Range): Boolean = {
val isLowerSubsumed = (lower, other.lower) match {
case (Some(a), Some(b)) => a.lessThanOrEqualTo(b)
case (None, _) => true // `bound1` is unbounded and thus can subsume anything
case (_, None) => false // `bound2` is unbounded and thus cannot be subsumed
case (None, None) => true
}
val isUpperSubsumed = (upper, other.upper) match {
case (Some(a), Some(b)) => b.lessThanOrEqualTo(a)
case (None, _) => true // `bound1` is unbounded and thus can subsume anything
case (_, None) => false // `bound2` is unbounded and thus cannot be subsumed
case (None, None) => true
}
isLowerSubsumed && isUpperSubsumed
// Subsumption check helper for lower and upper bound
def subsume(
thisBound: Option[Bound],
otherBound: Option[Bound],
comp: (Bound, Bound) => Boolean): Boolean =
(thisBound, otherBound) match {
case (Some(a), Some(b)) => comp(a, b)
case (None, _) => true // this is unbounded and thus can subsume any other bound
case (_, None) => false // other is unbounded and thus cannot be subsumed by any
}
subsume(lower, other.lower, _.lessThanOrEqualTo(_)) &&
subsume(upper, other.upper, _.greaterThanOrEqualTo(_))
}
}

object Range {

/**
* Constructs a Range object from a binary comparison expression, translating comparison
* operators into bounds with appropriate inclusivity.
*
* @param condition
* The binary comparison
*/
def apply(condition: BinaryComparison): Range = condition match {
case GreaterThan(_, value: Literal) => Range(Some(Bound(value, inclusive = false)), None)
case GreaterThan(_, value: Literal) =>
Range(Some(Bound(value, inclusive = false)), None)
case GreaterThanOrEqual(_, value: Literal) =>
Range(Some(Bound(value, inclusive = true)), None)
case LessThan(_, value: Literal) => Range(None, Some(Bound(value, inclusive = false)))
case LessThanOrEqual(_, value: Literal) => Range(None, Some(Bound(value, inclusive = true)))
case LessThan(_, value: Literal) =>
Range(None, Some(Bound(value, inclusive = false)))
case LessThanOrEqual(_, value: Literal) =>
Range(None, Some(Bound(value, inclusive = true)))
case EqualTo(_, value: Literal) =>
Range(Some(Bound(value, inclusive = true)), Some(Bound(value, inclusive = true)))
case _ => Range(None, None) // For unsupported or complex conditions
}
}

/**
* Represents a bound (lower or upper) in a range, defined by a literal value and its
* inclusiveness.
*
* @param value
* The literal value defining the bound.
* @param inclusive
* Indicates whether the bound is inclusive.
*/
case class Bound(value: Literal, inclusive: Boolean) {

/**
* Checks if this bound is less than or equal to another bound, considering inclusiveness.
*
* @param other
* The bound to compare against.
* @return
* True if this bound is less than or equal to the other bound.
*/
def lessThanOrEqualTo(other: Bound): Boolean = {
val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value)
cmp < 0 || (cmp == 0 && (inclusive || !other.inclusive))
}

/**
* Checks if this bound is greater than or equal to another bound, considering inclusiveness.
*
* @param other
* The bound to compare against.
* @return
* True if this bound is greater than or equal to the other bound.
*/
def greaterThanOrEqualTo(other: Bound): Boolean = {
val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value)
cmp > 0 || (cmp == 0 && (inclusive || !other.inclusive))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark)
} else {
// Iterate each sub plan tree in the given plan
plan transform {
case subPlan @ Filter(condition, ExtractRelation(relation)) =>
case subPlan @ Filter(condition, ExtractRelation(relation)) if isConjunction(condition) =>
doApply(plan, relation, Some(condition))
.map(newRelation => subPlan.copy(child = newRelation))
.getOrElse(subPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
super.beforeAll()
sql(s"CREATE TABLE $testTable (name STRING, age INT) USING JSON")
sql(s"CREATE TABLE $testTable2 (name STRING) USING JSON")
sql(s"""
| INSERT INTO $testTable
| VALUES
| ('A', 10), ('B', 15), ('C', 20), ('D', 25), ('E', 30),
| ('F', 35), ('G', 40), ('H', 45), ('I', 50), ('J', 55)
| """.stripMargin)

// Mock static create method in FlintClientBuilder used by Flint data source
clientBuilder
Expand All @@ -63,50 +69,56 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
.assertIndexNotUsed(testTable)
}

Seq(
("age = 30", "age = 20", false),
("age = 30", "age < 20", false),
("age = 30", "age > 50", false),
("age > 30 AND age < 60", "age > 20 AND age < 50", false),
("age > 30", "age >= 30", false),
("age <= 30", "age <= 20", false),
("age < 50", "age = 49", false),
("age <= 50", "age = 50", false),
("age > 30 AND age < 60", "age > 40 AND age < 50", false),
(null, "age > 30", false), // no query filter
("age = 30", "age = 30", true),
("age = 30", "age <= 30", true),
("age = 30", "age >= 30", true),
("age = 30", "age > 20 AND age < 50", true),
("age > 30 AND age < 40", "age > 20 AND age < 50", true),
("age >= 30", "age > 29", true),
("age <= 30", "age < 31", true),
("age > 30", null, true) // no index filter
).foreach { case (queryFilter, indexFilter, expectedResult) =>
test(
s"apply partial covering index with [$indexFilter] to query filter [$queryFilter]: $expectedResult") {
val query = if (queryFilter == null) {
s"SELECT name FROM $testTable"
} else {
s"SELECT name FROM $testTable WHERE $queryFilter"
}

val assertion = assertFlintQueryRewriter
.withQuery(query)
.withIndex(
new FlintSparkCoveringIndex(
indexName = "partial",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int"),
filterCondition = Option(indexFilter)))

if (expectedResult) {
assertion.assertIndexUsed(getFlintIndexName("partial", testTable))
} else {
assertion.assertIndexNotUsed(testTable)
// Comprehensive test by cartesian product of the following condition
private val conditions = Seq(
null,
"age = 20",
"age > 20",
"age >= 20",
"age < 20",
"age <= 20",
"age = 50",
"age > 50",
"age >= 50",
"age < 50",
"age <= 50",
"age > 20 AND age < 50",
"age >= 20 AND age < 50",
"age > 20 AND age < 50",
"age >=20 AND age <= 50")
(for {
indexFilter <- conditions
queryFilter <- conditions
} yield (indexFilter, queryFilter)).distinct
.foreach { case (indexFilter, queryFilter) =>
test(s"apply partial covering index with [$indexFilter] to query filter [$queryFilter]") {
def queryWithFilter(condition: String): String =
Option(condition) match {
case None => s"SELECT name FROM $testTable"
case Some(cond) => s"SELECT name FROM $testTable WHERE $cond"
}

// Expect index applied if query result is subset of index data (index filter result)
val queryData = sql(queryWithFilter(queryFilter)).collect().toSet
val indexData = sql(queryWithFilter(indexFilter)).collect().toSet
val expectedResult = queryData.subsetOf(indexData)

val assertion = assertFlintQueryRewriter
.withQuery(queryWithFilter(queryFilter))
.withIndex(
new FlintSparkCoveringIndex(
indexName = "partial",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int"),
filterCondition = Option(indexFilter)))

if (expectedResult) {
assertion.assertIndexUsed(getFlintIndexName("partial", testTable))
} else {
assertion.assertIndexNotUsed(testTable)
}
}
}
}

test("should not apply if covering index is logically deleted") {
assertFlintQueryRewriter
Expand Down

0 comments on commit b5c958a

Please sign in to comment.