diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index cb2e14144..b223ee2c2 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -139,7 +139,12 @@ indexColTypeList ; indexColType - : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX) + : identifier skipType=(PARTITION | MIN_MAX) + | identifier valueSetType + ; + +valueSetType + : VALUE_SET (LEFT_PAREN limit=INTEGER_VALUE RIGHT_PAREN)? ; indexName diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala index 2b0bb6c48..fa922d713 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.FlintSpark.RefreshMode import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION} import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ @@ -42,11 +42,18 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A ctx.indexColTypeList().indexColType().forEach { colTypeCtx => val colName = colTypeCtx.identifier().getText - val skipType = SkippingKind.withName(colTypeCtx.skipType.getText) - skipType match { - case PARTITION => indexBuilder.addPartitions(colName) - case VALUE_SET => indexBuilder.addValueSet(colName) - case MIN_MAX => indexBuilder.addMinMax(colName) + if (colTypeCtx.skipType == null) { + if (colTypeCtx.valueSetType().limit == null) { + indexBuilder.addValueSet(colName) + } else { + indexBuilder.addValueSet(colName, colTypeCtx.valueSetType().limit.getText.toInt) + } + } else { + val skipType = SkippingKind.withName(colTypeCtx.skipType.getText) + skipType match { + case PARTITION => indexBuilder.addPartitions(colName) + case MIN_MAX => indexBuilder.addMinMax(colName) + } } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index 21de15de7..cf29e0298 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -14,7 +14,7 @@ import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.FlintOpenSearchClient import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName -import org.scalatest.matchers.must.Matchers.{defined, have} +import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} import org.apache.spark.sql.Row @@ -144,6 +144,24 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { indexData.count() shouldBe 2 } + test("create skipping index with value set limit") { + sql(s""" + | CREATE SKIPPING INDEX ON $testTable + | ( + | name VALUE_SET(10) + | ) + | WITH (auto_refresh = true) + | """.stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testIndex) + awaitStreamingComplete(job.get.id.toString) + + val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex) + flint.describeIndex(testIndex) shouldBe defined + indexData.count() shouldBe 2 + } + test("create skipping index if not exists") { sql(s""" | CREATE SKIPPING INDEX