diff --git a/docs/index.md b/docs/index.md index 7a1c44a0e..88c2bc5e6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -124,6 +124,8 @@ High level API is dependent on query engine implementation. Please see Query Eng #### Skipping Index +The default maximum size for the value set is 100. In cases where a file contains columns with high cardinality values, the value set will become null. This is the trade-off that prevents excessive memory consumption at the cost of not skipping the file. + ```sql CREATE SKIPPING INDEX [IF NOT EXISTS] ON diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index d83af5df5..120ca8219 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -74,7 +74,7 @@ case class FlintSparkSkippingIndex( // Wrap aggregate function with output column name val namedAggFuncs = (outputNames, aggFuncs).zipped.map { case (name, aggFunc) => - new Column(aggFunc.toAggregateExpression().as(name)) + new Column(aggFunc.as(name)) } df.getOrElse(spark.read.table(tableName)) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index eeba48cfe..042c968ec 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -10,7 +10,6 @@ import org.json4s.JsonAST.JString import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction /** * Skipping index strategy that defines skipping data structure building and reading logic. @@ -42,7 +41,7 @@ trait FlintSparkSkippingStrategy { * @return * aggregators that generate skipping data structure */ - def getAggregators: Seq[AggregateFunction] + def getAggregators: Seq[Expression] /** * Rewrite a filtering condition on source table into a new predicate on index data based on diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala index dc57830c7..f7745e7a8 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Max, Min} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.functions.col @@ -29,8 +29,11 @@ case class MinMaxSkippingStrategy( override def outputSchema(): Map[String, String] = Map(minColName -> columnType, maxColName -> columnType) - override def getAggregators: Seq[AggregateFunction] = - Seq(Min(col(columnName).expr), Max(col(columnName).expr)) + override def getAggregators: Seq[Expression] = { + Seq( + Min(col(columnName).expr).toAggregateExpression(), + Max(col(columnName).expr).toAggregateExpression()) + } override def rewritePredicate(predicate: Expression): Option[Expression] = predicate match { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index fe717e7ad..18fec0642 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First} +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.functions.col /** @@ -25,8 +25,8 @@ case class PartitionSkippingStrategy( Map(columnName -> columnType) } - override def getAggregators: Seq[AggregateFunction] = { - Seq(First(col(columnName).expr, ignoreNulls = true)) + override def getAggregators: Seq[Expression] = { + Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression()) } override def rewritePredicate(predicate: Expression): Option[Expression] = diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index f462b70c1..ff2d53d44 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.skipping.valueset import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET} +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, CollectSet} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions._ /** * Skipping strategy based on unique column value set. @@ -24,8 +24,14 @@ case class ValueSetSkippingStrategy( override def outputSchema(): Map[String, String] = Map(columnName -> columnType) - override def getAggregators: Seq[AggregateFunction] = - Seq(CollectSet(col(columnName).expr)) + override def getAggregators: Seq[Expression] = { + val limit = DEFAULT_VALUE_SET_SIZE_LIMIT + val collectSet = collect_set(columnName) + val aggregator = + when(size(collectSet) > limit, lit(null)) + .otherwise(collectSet) + Seq(aggregator.expr) + } override def rewritePredicate(predicate: Expression): Option[Expression] = /* @@ -34,7 +40,16 @@ case class ValueSetSkippingStrategy( */ predicate match { case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => - Some((col(columnName) === value).expr) + // Value set maybe null due to maximum size limit restriction + Some((isnull(col(columnName)) || col(columnName) === value).expr) case _ => None } } + +object ValueSetSkippingStrategy { + + /** + * Default limit for value set size collected. TODO: make this val once it's configurable + */ + var DEFAULT_VALUE_SET_SIZE_LIMIT = 100 +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 491b7811a..9760e8cd2 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -58,7 +58,8 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index building job with unique ID column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.outputSchema()).thenReturn(Map("name" -> "string")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr))) + when(indexCol.getAggregators).thenReturn( + Seq(CollectSet(col("name").expr).toAggregateExpression())) val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") @@ -66,276 +67,170 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) } - test("can build index for boolean column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("boolean_col" -> "boolean")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("boolean_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "boolean_col": { - | "type": "boolean" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for string column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("string_col" -> "string")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("string_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "string_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - // TODO: test for osType "text" - - test("can build index for varchar column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("varchar_col" -> "varchar(20)")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("varchar_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "varchar_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for char column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("char_col" -> "char(20)")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("char_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "char_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for long column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("long_col" -> "bigint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("long_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "long_col": { - | "type": "long" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for int column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("int_col" -> "int")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("int_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "int_col": { - | "type": "integer" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for short column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("short_col" -> "smallint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("short_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "short_col": { - | "type": "short" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for byte column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("byte_col" -> "tinyint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("byte_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "byte_col": { - | "type": "byte" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for double column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("double_col" -> "double")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("double_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "double_col": { - | "type": "double" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for float column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("float_col" -> "float")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("float_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "float_col": { - | "type": "float" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for timestamp column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("timestamp_col" -> "timestamp")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("timestamp_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "timestamp_col": { - | "type": "date", - | "format": "strict_date_optional_time_nanos" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for date column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("date_col" -> "date")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("date_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "date_col": { - | "type": "date", - | "format": "strict_date" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for struct column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()) - .thenReturn(Map("struct_col" -> "struct")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("struct_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "struct_col": { - | "properties": { - | "subfield1": { - | "type": "keyword" - | }, - | "subfield2": { - | "type": "integer" - | } - | } - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) + // Test index build for different column type + Seq( + ( + "boolean_col", + "boolean", + """{ + | "boolean_col": { + | "type": "boolean" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "string_col", + "string", + """{ + | "string_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "varchar_col", + "varchar(20)", + """{ + | "varchar_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "char_col", + "char(20)", + """{ + | "char_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "long_col", + "bigint", + """{ + | "long_col": { + | "type": "long" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "int_col", + "int", + """{ + | "int_col": { + | "type": "integer" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "short_col", + "smallint", + """{ + | "short_col": { + | "type": "short" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "byte_col", + "tinyint", + """{ + | "byte_col": { + | "type": "byte" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "double_col", + "double", + """{ + | "double_col": { + | "type": "double" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "float_col", + "float", + """{ + | "float_col": { + | "type": "float" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "timestamp_col", + "timestamp", + """{ + | "timestamp_col": { + | "type": "date", + | "format": "strict_date_optional_time_nanos" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "date_col", + "date", + """{ + | "date_col": { + | "type": "date", + | "format": "strict_date" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "struct_col", + "struct", + """{ + | "struct_col": { + | "properties": { + | "subfield1": { + | "type": "keyword" + | }, + | "subfield2": { + | "type": "integer" + | } + | } + | }, + | "file_path": { + | "type": "keyword" + | } + |}""")).foreach { case (columnName, columnType, expectedSchema) => + test(s"can build index for $columnType column") { + val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) + when(indexCol.outputSchema()).thenReturn(Map(columnName -> columnType)) + when(indexCol.getAggregators).thenReturn( + Seq(CollectSet(col(columnName).expr).toAggregateExpression())) + + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) + schemaShouldMatch(index.metadata(), expectedSchema.stripMargin) + } } test("should fail if get index name without full table name") { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala index 09de75234..bc81d9fd9 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala @@ -10,7 +10,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Literal} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, isnull} import org.apache.spark.sql.types.StringType class ValueSetSkippingStrategySuite @@ -24,7 +24,8 @@ class ValueSetSkippingStrategySuite private val name = AttributeReference("name", StringType, nullable = false)() test("should rewrite EqualTo(, )") { - EqualTo(name, Literal("hello")) shouldRewriteTo (col("name") === "hello") + EqualTo(name, Literal("hello")) shouldRewriteTo + (isnull(col("name")) || col("name") === "hello") } test("should not rewrite predicate with other column") { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 40cb5c201..9cb4affec 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -12,6 +12,7 @@ import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL} import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.scalatest.matchers.{Matcher, MatchResult} import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -20,7 +21,7 @@ import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.flint.config.FlintSparkConf._ -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, isnull} class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { @@ -31,7 +32,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { override def beforeAll(): Unit = { super.beforeAll() - createPartitionedTable(testTable) + createPartitionedMultiRowTable(testTable) } override def afterEach(): Unit = { @@ -255,34 +256,57 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { .create() flint.refreshIndex(testIndex, FULL) + // Assert index data + checkAnswer( + flint.queryIndex(testIndex).select("year", "month"), + Seq(Row(2023, 4), Row(2023, 5))) + + // Assert query rewrite val query = sql(s""" | SELECT name | FROM $testTable | WHERE year = 2023 AND month = 4 |""".stripMargin) - checkAnswer(query, Row("Hello")) + checkAnswer(query, Seq(Row("Hello"), Row("World"))) query.queryExecution.executedPlan should useFlintSparkSkippingFileIndex(hasIndexFilter(col("year") === 2023 && col("month") === 4)) } test("can build value set skipping index and rewrite applicable query") { - flint - .skippingIndex() - .onTable(testTable) - .addValueSet("address") - .create() - flint.refreshIndex(testIndex, FULL) + val defaultLimit = ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT + try { + ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT = 2 + flint + .skippingIndex() + .onTable(testTable) + .addValueSet("address") + .create() + flint.refreshIndex(testIndex, FULL) - val query = sql(s""" - | SELECT name + // Assert index data + checkAnswer( + flint.queryIndex(testIndex).select("address"), + Seq( + Row("""["Seattle","Portland"]"""), + Row(null) // Value set exceeded limit size is expected to be null + )) + + // Assert query rewrite that works with value set maybe null + val query = sql(s""" + | SELECT age | FROM $testTable | WHERE address = 'Portland' |""".stripMargin) - checkAnswer(query, Row("World")) - query.queryExecution.executedPlan should - useFlintSparkSkippingFileIndex(hasIndexFilter(col("address") === "Portland")) + query.queryExecution.executedPlan should + useFlintSparkSkippingFileIndex( + hasIndexFilter(isnull(col("address")) || col("address") === "Portland")) + checkAnswer(query, Seq(Row(30), Row(50))) + + } finally { + ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT = defaultLimit + } } test("can build min max skipping index and rewrite applicable query") { @@ -293,16 +317,22 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { .create() flint.refreshIndex(testIndex, FULL) + // Assert index data + checkAnswer( + flint.queryIndex(testIndex).select("MinMax_age_0", "MinMax_age_1"), + Seq(Row(20, 30), Row(40, 60))) + + // Assert query rewrite val query = sql(s""" | SELECT name | FROM $testTable - | WHERE age = 25 + | WHERE age = 30 |""".stripMargin) checkAnswer(query, Row("World")) query.queryExecution.executedPlan should useFlintSparkSkippingFileIndex( - hasIndexFilter(col("MinMax_age_0") <= 25 && col("MinMax_age_1") >= 25)) + hasIndexFilter(col("MinMax_age_0") <= 30 && col("MinMax_age_1") >= 30)) } test("should rewrite applicable query with table name without database specified") { @@ -374,7 +404,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | WHERE month = 4 |""".stripMargin) - checkAnswer(query, Seq(Row("Seattle"), Row("Vancouver"))) + checkAnswer(query, Seq(Row("Seattle"), Row("Portland"), Row("Vancouver"))) } } @@ -617,7 +647,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { checkAnswer(query, Row("sample varchar", paddedChar)) query.queryExecution.executedPlan should useFlintSparkSkippingFileIndex( - hasIndexFilter(col("varchar_col") === "sample varchar" && col("char_col") === paddedChar)) + hasIndexFilter((isnull(col("varchar_col")) || col("varchar_col") === "sample varchar") && + (isnull(col("char_col")) || col("char_col") === paddedChar))) flint.deleteIndex(testIndex) } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 29b8b95a6..211ddb57b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -85,6 +85,46 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createPartitionedMultiRowTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | address STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Use hint to insert all rows in a single csv file + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ('Hello', 20, 'Seattle'), + | ('World', 30, 'Portland') + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=5) + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ('Scala', 40, 'Seattle'), + | ('Java', 50, 'Portland'), + | ('Test', 60, 'Vancouver') + |""".stripMargin) + } + protected def createTimeSeriesTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable