Skip to content

Commit

Permalink
Implement simple query rewrite and update IT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Feb 6, 2024
1 parent dcfdff1 commit d4820cd
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.json4s.CustomSerializer
import org.json4s.JsonAST.JString
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField}

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
Expand Down Expand Up @@ -82,4 +82,39 @@ object FlintSparkSkippingStrategy {
{ case kind: SkippingKind =>
JString(kind.toString)
}))

/**
* Extractor that match the given expression with the index expression in skipping index.
*
* @param indexColName
* indexed column name
*/
case class IndexExpressionMatcher(indexColName: String) {

def unapply(expr: Expression): Option[String] = {
val colName = extractColumnName(expr).mkString(".")
if (colName == indexColName) {
Some(indexColName)
} else {
None
}
}

/*
* In Spark, after analysis, nested field "a.b.c" becomes:
* GetStructField(name="a",
* child=GetStructField(name="b",
* child=AttributeReference(name="c")))
* TODO: To support any index expression, analyze index expression string
*/
private def extractColumnName(expr: Expression): Seq[String] = {
expr match {
case attr: Attribute =>
Seq(attr.name)
case GetStructField(child, _, Some(name)) =>
extractColumnName(child) :+ name
case _ => Seq.empty
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
package org.opensearch.flint.spark.skipping.minmax

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.functions.col
Expand All @@ -35,19 +36,20 @@ case class MinMaxSkippingStrategy(
Max(col(columnName).expr).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
override def rewritePredicate(predicate: Expression): Option[Expression] = {
val IndexExpression = IndexExpressionMatcher(columnName)
predicate match {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case EqualTo(IndexExpression(_), value: Literal) =>
Some((col(minColName) <= value && col(maxColName) >= value).expr)
case LessThan(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case LessThan(IndexExpression(_), value: Literal) =>
Some((col(minColName) < value).expr)
case LessThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case LessThanOrEqual(IndexExpression(_), value: Literal) =>
Some((col(minColName) <= value).expr)
case GreaterThan(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case GreaterThan(IndexExpression(_), value: Literal) =>
Some((col(maxColName) > value).expr)
case GreaterThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case GreaterThanOrEqual(IndexExpression(_), value: Literal) =>
Some((col(maxColName) >= value).expr)
case In(column @ AttributeReference(`columnName`, _, _, _), AllLiterals(literals)) =>
case In(column @ IndexExpression(_), AllLiterals(literals)) =>
/*
* First, convert IN to approximate range check: min(in_list) <= col <= max(in_list)
* to avoid long and maybe unnecessary comparison expressions.
Expand All @@ -62,6 +64,7 @@ case class MinMaxSkippingStrategy(
rewritePredicate(LessThanOrEqual(column, Literal(maxVal))).get))
case _ => None
}
}

/** Need this because Scala pattern match doesn't work for generic type like Seq[Literal] */
object AllLiterals {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
package org.opensearch.flint.spark.skipping.partition

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.functions.col

Expand All @@ -29,11 +30,13 @@ case class PartitionSkippingStrategy(
Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
override def rewritePredicate(predicate: Expression): Option[Expression] = {
val IndexExpression = IndexExpressionMatcher(columnName)
predicate match {
// Column has same name in index data, so just rewrite to the same equation
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case EqualTo(IndexExpression(_), value: Literal) =>
Some((col(columnName) === value).expr)
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
package org.opensearch.flint.spark.skipping.valueset

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET}
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DEFAULT_VALUE_SET_MAX_SIZE, VALUE_SET_MAX_SIZE_KEY}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal}
import org.apache.spark.sql.functions._

/**
Expand Down Expand Up @@ -44,17 +45,19 @@ case class ValueSetSkippingStrategy(
Seq(aggregator.expr)
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
override def rewritePredicate(predicate: Expression): Option[Expression] = {
/*
* This is supposed to be rewritten to ARRAY_CONTAINS(columName, value).
* However, due to push down limitation in Spark, we keep the equation.
*/
val IndexExpression = IndexExpressionMatcher(columnName)
predicate match {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case EqualTo(IndexExpression(_), value: Literal) =>
// Value set maybe null due to maximum size limit restriction
Some((isnull(col(columnName)) || col(columnName) === value).expr)
case _ => None
}
}
}

object ValueSetSkippingStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,36 +647,56 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
test("build skipping index for nested field and rewrite applicable query") {
val testTable = "spark_catalog.default.nested_field_table"
val testIndex = getSkippingIndexName(testTable)
sql(s"""
withTable(testTable) {
sql(s"""
| CREATE TABLE $testTable
| (
| int_col INT,
| struct_col STRUCT<field1: STRUCT<subfield:STRING>, field2: INT>
| )
| USING JSON
|""".stripMargin)
sql(s"""
sql(s"""
| INSERT INTO $testTable
| SELECT /*+ COALESCE(1) */ *
| FROM VALUES
| ( 30, STRUCT(STRUCT("subfieldValue1"),123) ),
| ( 40, STRUCT(STRUCT("subfieldValue2"),456) )
| ( 30, STRUCT(STRUCT("value1"),123) ),
| ( 40, STRUCT(STRUCT("value2"),456) )
|""".stripMargin)
sql(s"""
sql(s"""
| INSERT INTO $testTable
| VALUES ( 50, STRUCT(STRUCT("subfieldValue3"),789) )
| VALUES ( 50, STRUCT(STRUCT("value3"),789) )
|""".stripMargin)

flint
.skippingIndex()
.onTable(testTable)
.addMinMax("struct_col.field2")
.addValueSet("struct_col.field1.subfield")
.create()
flint.refreshIndex(testIndex)

// FIXME: add assertion once https://github.com/opensearch-project/opensearch-spark/issues/233 fixed
deleteTestIndex(testIndex)
flint
.skippingIndex()
.onTable(testTable)
.addMinMax("struct_col.field2")
.addValueSet("struct_col.field1.subfield")
.create()
flint.refreshIndex(testIndex)

// FIXME: add assertion on index data once https://github.com/opensearch-project/opensearch-spark/issues/233 fixed
// Query rewrite nested field
val query1 =
sql(s"SELECT int_col FROM $testTable WHERE struct_col.field2 = 456".stripMargin)
checkAnswer(query1, Row(40))
query1.queryExecution.executedPlan should
useFlintSparkSkippingFileIndex(
hasIndexFilter(
col("MinMax_struct_col.field2_0") <= 456 && col("MinMax_struct_col.field2_1") >= 456))

// Query rewrite deep nested field
val query2 = sql(
s"SELECT int_col FROM $testTable WHERE struct_col.field1.subfield = 'value3'".stripMargin)
checkAnswer(query2, Row(50))
query2.queryExecution.executedPlan should
useFlintSparkSkippingFileIndex(
hasIndexFilter(isnull(col("struct_col.field1.subfield")) ||
col("struct_col.field1.subfield") === "value3"))

deleteTestIndex(testIndex)
}
}

// Custom matcher to check if a SparkPlan uses FlintSparkSkippingFileIndex
Expand Down Expand Up @@ -711,7 +731,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {

MatchResult(
hasExpectedFilter,
"FlintSparkSkippingFileIndex does not have expected filter",
s"FlintSparkSkippingFileIndex does not have expected filter: ${fileIndex.indexFilter}",
"FlintSparkSkippingFileIndex has expected filter")
}
}
Expand Down

0 comments on commit d4820cd

Please sign in to comment.