Skip to content

Commit

Permalink
Refactor bound
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Jul 5, 2024
1 parent e4dfb8b commit f58d10e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.spark

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BinaryComparison, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Or}
import org.apache.spark.sql.catalyst.expressions._

/**
* Query rewrite helper that provides common utilities for query rewrite rule of various Flint
Expand Down Expand Up @@ -55,15 +55,15 @@ trait FlintSparkQueryRewriteHelper {

// Ensures that every condition in the index filter is subsumed by at least one condition
// on the same column in the query filter
indexConditions.forall { indexCondition =>
queryConditions.exists { queryCondition =>
(indexCondition, queryCondition) match {
indexConditions.forall { indexCond =>
queryConditions.exists { queryCond =>
(indexCond, queryCond) match {
case (
indexComparison @ BinaryComparison(indexCol: Attribute, _),
queryComparison @ BinaryComparison(queryCol: Attribute, _))
indexComp @ BinaryComparison(indexCol: Attribute, _),
queryComp @ BinaryComparison(queryCol: Attribute, _))
if indexCol.name == queryCol.name =>
Range(indexComparison).subsume(Range(queryComparison))
case _ => false
Range(indexComp).subsume(Range(queryComp))
case _ => false // consider as not subsumed for unsupported expression
}
}
}
Expand All @@ -77,7 +77,7 @@ trait FlintSparkQueryRewriteHelper {
* @param upper
* The optional upper bound
*/
case class Range(lower: Option[Bound], upper: Option[Bound]) {
private case class Range(lower: Option[Bound], upper: Option[Bound]) {

/**
* Determines if this range subsumes (completely covers) another range.
Expand All @@ -88,8 +88,13 @@ trait FlintSparkQueryRewriteHelper {
* True if this range subsumes the other, otherwise false.
*/
def subsume(other: Range): Boolean = {
// Unknown range cannot subsume or be subsumed by any
if (this == Range.UNKNOWN || other == Range.UNKNOWN) {
return false
}

// Subsumption check helper for lower and upper bound
def subsume(
def subsumeHelper(
thisBound: Option[Bound],
otherBound: Option[Bound],
comp: (Bound, Bound) => Boolean): Boolean =
Expand All @@ -98,12 +103,15 @@ trait FlintSparkQueryRewriteHelper {
case (None, _) => true // this is unbounded and thus can subsume any other bound
case (_, None) => false // other is unbounded and thus cannot be subsumed by any
}
subsume(lower, other.lower, _.lessThanOrEqualTo(_)) &&
subsume(upper, other.upper, _.greaterThanOrEqualTo(_))
subsumeHelper(lower, other.lower, _.lessThanOrEqualTo(_)) &&
subsumeHelper(upper, other.upper, _.greaterThanOrEqualTo(_))
}
}

object Range {
private object Range {

/** Unknown range for unsupported binary comparison expression */
private val UNKNOWN: Range = Range(None, None)

/**
* Constructs a Range object from a binary comparison expression, translating comparison
Expand All @@ -113,17 +121,17 @@ trait FlintSparkQueryRewriteHelper {
* The binary comparison
*/
def apply(condition: BinaryComparison): Range = condition match {
case GreaterThan(_, value: Literal) =>
case GreaterThan(_, Literal(value: Comparable[Any], _)) =>
Range(Some(Bound(value, inclusive = false)), None)
case GreaterThanOrEqual(_, value: Literal) =>
case GreaterThanOrEqual(_, Literal(value: Comparable[Any], _)) =>
Range(Some(Bound(value, inclusive = true)), None)
case LessThan(_, value: Literal) =>
case LessThan(_, Literal(value: Comparable[Any], _)) =>
Range(None, Some(Bound(value, inclusive = false)))
case LessThanOrEqual(_, value: Literal) =>
case LessThanOrEqual(_, Literal(value: Comparable[Any], _)) =>
Range(None, Some(Bound(value, inclusive = true)))
case EqualTo(_, value: Literal) =>
case EqualTo(_, Literal(value: Comparable[Any], _)) =>
Range(Some(Bound(value, inclusive = true)), Some(Bound(value, inclusive = true)))
case _ => Range(None, None) // Infinity for unsupported or complex conditions
case _ => UNKNOWN
}
}

Expand All @@ -136,7 +144,7 @@ trait FlintSparkQueryRewriteHelper {
* @param inclusive
* Indicates whether the bound is inclusive.
*/
case class Bound(value: Literal, inclusive: Boolean) {
private case class Bound(value: Comparable[Any], inclusive: Boolean) {

/**
* Checks if this bound is less than or equal to another bound, considering inclusiveness.
Expand All @@ -148,7 +156,7 @@ trait FlintSparkQueryRewriteHelper {
* smaller, this bound is inclusive, or both bound are exclusive.
*/
def lessThanOrEqualTo(other: Bound): Boolean = {
val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value)
val cmp = value.compareTo(other.value)
cmp < 0 || (cmp == 0 && (inclusive || !other.inclusive))
}

Expand All @@ -162,7 +170,7 @@ trait FlintSparkQueryRewriteHelper {
* greater, this bound is inclusive, or both bound are exclusive.
*/
def greaterThanOrEqualTo(other: Bound): Boolean = {
val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value)
val cmp = value.compareTo(other.value)
cmp > 0 || (cmp == 0 && (inclusive || !other.inclusive))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark)
val indexes =
flint
.describeIndexes(indexPattern)
.collect { // cast to covering index
case index: FlintSparkCoveringIndex => index
.collect { // cast to covering index and double check table name
case index: FlintSparkCoveringIndex if index.tableName == qualifiedTableName => index
}

val indexNames = indexes.map(_.name()).mkString(",")
Expand Down

0 comments on commit f58d10e

Please sign in to comment.