diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index c988186c2..e10e6a29b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -18,11 +18,11 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIn import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{ExplainSuiteHelper, Row} import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.flint.config.FlintSparkConf.CHECKPOINT_MANDATORY -class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { +class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite with ExplainSuiteHelper { /** Test table and index name */ private val testTable = "spark_catalog.default.skipping_sql_test" @@ -166,36 +166,41 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { (settings \ "index.number_of_replicas").extract[String] shouldBe "2" } - test("build skipping index for nested field") { - val testTable = "spark_catalog.default.nested_field_table" - val testIndex = getSkippingIndexName(testTable) - withTable(testTable) { - createStructTable(testTable) + Seq( + "struct_col.field1.subfield VALUE_SET, struct_col.field2 MIN_MAX", + "`struct_col.field1.subfield` VALUE_SET, `struct_col.field2` MIN_MAX", // ensure previous hack still works + "`struct_col`.`field1`.`subfield` VALUE_SET, `struct_col`.`field2` MIN_MAX").foreach { + columnSkipTypes => + test(s"build skipping index for nested field $columnSkipTypes") { + val testTable = "spark_catalog.default.nested_field_table" + val testIndex = getSkippingIndexName(testTable) + withTable(testTable) { + createStructTable(testTable) + sql(s""" + | CREATE SKIPPING INDEX ON $testTable + | ( $columnSkipTypes ) + | WITH ( + | auto_refresh = true + | ) + | """.stripMargin) - sql(s""" - | CREATE SKIPPING INDEX ON $testTable - | ( - | struct_col.field1.subfield VALUE_SET, - | struct_col.field2 MIN_MAX - | ) - | WITH ( - | auto_refresh = true - | ) - | """.stripMargin) + val job = spark.streams.active.find(_.name == testIndex) + awaitStreamingComplete(job.get.id.toString) - val job = spark.streams.active.find(_.name == testIndex) - awaitStreamingComplete(job.get.id.toString) + // Query rewrite nested field + val query1 = sql(s"SELECT int_col FROM $testTable WHERE struct_col.field2 = 456") + checkAnswer(query1, Row(40)) + checkKeywordsExistsInExplain(query1, "FlintSparkSkippingFileIndex") - // Query rewrite nested field - val query1 = s"SELECT int_col FROM $testTable WHERE struct_col.field2 = 456" - checkAnswer(sql(query1), Row(40)) - checkKeywordsExist(sql(s"EXPLAIN $query1"), "FlintSparkSkippingFileIndex") + // Query rewrite deep nested field + val query2 = + sql(s"SELECT int_col FROM $testTable WHERE struct_col.field1.subfield = 'value3'") + checkAnswer(query2, Row(50)) + checkKeywordsExistsInExplain(query2, "FlintSparkSkippingFileIndex") + } - // Query rewrite deep nested field - val query2 = s"SELECT int_col FROM $testTable WHERE struct_col.field1.subfield = 'value3'" - checkAnswer(sql(query2), Row(50)) - checkKeywordsExist(sql(s"EXPLAIN $query2"), "FlintSparkSkippingFileIndex") - } + deleteTestIndex(testIndex) + } } test("create skipping index with invalid option") {