diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index cb2e14144..b223ee2c2 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -139,7 +139,12 @@ indexColTypeList ; indexColType - : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX) + : identifier skipType=(PARTITION | MIN_MAX) + | identifier valueSetType + ; + +valueSetType + : VALUE_SET (LEFT_PAREN limit=INTEGER_VALUE RIGHT_PAREN)? ; indexName diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index 6d680ae39..2da98e3b5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -50,7 +50,10 @@ object FlintSparkIndexFactory { case PARTITION => PartitionSkippingStrategy(columnName = columnName, columnType = columnType) case VALUE_SET => - ValueSetSkippingStrategy(columnName = columnName, columnType = columnType) + ValueSetSkippingStrategy( + columnName = columnName, + columnType = columnType, + properties = getSkippingProperties(colInfo)) case MIN_MAX => MinMaxSkippingStrategy(columnName = columnName, columnType = columnType) case other => @@ -90,4 +93,13 @@ object FlintSparkIndexFactory { Some(value.asInstanceOf[String]) } } + + private def getSkippingProperties( + colInfo: java.util.Map[String, AnyRef]): Map[String, String] = { + colInfo + .get("properties") + .asInstanceOf[java.util.Map[String, String]] + .asScala + .toMap + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/function/CollectSetLimit.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/function/CollectSetLimit.scala new file mode 100644 index 000000000..a1655e6bd --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/function/CollectSetLimit.scala @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.function + +import scala.collection.mutable + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, CollectSet, ImperativeAggregate} +import org.apache.spark.sql.types.DataType + +/** + * Collect set of unique values with maximum limit. + */ +@ExpressionDescription( + usage = + "_FUNC_(expr, limit) - Collects and returns a set of unique elements up to maximum limit.", + examples = """ + Examples: + > SELECT _FUNC_(col, 2) FROM VALUES (1), (2), (1) AS tab(col); + [1,2] + > SELECT _FUNC_(col, 1) FROM VALUES (1), (2), (1) AS tab(col); + [1] + """, + note = """ + The function is non-deterministic because the order of collected results depends + on the order of the rows which may be non-deterministic after a shuffle. + """, + group = "agg_funcs", + since = "2.0.0") +case class CollectSetLimit( + child: Expression, + limit: Int, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends Collect[mutable.HashSet[Any]] { + + /** Delegate to collect set (because Scala prohibit case-to-case inheritance) */ + private val collectSet = CollectSet(child, mutableAggBufferOffset, inputAggBufferOffset) + + override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = { + if (buffer.size < limit) { + super.update(buffer, input) + } else { + buffer + } + } + + override protected def convertToBufferElement(value: Any): Any = + collectSet.convertToBufferElement(value) + + override protected val bufferElementType: DataType = collectSet.bufferElementType + + override def createAggregationBuffer(): mutable.HashSet[Any] = + collectSet.createAggregationBuffer() + + override def eval(buffer: mutable.HashSet[Any]): Any = collectSet.eval(buffer) + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "collect_set_limit" +} + +object CollectSetLimit { + + /** Function DSL */ + def collect_set_limit(columnName: String, limit: Int): Column = + new Column( + CollectSetLimit(new Column(columnName).expr, limit) + .toAggregateExpression()) +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index d83af5df5..08d5037da 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -50,7 +50,8 @@ case class FlintSparkSkippingIndex( Map[String, AnyRef]( "kind" -> col.kind.toString, "columnName" -> col.columnName, - "columnType" -> col.columnType).asJava) + "columnType" -> col.columnType, + "properties" -> col.properties.asJava).asJava) .toArray val fieldTypes = @@ -74,7 +75,7 @@ case class FlintSparkSkippingIndex( // Wrap aggregate function with output column name val namedAggFuncs = (outputNames, aggFuncs).zipped.map { case (name, aggFunc) => - new Column(aggFunc.toAggregateExpression().as(name)) + new Column(aggFunc.as(name)) } df.getOrElse(spark.read.table(tableName)) @@ -155,14 +156,20 @@ object FlintSparkSkippingIndex { * * @param colName * indexed column name + * @param properties + * value set skipping properties * @return * index builder */ - def addValueSet(colName: String): Builder = { + def addValueSet(colName: String, properties: Map[String, String] = Map.empty): Builder = { require(tableName.nonEmpty, "table name cannot be empty") val col = findColumn(colName) - addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType)) + addIndexedColumn( + ValueSetSkippingStrategy( + columnName = col.name, + columnType = col.dataType, + properties = properties)) this } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index eeba48cfe..dbf024942 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -32,6 +32,11 @@ trait FlintSparkSkippingStrategy { */ val columnType: String + /** + * Skipping algorithm properties. + */ + val properties: Map[String, String] = Map.empty + /** * @return * output schema mapping from Flint field name to Flint field type @@ -42,7 +47,7 @@ trait FlintSparkSkippingStrategy { * @return * aggregators that generate skipping data structure */ - def getAggregators: Seq[AggregateFunction] + def getAggregators: Seq[Expression] /** * Rewrite a filtering condition on source table into a new predicate on index data based on diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala index dc57830c7..73b63da20 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Max, Min} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.functions.col @@ -29,8 +29,10 @@ case class MinMaxSkippingStrategy( override def outputSchema(): Map[String, String] = Map(minColName -> columnType, maxColName -> columnType) - override def getAggregators: Seq[AggregateFunction] = - Seq(Min(col(columnName).expr), Max(col(columnName).expr)) + override def getAggregators: Seq[Expression] = + Seq( + Min(col(columnName).expr).toAggregateExpression(), + Max(col(columnName).expr).toAggregateExpression()) override def rewritePredicate(predicate: Expression): Option[Expression] = predicate match { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index fe717e7ad..18fec0642 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First} +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.functions.col /** @@ -25,8 +25,8 @@ case class PartitionSkippingStrategy( Map(columnName -> columnType) } - override def getAggregators: Seq[AggregateFunction] = { - Seq(First(col(columnName).expr, ignoreNulls = true)) + override def getAggregators: Seq[Expression] = { + Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression()) } override def rewritePredicate(predicate: Expression): Option[Expression] = diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index f462b70c1..53b2db7e6 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -5,12 +5,13 @@ package org.opensearch.flint.spark.skipping.valueset +import org.opensearch.flint.spark.function.CollectSetLimit.collect_set_limit import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET} +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, CollectSet} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions._ /** * Skipping strategy based on unique column value set. @@ -18,14 +19,26 @@ import org.apache.spark.sql.functions.col case class ValueSetSkippingStrategy( override val kind: SkippingKind = VALUE_SET, override val columnName: String, - override val columnType: String) + override val columnType: String, + override val properties: Map[String, String] = Map.empty) extends FlintSparkSkippingStrategy { override def outputSchema(): Map[String, String] = Map(columnName -> columnType) - override def getAggregators: Seq[AggregateFunction] = - Seq(CollectSet(col(columnName).expr)) + override def getAggregators: Seq[Expression] = { + val limit = getValueSetSizeLimit() + val aggregator = + if (limit == 0) { + collect_set(columnName) + } else { + // val limitPlusOne = limit + 1 + val collectSetLimit = collect_set(columnName) // collect_set_limit(columnName, limitPlusOne) + when(size(collectSetLimit) > limit, lit(null)) + .otherwise(collectSetLimit) + } + Seq(aggregator.expr) + } override def rewritePredicate(predicate: Expression): Option[Expression] = /* @@ -37,4 +50,13 @@ case class ValueSetSkippingStrategy( Some((col(columnName) === value).expr) case _ => None } + + private def getValueSetSizeLimit(): Int = + properties.get("limit").map(_.toInt).getOrElse(DEFAULT_VALUE_SET_SIZE_LIMIT) +} + +object ValueSetSkippingStrategy { + + /** Default limit for value set size collected */ + val DEFAULT_VALUE_SET_SIZE_LIMIT = 100 } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala index 2b0bb6c48..7cd0a28a7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.FlintSpark.RefreshMode import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION} import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ @@ -42,11 +42,19 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A ctx.indexColTypeList().indexColType().forEach { colTypeCtx => val colName = colTypeCtx.identifier().getText - val skipType = SkippingKind.withName(colTypeCtx.skipType.getText) - skipType match { - case PARTITION => indexBuilder.addPartitions(colName) - case VALUE_SET => indexBuilder.addValueSet(colName) - case MIN_MAX => indexBuilder.addMinMax(colName) + if (colTypeCtx.skipType == null) { + if (colTypeCtx.valueSetType().limit == null) { + indexBuilder.addValueSet(colName) + } else { + indexBuilder + .addValueSet(colName, Map("limit" -> colTypeCtx.valueSetType().limit.getText)) + } + } else { + val skipType = SkippingKind.withName(colTypeCtx.skipType.getText) + skipType match { + case PARTITION => indexBuilder.addPartitions(colName) + case MIN_MAX => indexBuilder.addMinMax(colName) + } } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/function/CollectSetLimitSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/function/CollectSetLimitSuite.scala new file mode 100644 index 000000000..e2415bc9a --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/function/CollectSetLimitSuite.scala @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.function + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar.mock + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} + +class CollectSetLimitSuite extends FlintSuite with Matchers { + + var expression: Expression = _ + + override def beforeEach(): Unit = { + super.beforeEach() + + expression = mock[Expression] + when(expression.eval(any[InternalRow])).thenAnswer { invocation => + val row = invocation.getArgument[InternalRow](0) + val firstValue = row.getInt(0) + Literal(firstValue) + } + } + + test("should collect unique elements") { + val collectSetLimit = CollectSetLimit(expression, limit = 2) + var buffer = collectSetLimit.createAggregationBuffer() + + Seq(InternalRow(1), InternalRow(1)).foreach(row => + buffer = collectSetLimit.update(buffer, row)) + assert(buffer.size == 1) + } + + test("should collect unique elements up to the limit") { + val collectSetLimit = CollectSetLimit(expression, limit = 2) + var buffer = collectSetLimit.createAggregationBuffer() + + Seq(InternalRow(1), InternalRow(2), InternalRow(3)).foreach(row => + buffer = collectSetLimit.update(buffer, row)) + assert(buffer.size == 2) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 491b7811a..9760e8cd2 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -58,7 +58,8 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index building job with unique ID column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.outputSchema()).thenReturn(Map("name" -> "string")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr))) + when(indexCol.getAggregators).thenReturn( + Seq(CollectSet(col("name").expr).toAggregateExpression())) val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") @@ -66,276 +67,170 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) } - test("can build index for boolean column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("boolean_col" -> "boolean")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("boolean_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "boolean_col": { - | "type": "boolean" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for string column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("string_col" -> "string")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("string_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "string_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - // TODO: test for osType "text" - - test("can build index for varchar column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("varchar_col" -> "varchar(20)")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("varchar_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "varchar_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for char column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("char_col" -> "char(20)")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("char_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "char_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for long column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("long_col" -> "bigint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("long_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "long_col": { - | "type": "long" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for int column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("int_col" -> "int")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("int_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "int_col": { - | "type": "integer" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for short column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("short_col" -> "smallint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("short_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "short_col": { - | "type": "short" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for byte column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("byte_col" -> "tinyint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("byte_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "byte_col": { - | "type": "byte" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for double column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("double_col" -> "double")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("double_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "double_col": { - | "type": "double" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for float column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("float_col" -> "float")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("float_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "float_col": { - | "type": "float" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for timestamp column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("timestamp_col" -> "timestamp")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("timestamp_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "timestamp_col": { - | "type": "date", - | "format": "strict_date_optional_time_nanos" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for date column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("date_col" -> "date")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("date_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "date_col": { - | "type": "date", - | "format": "strict_date" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for struct column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()) - .thenReturn(Map("struct_col" -> "struct")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("struct_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "struct_col": { - | "properties": { - | "subfield1": { - | "type": "keyword" - | }, - | "subfield2": { - | "type": "integer" - | } - | } - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) + // Test index build for different column type + Seq( + ( + "boolean_col", + "boolean", + """{ + | "boolean_col": { + | "type": "boolean" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "string_col", + "string", + """{ + | "string_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "varchar_col", + "varchar(20)", + """{ + | "varchar_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "char_col", + "char(20)", + """{ + | "char_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "long_col", + "bigint", + """{ + | "long_col": { + | "type": "long" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "int_col", + "int", + """{ + | "int_col": { + | "type": "integer" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "short_col", + "smallint", + """{ + | "short_col": { + | "type": "short" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "byte_col", + "tinyint", + """{ + | "byte_col": { + | "type": "byte" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "double_col", + "double", + """{ + | "double_col": { + | "type": "double" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "float_col", + "float", + """{ + | "float_col": { + | "type": "float" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "timestamp_col", + "timestamp", + """{ + | "timestamp_col": { + | "type": "date", + | "format": "strict_date_optional_time_nanos" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "date_col", + "date", + """{ + | "date_col": { + | "type": "date", + | "format": "strict_date" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "struct_col", + "struct", + """{ + | "struct_col": { + | "properties": { + | "subfield1": { + | "type": "keyword" + | }, + | "subfield2": { + | "type": "integer" + | } + | } + | }, + | "file_path": { + | "type": "keyword" + | } + |}""")).foreach { case (columnName, columnType, expectedSchema) => + test(s"can build index for $columnType column") { + val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) + when(indexCol.outputSchema()).thenReturn(Map(columnName -> columnType)) + when(indexCol.getAggregators).thenReturn( + Seq(CollectSet(col(columnName).expr).toAggregateExpression())) + + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) + schemaShouldMatch(index.metadata(), expectedSchema.stripMargin) + } } test("should fail if get index name without full table name") { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 40cb5c201..4956483b2 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -285,6 +285,24 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { useFlintSparkSkippingFileIndex(hasIndexFilter(col("address") === "Portland")) } + test("can build value set skipping index with limit") { + // Add one more row for Seattle + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Hello', 30, 'Seattle') + | """.stripMargin) + + flint + .skippingIndex() + .onTable(testTable) + .addValueSet("address", Map("limit" -> "1")) + .create() + flint.refreshIndex(testIndex, FULL) + + checkAnswer(flint.queryIndex(testIndex).select("address"), Seq(Row(Seq("Portland")), Row())) + } + test("can build min max skipping index and rewrite applicable query") { flint .skippingIndex() diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index 21de15de7..cf29e0298 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -14,7 +14,7 @@ import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.FlintOpenSearchClient import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName -import org.scalatest.matchers.must.Matchers.{defined, have} +import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} import org.apache.spark.sql.Row @@ -144,6 +144,24 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { indexData.count() shouldBe 2 } + test("create skipping index with value set limit") { + sql(s""" + | CREATE SKIPPING INDEX ON $testTable + | ( + | name VALUE_SET(10) + | ) + | WITH (auto_refresh = true) + | """.stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testIndex) + awaitStreamingComplete(job.get.id.toString) + + val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex) + flint.describeIndex(testIndex) shouldBe defined + indexData.count() shouldBe 2 + } + test("create skipping index if not exists") { sql(s""" | CREATE SKIPPING INDEX