diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala index d04072fd3..fede39a8c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BinaryComparison, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Or} +import org.apache.spark.sql.catalyst.expressions._ /** * Query rewrite helper that provides common utilities for query rewrite rule of various Flint @@ -55,15 +55,15 @@ trait FlintSparkQueryRewriteHelper { // Ensures that every condition in the index filter is subsumed by at least one condition // on the same column in the query filter - indexConditions.forall { indexCondition => - queryConditions.exists { queryCondition => - (indexCondition, queryCondition) match { + indexConditions.forall { indexCond => + queryConditions.exists { queryCond => + (indexCond, queryCond) match { case ( - indexComparison @ BinaryComparison(indexCol: Attribute, _), - queryComparison @ BinaryComparison(queryCol: Attribute, _)) + indexComp @ BinaryComparison(indexCol: Attribute, _), + queryComp @ BinaryComparison(queryCol: Attribute, _)) if indexCol.name == queryCol.name => - Range(indexComparison).subsume(Range(queryComparison)) - case _ => false + Range(indexComp).subsume(Range(queryComp)) + case _ => false // consider as not subsumed for unsupported expression } } } @@ -77,7 +77,7 @@ trait FlintSparkQueryRewriteHelper { * @param upper * The optional upper bound */ - case class Range(lower: Option[Bound], upper: Option[Bound]) { + private case class Range(lower: Option[Bound], upper: Option[Bound]) { /** * Determines if this range subsumes (completely covers) another range. @@ -88,8 +88,13 @@ trait FlintSparkQueryRewriteHelper { * True if this range subsumes the other, otherwise false. */ def subsume(other: Range): Boolean = { + // Unknown range cannot subsume or be subsumed by any + if (this == Range.UNKNOWN || other == Range.UNKNOWN) { + return false + } + // Subsumption check helper for lower and upper bound - def subsume( + def subsumeHelper( thisBound: Option[Bound], otherBound: Option[Bound], comp: (Bound, Bound) => Boolean): Boolean = @@ -98,12 +103,15 @@ trait FlintSparkQueryRewriteHelper { case (None, _) => true // this is unbounded and thus can subsume any other bound case (_, None) => false // other is unbounded and thus cannot be subsumed by any } - subsume(lower, other.lower, _.lessThanOrEqualTo(_)) && - subsume(upper, other.upper, _.greaterThanOrEqualTo(_)) + subsumeHelper(lower, other.lower, _.lessThanOrEqualTo(_)) && + subsumeHelper(upper, other.upper, _.greaterThanOrEqualTo(_)) } } - object Range { + private object Range { + + /** Unknown range for unsupported binary comparison expression */ + private val UNKNOWN: Range = Range(None, None) /** * Constructs a Range object from a binary comparison expression, translating comparison @@ -113,17 +121,17 @@ trait FlintSparkQueryRewriteHelper { * The binary comparison */ def apply(condition: BinaryComparison): Range = condition match { - case GreaterThan(_, value: Literal) => + case GreaterThan(_, Literal(value: Comparable[Any], _)) => Range(Some(Bound(value, inclusive = false)), None) - case GreaterThanOrEqual(_, value: Literal) => + case GreaterThanOrEqual(_, Literal(value: Comparable[Any], _)) => Range(Some(Bound(value, inclusive = true)), None) - case LessThan(_, value: Literal) => + case LessThan(_, Literal(value: Comparable[Any], _)) => Range(None, Some(Bound(value, inclusive = false))) - case LessThanOrEqual(_, value: Literal) => + case LessThanOrEqual(_, Literal(value: Comparable[Any], _)) => Range(None, Some(Bound(value, inclusive = true))) - case EqualTo(_, value: Literal) => + case EqualTo(_, Literal(value: Comparable[Any], _)) => Range(Some(Bound(value, inclusive = true)), Some(Bound(value, inclusive = true))) - case _ => Range(None, None) // Infinity for unsupported or complex conditions + case _ => UNKNOWN } } @@ -136,7 +144,7 @@ trait FlintSparkQueryRewriteHelper { * @param inclusive * Indicates whether the bound is inclusive. */ - case class Bound(value: Literal, inclusive: Boolean) { + private case class Bound(value: Comparable[Any], inclusive: Boolean) { /** * Checks if this bound is less than or equal to another bound, considering inclusiveness. @@ -148,7 +156,7 @@ trait FlintSparkQueryRewriteHelper { * smaller, this bound is inclusive, or both bound are exclusive. */ def lessThanOrEqualTo(other: Bound): Boolean = { - val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value) + val cmp = value.compareTo(other.value) cmp < 0 || (cmp == 0 && (inclusive || !other.inclusive)) } @@ -162,7 +170,7 @@ trait FlintSparkQueryRewriteHelper { * greater, this bound is inclusive, or both bound are exclusive. */ def greaterThanOrEqualTo(other: Bound): Boolean = { - val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value) + val cmp = value.compareTo(other.value) cmp > 0 || (cmp == 0 && (inclusive || !other.inclusive)) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala index 636092c7c..3bdac3c3d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala @@ -114,8 +114,8 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) val indexes = flint .describeIndexes(indexPattern) - .collect { // cast to covering index - case index: FlintSparkCoveringIndex => index + .collect { // cast to covering index and double check table name + case index: FlintSparkCoveringIndex if index.tableName == qualifiedTableName => index } val indexNames = indexes.map(_.name()).mkString(",")