diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index aa9d5e420..2da98e3b5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -53,7 +53,7 @@ object FlintSparkIndexFactory { ValueSetSkippingStrategy( columnName = columnName, columnType = columnType, - limit = colInfo.get("limit").asInstanceOf[Int]) + properties = getSkippingProperties(colInfo)) case MIN_MAX => MinMaxSkippingStrategy(columnName = columnName, columnType = columnType) case other => @@ -93,4 +93,13 @@ object FlintSparkIndexFactory { Some(value.asInstanceOf[String]) } } + + private def getSkippingProperties( + colInfo: java.util.Map[String, AnyRef]): Map[String, String] = { + colInfo + .get("properties") + .asInstanceOf[java.util.Map[String, String]] + .asScala + .toMap + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index add890695..08d5037da 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -50,7 +50,8 @@ case class FlintSparkSkippingIndex( Map[String, AnyRef]( "kind" -> col.kind.toString, "columnName" -> col.columnName, - "columnType" -> col.columnType).asJava) + "columnType" -> col.columnType, + "properties" -> col.properties.asJava).asJava) .toArray val fieldTypes = @@ -155,31 +156,20 @@ object FlintSparkSkippingIndex { * * @param colName * indexed column name + * @param properties + * value set skipping properties * @return * index builder */ - def addValueSet(colName: String): Builder = { - require(tableName.nonEmpty, "table name cannot be empty") - - val col = findColumn(colName) - addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType)) - this - } - - /** - * Add value set skipping indexed column. - * - * @param colName - * indexed column name - * @return - * index builder - */ - def addValueSet(colName: String, limit: Int): Builder = { + def addValueSet(colName: String, properties: Map[String, String] = Map.empty): Builder = { require(tableName.nonEmpty, "table name cannot be empty") val col = findColumn(colName) addIndexedColumn( - ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType, limit = limit)) + ValueSetSkippingStrategy( + columnName = col.name, + columnType = col.dataType, + properties = properties)) this } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index dbb8250e7..dbf024942 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -32,6 +32,11 @@ trait FlintSparkSkippingStrategy { */ val columnType: String + /** + * Skipping algorithm properties. + */ + val properties: Map[String, String] = Map.empty + /** * @return * output schema mapping from Flint field name to Flint field type diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index d780cda06..53b2db7e6 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -8,6 +8,7 @@ package org.opensearch.flint.spark.skipping.valueset import org.opensearch.flint.spark.function.CollectSetLimit.collect_set_limit import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET} +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} import org.apache.spark.sql.functions._ @@ -19,19 +20,21 @@ case class ValueSetSkippingStrategy( override val kind: SkippingKind = VALUE_SET, override val columnName: String, override val columnType: String, - limit: Int = 100) + override val properties: Map[String, String] = Map.empty) extends FlintSparkSkippingStrategy { override def outputSchema(): Map[String, String] = Map(columnName -> columnType) override def getAggregators: Seq[Expression] = { + val limit = getValueSetSizeLimit() val aggregator = if (limit == 0) { collect_set(columnName) } else { - val collectSetLimit = collect_set_limit(columnName, limit + 1) - when(size(collectSetLimit) === limit + 1, lit(null)) + // val limitPlusOne = limit + 1 + val collectSetLimit = collect_set(columnName) // collect_set_limit(columnName, limitPlusOne) + when(size(collectSetLimit) > limit, lit(null)) .otherwise(collectSetLimit) } Seq(aggregator.expr) @@ -47,4 +50,13 @@ case class ValueSetSkippingStrategy( Some((col(columnName) === value).expr) case _ => None } + + private def getValueSetSizeLimit(): Int = + properties.get("limit").map(_.toInt).getOrElse(DEFAULT_VALUE_SET_SIZE_LIMIT) +} + +object ValueSetSkippingStrategy { + + /** Default limit for value set size collected */ + val DEFAULT_VALUE_SET_SIZE_LIMIT = 100 } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala index fa922d713..7cd0a28a7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala @@ -46,7 +46,8 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A if (colTypeCtx.valueSetType().limit == null) { indexBuilder.addValueSet(colName) } else { - indexBuilder.addValueSet(colName, colTypeCtx.valueSetType().limit.getText.toInt) + indexBuilder + .addValueSet(colName, Map("limit" -> colTypeCtx.valueSetType().limit.getText)) } } else { val skipType = SkippingKind.withName(colTypeCtx.skipType.getText) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 9ddc2ac46..9760e8cd2 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -67,289 +67,170 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) } - test("can build index for boolean column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("boolean_col" -> "boolean")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("boolean_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "boolean_col": { - | "type": "boolean" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for string column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("string_col" -> "string")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("string_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "string_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - // TODO: test for osType "text" - - test("can build index for varchar column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("varchar_col" -> "varchar(20)")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("varchar_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "varchar_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for char column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("char_col" -> "char(20)")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("char_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "char_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for long column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("long_col" -> "bigint")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("long_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "long_col": { - | "type": "long" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for int column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("int_col" -> "int")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("int_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "int_col": { - | "type": "integer" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for short column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("short_col" -> "smallint")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("short_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "short_col": { - | "type": "short" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for byte column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("byte_col" -> "tinyint")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("byte_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "byte_col": { - | "type": "byte" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for double column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("double_col" -> "double")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("double_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "double_col": { - | "type": "double" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for float column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("float_col" -> "float")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("float_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "float_col": { - | "type": "float" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for timestamp column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("timestamp_col" -> "timestamp")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("timestamp_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "timestamp_col": { - | "type": "date", - | "format": "strict_date_optional_time_nanos" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for date column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("date_col" -> "date")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("date_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "date_col": { - | "type": "date", - | "format": "strict_date" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for struct column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()) - .thenReturn(Map("struct_col" -> "struct")) - when(indexCol.getAggregators).thenReturn( - Seq(CollectSet(col("struct_col").expr).toAggregateExpression())) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "struct_col": { - | "properties": { - | "subfield1": { - | "type": "keyword" - | }, - | "subfield2": { - | "type": "integer" - | } - | } - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) + // Test index build for different column type + Seq( + ( + "boolean_col", + "boolean", + """{ + | "boolean_col": { + | "type": "boolean" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "string_col", + "string", + """{ + | "string_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "varchar_col", + "varchar(20)", + """{ + | "varchar_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "char_col", + "char(20)", + """{ + | "char_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "long_col", + "bigint", + """{ + | "long_col": { + | "type": "long" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "int_col", + "int", + """{ + | "int_col": { + | "type": "integer" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "short_col", + "smallint", + """{ + | "short_col": { + | "type": "short" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "byte_col", + "tinyint", + """{ + | "byte_col": { + | "type": "byte" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "double_col", + "double", + """{ + | "double_col": { + | "type": "double" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "float_col", + "float", + """{ + | "float_col": { + | "type": "float" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "timestamp_col", + "timestamp", + """{ + | "timestamp_col": { + | "type": "date", + | "format": "strict_date_optional_time_nanos" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "date_col", + "date", + """{ + | "date_col": { + | "type": "date", + | "format": "strict_date" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "struct_col", + "struct", + """{ + | "struct_col": { + | "properties": { + | "subfield1": { + | "type": "keyword" + | }, + | "subfield2": { + | "type": "integer" + | } + | } + | }, + | "file_path": { + | "type": "keyword" + | } + |}""")).foreach { case (columnName, columnType, expectedSchema) => + test(s"can build index for $columnType column") { + val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) + when(indexCol.outputSchema()).thenReturn(Map(columnName -> columnType)) + when(indexCol.getAggregators).thenReturn( + Seq(CollectSet(col(columnName).expr).toAggregateExpression())) + + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) + schemaShouldMatch(index.metadata(), expectedSchema.stripMargin) + } } test("should fail if get index name without full table name") { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 40cb5c201..4956483b2 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -285,6 +285,24 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { useFlintSparkSkippingFileIndex(hasIndexFilter(col("address") === "Portland")) } + test("can build value set skipping index with limit") { + // Add one more row for Seattle + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Hello', 30, 'Seattle') + | """.stripMargin) + + flint + .skippingIndex() + .onTable(testTable) + .addValueSet("address", Map("limit" -> "1")) + .create() + flint.refreshIndex(testIndex, FULL) + + checkAnswer(flint.queryIndex(testIndex).select("address"), Seq(Row(Seq("Portland")), Row())) + } + test("can build min max skipping index and rewrite applicable query") { flint .skippingIndex()