From d4820cde9d85f62133e70a256d0849fcd793b006 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Tue, 6 Feb 2024 11:21:56 -0800 Subject: [PATCH] Implement simple query rewrite and update IT Signed-off-by: Chen Dai --- .../skipping/FlintSparkSkippingStrategy.scala | 37 ++++++++++++- .../minmax/MinMaxSkippingStrategy.scala | 19 ++++--- .../partition/PartitionSkippingStrategy.scala | 9 ++-- .../valueset/ValueSetSkippingStrategy.scala | 9 ++-- .../FlintSparkSkippingIndexITSuite.scala | 54 +++++++++++++------ 5 files changed, 96 insertions(+), 32 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index 2569f06fa..822df478b 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -9,7 +9,7 @@ import org.json4s.CustomSerializer import org.json4s.JsonAST.JString import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField} /** * Skipping index strategy that defines skipping data structure building and reading logic. @@ -82,4 +82,39 @@ object FlintSparkSkippingStrategy { { case kind: SkippingKind => JString(kind.toString) })) + + /** + * Extractor that match the given expression with the index expression in skipping index. + * + * @param indexColName + * indexed column name + */ + case class IndexExpressionMatcher(indexColName: String) { + + def unapply(expr: Expression): Option[String] = { + val colName = extractColumnName(expr).mkString(".") + if (colName == indexColName) { + Some(indexColName) + } else { + None + } + } + + /* + * In Spark, after analysis, nested field "a.b.c" becomes: + * GetStructField(name="a", + * child=GetStructField(name="b", + * child=AttributeReference(name="c"))) + * TODO: To support any index expression, analyze index expression string + */ + private def extractColumnName(expr: Expression): Seq[String] = { + expr match { + case attr: Attribute => + Seq(attr.name) + case GetStructField(child, _, Some(name)) => + extractColumnName(child) :+ name + case _ => Seq.empty + } + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala index f7745e7a8..2e7ab8057 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala @@ -6,9 +6,10 @@ package org.opensearch.flint.spark.skipping.minmax import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.functions.col @@ -35,19 +36,20 @@ case class MinMaxSkippingStrategy( Max(col(columnName).expr).toAggregateExpression()) } - override def rewritePredicate(predicate: Expression): Option[Expression] = + override def rewritePredicate(predicate: Expression): Option[Expression] = { + val IndexExpression = IndexExpressionMatcher(columnName) predicate match { - case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => + case EqualTo(IndexExpression(_), value: Literal) => Some((col(minColName) <= value && col(maxColName) >= value).expr) - case LessThan(AttributeReference(`columnName`, _, _, _), value: Literal) => + case LessThan(IndexExpression(_), value: Literal) => Some((col(minColName) < value).expr) - case LessThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) => + case LessThanOrEqual(IndexExpression(_), value: Literal) => Some((col(minColName) <= value).expr) - case GreaterThan(AttributeReference(`columnName`, _, _, _), value: Literal) => + case GreaterThan(IndexExpression(_), value: Literal) => Some((col(maxColName) > value).expr) - case GreaterThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) => + case GreaterThanOrEqual(IndexExpression(_), value: Literal) => Some((col(maxColName) >= value).expr) - case In(column @ AttributeReference(`columnName`, _, _, _), AllLiterals(literals)) => + case In(column @ IndexExpression(_), AllLiterals(literals)) => /* * First, convert IN to approximate range check: min(in_list) <= col <= max(in_list) * to avoid long and maybe unnecessary comparison expressions. @@ -62,6 +64,7 @@ case class MinMaxSkippingStrategy( rewritePredicate(LessThanOrEqual(column, Literal(maxVal))).get)) case _ => None } + } /** Need this because Scala pattern match doesn't work for generic type like Seq[Literal] */ object AllLiterals { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index 18fec0642..67f3ccb3f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -6,9 +6,10 @@ package org.opensearch.flint.spark.skipping.partition import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.functions.col @@ -29,11 +30,13 @@ case class PartitionSkippingStrategy( Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression()) } - override def rewritePredicate(predicate: Expression): Option[Expression] = + override def rewritePredicate(predicate: Expression): Option[Expression] = { + val IndexExpression = IndexExpressionMatcher(columnName) predicate match { // Column has same name in index data, so just rewrite to the same equation - case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => + case EqualTo(IndexExpression(_), value: Literal) => Some((col(columnName) === value).expr) case _ => None } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index 1db9e3d32..d095c797b 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -6,10 +6,11 @@ package org.opensearch.flint.spark.skipping.valueset import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexExpressionMatcher import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET} import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DEFAULT_VALUE_SET_MAX_SIZE, VALUE_SET_MAX_SIZE_KEY} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal} import org.apache.spark.sql.functions._ /** @@ -44,17 +45,19 @@ case class ValueSetSkippingStrategy( Seq(aggregator.expr) } - override def rewritePredicate(predicate: Expression): Option[Expression] = + override def rewritePredicate(predicate: Expression): Option[Expression] = { /* * This is supposed to be rewritten to ARRAY_CONTAINS(columName, value). * However, due to push down limitation in Spark, we keep the equation. */ + val IndexExpression = IndexExpressionMatcher(columnName) predicate match { - case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => + case EqualTo(IndexExpression(_), value: Literal) => // Value set maybe null due to maximum size limit restriction Some((isnull(col(columnName)) || col(columnName) === value).expr) case _ => None } + } } object ValueSetSkippingStrategy { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index ec8fc0df9..f9366f7bc 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -647,7 +647,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { test("build skipping index for nested field and rewrite applicable query") { val testTable = "spark_catalog.default.nested_field_table" val testIndex = getSkippingIndexName(testTable) - sql(s""" + withTable(testTable) { + sql(s""" | CREATE TABLE $testTable | ( | int_col INT, @@ -655,28 +656,47 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | ) | USING JSON |""".stripMargin) - sql(s""" + sql(s""" | INSERT INTO $testTable | SELECT /*+ COALESCE(1) */ * | FROM VALUES - | ( 30, STRUCT(STRUCT("subfieldValue1"),123) ), - | ( 40, STRUCT(STRUCT("subfieldValue2"),456) ) + | ( 30, STRUCT(STRUCT("value1"),123) ), + | ( 40, STRUCT(STRUCT("value2"),456) ) |""".stripMargin) - sql(s""" + sql(s""" | INSERT INTO $testTable - | VALUES ( 50, STRUCT(STRUCT("subfieldValue3"),789) ) + | VALUES ( 50, STRUCT(STRUCT("value3"),789) ) |""".stripMargin) - flint - .skippingIndex() - .onTable(testTable) - .addMinMax("struct_col.field2") - .addValueSet("struct_col.field1.subfield") - .create() - flint.refreshIndex(testIndex) - - // FIXME: add assertion once https://github.com/opensearch-project/opensearch-spark/issues/233 fixed - deleteTestIndex(testIndex) + flint + .skippingIndex() + .onTable(testTable) + .addMinMax("struct_col.field2") + .addValueSet("struct_col.field1.subfield") + .create() + flint.refreshIndex(testIndex) + + // FIXME: add assertion on index data once https://github.com/opensearch-project/opensearch-spark/issues/233 fixed + // Query rewrite nested field + val query1 = + sql(s"SELECT int_col FROM $testTable WHERE struct_col.field2 = 456".stripMargin) + checkAnswer(query1, Row(40)) + query1.queryExecution.executedPlan should + useFlintSparkSkippingFileIndex( + hasIndexFilter( + col("MinMax_struct_col.field2_0") <= 456 && col("MinMax_struct_col.field2_1") >= 456)) + + // Query rewrite deep nested field + val query2 = sql( + s"SELECT int_col FROM $testTable WHERE struct_col.field1.subfield = 'value3'".stripMargin) + checkAnswer(query2, Row(50)) + query2.queryExecution.executedPlan should + useFlintSparkSkippingFileIndex( + hasIndexFilter(isnull(col("struct_col.field1.subfield")) || + col("struct_col.field1.subfield") === "value3")) + + deleteTestIndex(testIndex) + } } // Custom matcher to check if a SparkPlan uses FlintSparkSkippingFileIndex @@ -711,7 +731,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { MatchResult( hasExpectedFilter, - "FlintSparkSkippingFileIndex does not have expected filter", + s"FlintSparkSkippingFileIndex does not have expected filter: ${fileIndex.indexFilter}", "FlintSparkSkippingFileIndex has expected filter") } }