Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Feb 7, 2024
1 parent d4820cd commit 60684fd
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import org.json4s.CustomSerializer
import org.json4s.JsonAST.JString
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField}
import org.apache.spark.sql.functions.col

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
Expand Down Expand Up @@ -89,12 +91,12 @@ object FlintSparkSkippingStrategy {
* @param indexColName
* indexed column name
*/
case class IndexExpressionMatcher(indexColName: String) {
case class IndexColumnExtractor(indexColName: String) {

def unapply(expr: Expression): Option[String] = {
def unapply(expr: Expression): Option[Column] = {
val colName = extractColumnName(expr).mkString(".")
if (colName == indexColName) {
Some(indexColName)
Some(col(indexColName))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
package org.opensearch.flint.spark.skipping.minmax

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind}

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min}
import org.apache.spark.sql.catalyst.util.TypeUtils
Expand Down Expand Up @@ -37,19 +38,19 @@ case class MinMaxSkippingStrategy(
}

override def rewritePredicate(predicate: Expression): Option[Expression] = {
val IndexExpression = IndexExpressionMatcher(columnName)
val IndexColumn = MinMaxIndexColumnExtractor(IndexColumnExtractor(columnName))
predicate match {
case EqualTo(IndexExpression(_), value: Literal) =>
Some((col(minColName) <= value && col(maxColName) >= value).expr)
case LessThan(IndexExpression(_), value: Literal) =>
Some((col(minColName) < value).expr)
case LessThanOrEqual(IndexExpression(_), value: Literal) =>
Some((col(minColName) <= value).expr)
case GreaterThan(IndexExpression(_), value: Literal) =>
Some((col(maxColName) > value).expr)
case GreaterThanOrEqual(IndexExpression(_), value: Literal) =>
Some((col(maxColName) >= value).expr)
case In(column @ IndexExpression(_), AllLiterals(literals)) =>
case EqualTo(IndexColumn(minIndexCol, maxIndexCol), value: Literal) =>
Some((minIndexCol <= value && maxIndexCol >= value).expr)
case LessThan(IndexColumn(minIndexCol, _), value: Literal) =>
Some((minIndexCol < value).expr)
case LessThanOrEqual(IndexColumn(minIndexCol, _), value: Literal) =>
Some((minIndexCol <= value).expr)
case GreaterThan(IndexColumn(_, maxIndexCol), value: Literal) =>
Some((maxIndexCol > value).expr)
case GreaterThanOrEqual(IndexColumn(_, maxIndexCol), value: Literal) =>
Some((maxIndexCol >= value).expr)
case In(column @ IndexColumn(_), AllLiterals(literals)) =>
/*
* First, convert IN to approximate range check: min(in_list) <= col <= max(in_list)
* to avoid long and maybe unnecessary comparison expressions.
Expand All @@ -66,8 +67,19 @@ case class MinMaxSkippingStrategy(
}
}

/** Extractor that returns MinMax index column if the given expression matched */
private case class MinMaxIndexColumnExtractor(IndexColumn: IndexColumnExtractor) {

def unapply(expr: Expression): Option[(Column, Column)] = {
expr match {
case IndexColumn(_) => Some((col(minColName), col(maxColName)))
case _ => None
}
}
}

/** Need this because Scala pattern match doesn't work for generic type like Seq[Literal] */
object AllLiterals {
private object AllLiterals {
def unapply(values: Seq[Expression]): Option[Seq[Literal]] = {
if (values.forall(_.isInstanceOf[Literal])) {
Some(values.asInstanceOf[Seq[Literal]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.flint.spark.skipping.partition

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal}
Expand All @@ -31,11 +31,11 @@ case class PartitionSkippingStrategy(
}

override def rewritePredicate(predicate: Expression): Option[Expression] = {
val IndexExpression = IndexExpressionMatcher(columnName)
val IndexColumn = IndexColumnExtractor(columnName)
predicate match {
// Column has same name in index data, so just rewrite to the same equation
case EqualTo(IndexExpression(_), value: Literal) =>
Some((col(columnName) === value).expr)
case EqualTo(IndexColumn(indexCol), value: Literal) =>
Some((indexCol === value).expr)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.flint.spark.skipping.valueset

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET}
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DEFAULT_VALUE_SET_MAX_SIZE, VALUE_SET_MAX_SIZE_KEY}

Expand Down Expand Up @@ -50,11 +50,11 @@ case class ValueSetSkippingStrategy(
* This is supposed to be rewritten to ARRAY_CONTAINS(columName, value).
* However, due to push down limitation in Spark, we keep the equation.
*/
val IndexExpression = IndexExpressionMatcher(columnName)
val IndexColumn = IndexColumnExtractor(columnName)
predicate match {
case EqualTo(IndexExpression(_), value: Literal) =>
case EqualTo(IndexColumn(indexCol), value: Literal) =>
// Value set maybe null due to maximum size limit restriction
Some((isnull(col(columnName)) || col(columnName) === value).expr)
Some((isnull(indexCol) || indexCol === value).expr)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
MatchResult(
hasExpectedFilter,
s"FlintSparkSkippingFileIndex does not have expected filter: ${fileIndex.indexFilter}",
"FlintSparkSkippingFileIndex has expected filter")
s"FlintSparkSkippingFileIndex has expected filter: ${fileIndex.indexFilter}")
}
}

Expand Down

0 comments on commit 60684fd

Please sign in to comment.