diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index d83af5df5..120ca8219 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -74,7 +74,7 @@ case class FlintSparkSkippingIndex( // Wrap aggregate function with output column name val namedAggFuncs = (outputNames, aggFuncs).zipped.map { case (name, aggFunc) => - new Column(aggFunc.toAggregateExpression().as(name)) + new Column(aggFunc.as(name)) } df.getOrElse(spark.read.table(tableName)) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index eeba48cfe..042c968ec 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -10,7 +10,6 @@ import org.json4s.JsonAST.JString import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction /** * Skipping index strategy that defines skipping data structure building and reading logic. @@ -42,7 +41,7 @@ trait FlintSparkSkippingStrategy { * @return * aggregators that generate skipping data structure */ - def getAggregators: Seq[AggregateFunction] + def getAggregators: Seq[Expression] /** * Rewrite a filtering condition on source table into a new predicate on index data based on diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala index dc57830c7..f7745e7a8 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Max, Min} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.functions.col @@ -29,8 +29,11 @@ case class MinMaxSkippingStrategy( override def outputSchema(): Map[String, String] = Map(minColName -> columnType, maxColName -> columnType) - override def getAggregators: Seq[AggregateFunction] = - Seq(Min(col(columnName).expr), Max(col(columnName).expr)) + override def getAggregators: Seq[Expression] = { + Seq( + Min(col(columnName).expr).toAggregateExpression(), + Max(col(columnName).expr).toAggregateExpression()) + } override def rewritePredicate(predicate: Expression): Option[Expression] = predicate match { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index fe717e7ad..18fec0642 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First} +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.functions.col /** @@ -25,8 +25,8 @@ case class PartitionSkippingStrategy( Map(columnName -> columnType) } - override def getAggregators: Seq[AggregateFunction] = { - Seq(First(col(columnName).expr, ignoreNulls = true)) + override def getAggregators: Seq[Expression] = { + Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression()) } override def rewritePredicate(predicate: Expression): Option[Expression] = diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index f462b70c1..9bcd2ba9b 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.skipping.valueset import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET} +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, CollectSet} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions._ /** * Skipping strategy based on unique column value set. @@ -24,8 +24,16 @@ case class ValueSetSkippingStrategy( override def outputSchema(): Map[String, String] = Map(columnName -> columnType) - override def getAggregators: Seq[AggregateFunction] = - Seq(CollectSet(col(columnName).expr)) + override def getAggregators: Seq[Expression] = { + val limit = DEFAULT_VALUE_SET_SIZE_LIMIT + val collectSetLimit = collect_set(columnName) + + // IF(ARRAY_SIZE(COLLECT_SET(col)) > default_limit, null, COLLECT_SET(col) + val aggregator = + when(size(collectSetLimit) > limit, lit(null)) + .otherwise(collectSetLimit) + Seq(aggregator.expr) + } override def rewritePredicate(predicate: Expression): Option[Expression] = /* @@ -38,3 +46,9 @@ case class ValueSetSkippingStrategy( case _ => None } } + +object ValueSetSkippingStrategy { + + /** Default limit for value set size collected */ + val DEFAULT_VALUE_SET_SIZE_LIMIT = 100 +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 491b7811a..9760e8cd2 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -58,7 +58,8 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index building job with unique ID column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.outputSchema()).thenReturn(Map("name" -> "string")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr))) + when(indexCol.getAggregators).thenReturn( + Seq(CollectSet(col("name").expr).toAggregateExpression())) val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") @@ -66,276 +67,170 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) } - test("can build index for boolean column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("boolean_col" -> "boolean")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("boolean_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "boolean_col": { - | "type": "boolean" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for string column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("string_col" -> "string")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("string_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "string_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - // TODO: test for osType "text" - - test("can build index for varchar column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("varchar_col" -> "varchar(20)")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("varchar_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "varchar_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for char column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("char_col" -> "char(20)")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("char_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "char_col": { - | "type": "keyword" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for long column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("long_col" -> "bigint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("long_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "long_col": { - | "type": "long" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for int column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("int_col" -> "int")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("int_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "int_col": { - | "type": "integer" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for short column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("short_col" -> "smallint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("short_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "short_col": { - | "type": "short" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for byte column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("byte_col" -> "tinyint")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("byte_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "byte_col": { - | "type": "byte" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for double column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("double_col" -> "double")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("double_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "double_col": { - | "type": "double" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for float column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("float_col" -> "float")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("float_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "float_col": { - | "type": "float" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for timestamp column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("timestamp_col" -> "timestamp")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("timestamp_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "timestamp_col": { - | "type": "date", - | "format": "strict_date_optional_time_nanos" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for date column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()).thenReturn(Map("date_col" -> "date")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("date_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "date_col": { - | "type": "date", - | "format": "strict_date" - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) - } - - test("can build index for struct column") { - val indexCol = mock[FlintSparkSkippingStrategy] - when(indexCol.kind).thenReturn(SkippingKind.PARTITION) - when(indexCol.outputSchema()) - .thenReturn(Map("struct_col" -> "struct")) - when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("struct_col").expr))) - - val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) - schemaShouldMatch( - index.metadata(), - s"""{ - | "struct_col": { - | "properties": { - | "subfield1": { - | "type": "keyword" - | }, - | "subfield2": { - | "type": "integer" - | } - | } - | }, - | "file_path": { - | "type": "keyword" - | } - |} - |""".stripMargin) + // Test index build for different column type + Seq( + ( + "boolean_col", + "boolean", + """{ + | "boolean_col": { + | "type": "boolean" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "string_col", + "string", + """{ + | "string_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "varchar_col", + "varchar(20)", + """{ + | "varchar_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "char_col", + "char(20)", + """{ + | "char_col": { + | "type": "keyword" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "long_col", + "bigint", + """{ + | "long_col": { + | "type": "long" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "int_col", + "int", + """{ + | "int_col": { + | "type": "integer" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "short_col", + "smallint", + """{ + | "short_col": { + | "type": "short" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "byte_col", + "tinyint", + """{ + | "byte_col": { + | "type": "byte" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "double_col", + "double", + """{ + | "double_col": { + | "type": "double" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "float_col", + "float", + """{ + | "float_col": { + | "type": "float" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "timestamp_col", + "timestamp", + """{ + | "timestamp_col": { + | "type": "date", + | "format": "strict_date_optional_time_nanos" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "date_col", + "date", + """{ + | "date_col": { + | "type": "date", + | "format": "strict_date" + | }, + | "file_path": { + | "type": "keyword" + | } + |}"""), + ( + "struct_col", + "struct", + """{ + | "struct_col": { + | "properties": { + | "subfield1": { + | "type": "keyword" + | }, + | "subfield2": { + | "type": "integer" + | } + | } + | }, + | "file_path": { + | "type": "keyword" + | } + |}""")).foreach { case (columnName, columnType, expectedSchema) => + test(s"can build index for $columnType column") { + val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) + when(indexCol.outputSchema()).thenReturn(Map(columnName -> columnType)) + when(indexCol.getAggregators).thenReturn( + Seq(CollectSet(col(columnName).expr).toAggregateExpression())) + + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) + schemaShouldMatch(index.metadata(), expectedSchema.stripMargin) + } } test("should fail if get index name without full table name") {