diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index dc097d596..2e8d634da 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -188,7 +188,7 @@ indexColTypeList ; indexColType - : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX | BLOOM_FILTER) + : multipartIdentifier skipType=(PARTITION | VALUE_SET | MIN_MAX | BLOOM_FILTER) (LEFT_PAREN skipParams RIGHT_PAREN)? ; diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala index 367db9d3f..521898aa2 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala @@ -44,7 +44,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A .onTable(getFullTableName(flint, ctx.tableName)) ctx.indexColTypeList().indexColType().forEach { colTypeCtx => - val colName = colTypeCtx.identifier().getText + val colName = colTypeCtx.multipartIdentifier().getText val skipType = SkippingKind.withName(colTypeCtx.skipType.getText) val skipParams = visitSkipParams(colTypeCtx.skipParams()) skipType match { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 999fb3008..b2185a5a9 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -809,25 +809,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val testTable = "spark_catalog.default.nested_field_table" val testIndex = getSkippingIndexName(testTable) withTable(testTable) { - sql(s""" - | CREATE TABLE $testTable - | ( - | int_col INT, - | struct_col STRUCT, field2: INT> - | ) - | USING JSON - |""".stripMargin) - sql(s""" - | INSERT INTO $testTable - | SELECT /*+ COALESCE(1) */ * - | FROM VALUES - | ( 30, STRUCT(STRUCT("value1"),123) ), - | ( 40, STRUCT(STRUCT("value2"),456) ) - |""".stripMargin) - sql(s""" - | INSERT INTO $testTable - | VALUES ( 50, STRUCT(STRUCT("value3"),789) ) - |""".stripMargin) + createStructTable(testTable) flint .skippingIndex() diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index cdc599233..e10e6a29b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -18,11 +18,11 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIn import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{ExplainSuiteHelper, Row} import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.flint.config.FlintSparkConf.CHECKPOINT_MANDATORY -class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { +class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuiteHelper { /** Test table and index name */ private val testTable = "spark_catalog.default.skipping_sql_test" @@ -166,6 +166,43 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { (settings \ "index.number_of_replicas").extract[String] shouldBe "2" } + Seq( + "struct_col.field1.subfield VALUE_SET, struct_col.field2 MIN_MAX", + "`struct_col.field1.subfield` VALUE_SET, `struct_col.field2` MIN_MAX", // ensure previous hack still works + "`struct_col`.`field1`.`subfield` VALUE_SET, `struct_col`.`field2` MIN_MAX").foreach { + columnSkipTypes => + test(s"build skipping index for nested field $columnSkipTypes") { + val testTable = "spark_catalog.default.nested_field_table" + val testIndex = getSkippingIndexName(testTable) + withTable(testTable) { + createStructTable(testTable) + sql(s""" + | CREATE SKIPPING INDEX ON $testTable + | ( $columnSkipTypes ) + | WITH ( + | auto_refresh = true + | ) + | """.stripMargin) + + val job = spark.streams.active.find(_.name == testIndex) + awaitStreamingComplete(job.get.id.toString) + + // Query rewrite nested field + val query1 = sql(s"SELECT int_col FROM $testTable WHERE struct_col.field2 = 456") + checkAnswer(query1, Row(40)) + checkKeywordsExistsInExplain(query1, "FlintSparkSkippingFileIndex") + + // Query rewrite deep nested field + val query2 = + sql(s"SELECT int_col FROM $testTable WHERE struct_col.field1.subfield = 'value3'") + checkAnswer(query2, Row(50)) + checkKeywordsExistsInExplain(query2, "FlintSparkSkippingFileIndex") + } + + deleteTestIndex(testIndex) + } + } + test("create skipping index with invalid option") { the[IllegalArgumentException] thrownBy sql(s""" diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 0c6282bb6..fbb2f89bd 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -5,11 +5,10 @@ package org.opensearch.flint.spark -import java.nio.file.{Files, Path, Paths, StandardCopyOption} +import java.nio.file.{Files, Paths} import java.util.Comparator import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} -import scala.collection.immutable.Map import scala.concurrent.duration.TimeUnit import scala.util.Try @@ -20,11 +19,9 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest import org.opensearch.client.RequestOptions import org.opensearch.client.indices.GetIndexRequest import org.opensearch.flint.OpenSearchSuite -import org.scalatest.prop.TableDrivenPropertyChecks.forAll import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.spark.FlintSuite -import org.apache.spark.SparkConf +import org.apache.spark.{FlintSuite, SparkConf} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY} import org.apache.spark.sql.streaming.StreamTest @@ -312,6 +309,30 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createStructTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | struct_col STRUCT, field2: INT> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ( 30, STRUCT(STRUCT("value1"),123) ), + | ( 40, STRUCT(STRUCT("value2"),456) ) + |""".stripMargin) + sql(s""" + | INSERT INTO $testTable + | VALUES ( 50, STRUCT(STRUCT("value3"),789) ) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable (