From 60684fd53a0d614e6de802bbcc7c00fe47800ae3 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 7 Feb 2024 09:34:35 -0800 Subject: [PATCH] Address PR comments Signed-off-by: Chen Dai --- .../skipping/FlintSparkSkippingStrategy.scala | 8 ++-- .../minmax/MinMaxSkippingStrategy.scala | 40 ++++++++++++------- .../partition/PartitionSkippingStrategy.scala | 8 ++-- .../valueset/ValueSetSkippingStrategy.scala | 8 ++-- .../FlintSparkSkippingIndexITSuite.scala | 2 +- 5 files changed, 40 insertions(+), 26 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index 822df478b..06b6daa13 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -9,7 +9,9 @@ import org.json4s.CustomSerializer import org.json4s.JsonAST.JString import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField} +import org.apache.spark.sql.functions.col /** * Skipping index strategy that defines skipping data structure building and reading logic. @@ -89,12 +91,12 @@ object FlintSparkSkippingStrategy { * @param indexColName * indexed column name */ - case class IndexExpressionMatcher(indexColName: String) { + case class IndexColumnExtractor(indexColName: String) { - def unapply(expr: Expression): Option[String] = { + def unapply(expr: Expression): Option[Column] = { val colName = extractColumnName(expr).mkString(".") if (colName == indexColName) { - Some(indexColName) + Some(col(indexColName)) } else { None } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala index 2e7ab8057..edcc24c26 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala @@ -6,9 +6,10 @@ package org.opensearch.flint.spark.skipping.minmax import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind} +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min} import org.apache.spark.sql.catalyst.util.TypeUtils @@ -37,19 +38,19 @@ case class MinMaxSkippingStrategy( } override def rewritePredicate(predicate: Expression): Option[Expression] = { - val IndexExpression = IndexExpressionMatcher(columnName) + val IndexColumn = MinMaxIndexColumnExtractor(IndexColumnExtractor(columnName)) predicate match { - case EqualTo(IndexExpression(_), value: Literal) => - Some((col(minColName) <= value && col(maxColName) >= value).expr) - case LessThan(IndexExpression(_), value: Literal) => - Some((col(minColName) < value).expr) - case LessThanOrEqual(IndexExpression(_), value: Literal) => - Some((col(minColName) <= value).expr) - case GreaterThan(IndexExpression(_), value: Literal) => - Some((col(maxColName) > value).expr) - case GreaterThanOrEqual(IndexExpression(_), value: Literal) => - Some((col(maxColName) >= value).expr) - case In(column @ IndexExpression(_), AllLiterals(literals)) => + case EqualTo(IndexColumn(minIndexCol, maxIndexCol), value: Literal) => + Some((minIndexCol <= value && maxIndexCol >= value).expr) + case LessThan(IndexColumn(minIndexCol, _), value: Literal) => + Some((minIndexCol < value).expr) + case LessThanOrEqual(IndexColumn(minIndexCol, _), value: Literal) => + Some((minIndexCol <= value).expr) + case GreaterThan(IndexColumn(_, maxIndexCol), value: Literal) => + Some((maxIndexCol > value).expr) + case GreaterThanOrEqual(IndexColumn(_, maxIndexCol), value: Literal) => + Some((maxIndexCol >= value).expr) + case In(column @ IndexColumn(_), AllLiterals(literals)) => /* * First, convert IN to approximate range check: min(in_list) <= col <= max(in_list) * to avoid long and maybe unnecessary comparison expressions. @@ -66,8 +67,19 @@ case class MinMaxSkippingStrategy( } } + /** Extractor that returns MinMax index column if the given expression matched */ + private case class MinMaxIndexColumnExtractor(IndexColumn: IndexColumnExtractor) { + + def unapply(expr: Expression): Option[(Column, Column)] = { + expr match { + case IndexColumn(_) => Some((col(minColName), col(maxColName))) + case _ => None + } + } + } + /** Need this because Scala pattern match doesn't work for generic type like Seq[Literal] */ - object AllLiterals { + private object AllLiterals { def unapply(values: Seq[Expression]): Option[Seq[Literal]] = { if (values.forall(_.isInstanceOf[Literal])) { Some(values.asInstanceOf[Seq[Literal]]) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index 67f3ccb3f..21d6dc836 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -6,7 +6,7 @@ package org.opensearch.flint.spark.skipping.partition import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind} import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal} @@ -31,11 +31,11 @@ case class PartitionSkippingStrategy( } override def rewritePredicate(predicate: Expression): Option[Expression] = { - val IndexExpression = IndexExpressionMatcher(columnName) + val IndexColumn = IndexColumnExtractor(columnName) predicate match { // Column has same name in index data, so just rewrite to the same equation - case EqualTo(IndexExpression(_), value: Literal) => - Some((col(columnName) === value).expr) + case EqualTo(IndexColumn(indexCol), value: Literal) => + Some((indexCol === value).expr) case _ => None } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index d095c797b..18f573949 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -6,7 +6,7 @@ package org.opensearch.flint.spark.skipping.valueset import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET} import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DEFAULT_VALUE_SET_MAX_SIZE, VALUE_SET_MAX_SIZE_KEY} @@ -50,11 +50,11 @@ case class ValueSetSkippingStrategy( * This is supposed to be rewritten to ARRAY_CONTAINS(columName, value). * However, due to push down limitation in Spark, we keep the equation. */ - val IndexExpression = IndexExpressionMatcher(columnName) + val IndexColumn = IndexColumnExtractor(columnName) predicate match { - case EqualTo(IndexExpression(_), value: Literal) => + case EqualTo(IndexColumn(indexCol), value: Literal) => // Value set maybe null due to maximum size limit restriction - Some((isnull(col(columnName)) || col(columnName) === value).expr) + Some((isnull(indexCol) || indexCol === value).expr) case _ => None } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index f9366f7bc..58400aa81 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -732,7 +732,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { MatchResult( hasExpectedFilter, s"FlintSparkSkippingFileIndex does not have expected filter: ${fileIndex.indexFilter}", - "FlintSparkSkippingFileIndex has expected filter") + s"FlintSparkSkippingFileIndex has expected filter: ${fileIndex.indexFilter}") } }