diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index c4f9080bb..8058f9bff 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -12,10 +12,11 @@ import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} import org.opensearch.flint.core.metadata.FlintMetadata -import org.opensearch.flint.spark.FlintSpark._ import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode} import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN -import org.opensearch.flint.spark.skipping.{FlintSparkSkippingIndex, FlintSparkSkippingStrategy} +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer} import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} @@ -25,12 +26,10 @@ import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.SaveMode._ -import org.apache.spark.sql.catalog.Column import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN} import org.apache.spark.sql.streaming.OutputMode.Append -import org.apache.spark.sql.streaming.StreamingQuery /** * Flint Spark integration API entrypoint. @@ -42,8 +41,7 @@ class FlintSpark(val spark: SparkSession) { FlintSparkConf( Map( DOC_ID_COLUMN_NAME.optionKey -> ID_COLUMN, - IGNORE_DOC_ID_COLUMN.optionKey -> "true" - ).asJava) + IGNORE_DOC_ID_COLUMN.optionKey -> "true").asJava) /** Flint client for low-level index operation */ private val flintClient: FlintClient = FlintClientBuilder.build(flintSparkConf.flintOptions()) @@ -57,8 +55,18 @@ class FlintSpark(val spark: SparkSession) { * @return * index builder */ - def skippingIndex(): IndexBuilder = { - new IndexBuilder(this) + def skippingIndex(): FlintSparkSkippingIndex.Builder = { + new FlintSparkSkippingIndex.Builder(this) + } + + /** + * Create index builder for creating index with fluent API. + * + * @return + * index builder + */ + def coveringIndex(): FlintSparkCoveringIndex.Builder = { + new FlintSparkCoveringIndex.Builder(this) } /** @@ -199,6 +207,7 @@ class FlintSpark(val spark: SparkSession) { */ private def deserialize(metadata: FlintMetadata): FlintSparkIndex = { val meta = parse(metadata.getContent) \ "_meta" + val indexName = (meta \ "name").extract[String] val tableName = (meta \ "source").extract[String] val indexType = (meta \ "kind").extract[String] val indexedColumns = (meta \ "indexedColumns").asInstanceOf[JArray] @@ -222,6 +231,13 @@ class FlintSpark(val spark: SparkSession) { } } new FlintSparkSkippingIndex(tableName, strategies) + case COVERING_INDEX_TYPE => + new FlintSparkCoveringIndex( + indexName, + tableName, + indexedColumns.arr.map { obj => + ((obj \ "columnName").extract[String], (obj \ "columnType").extract[String]) + }.toMap) } } } @@ -236,102 +252,4 @@ object FlintSpark { type RefreshMode = Value val FULL, INCREMENTAL = Value } - - /** - * Helper class for index class construct. For now only skipping index supported. - */ - class IndexBuilder(flint: FlintSpark) { - var tableName: String = "" - var indexedColumns: Seq[FlintSparkSkippingStrategy] = Seq() - - lazy val allColumns: Map[String, Column] = { - flint.spark.catalog - .listColumns(tableName) - .collect() - .map(col => (col.name, col)) - .toMap - } - - /** - * Configure which source table the index is based on. - * - * @param tableName - * full table name - * @return - * index builder - */ - def onTable(tableName: String): IndexBuilder = { - this.tableName = tableName - this - } - - /** - * Add partition skipping indexed columns. - * - * @param colNames - * indexed column names - * @return - * index builder - */ - def addPartitions(colNames: String*): IndexBuilder = { - require(tableName.nonEmpty, "table name cannot be empty") - - colNames - .map(findColumn) - .map(col => PartitionSkippingStrategy(columnName = col.name, columnType = col.dataType)) - .foreach(addIndexedColumn) - this - } - - /** - * Add value set skipping indexed column. - * - * @param colName - * indexed column name - * @return - * index builder - */ - def addValueSet(colName: String): IndexBuilder = { - require(tableName.nonEmpty, "table name cannot be empty") - - val col = findColumn(colName) - addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType)) - this - } - - /** - * Add min max skipping indexed column. - * - * @param colName - * indexed column name - * @return - * index builder - */ - def addMinMax(colName: String): IndexBuilder = { - val col = findColumn(colName) - indexedColumns = - indexedColumns :+ MinMaxSkippingStrategy(columnName = col.name, columnType = col.dataType) - this - } - - /** - * Create index. - */ - def create(): Unit = { - flint.createIndex(new FlintSparkSkippingIndex(tableName, indexedColumns)) - } - - private def findColumn(colName: String): Column = - allColumns.getOrElse( - colName, - throw new IllegalArgumentException(s"Column $colName does not exist")) - - private def addIndexedColumn(indexedCol: FlintSparkSkippingStrategy): Unit = { - require( - indexedColumns.forall(_.columnName != indexedCol.columnName), - s"${indexedCol.columnName} is already indexed") - - indexedColumns = indexedColumns :+ indexedCol - } - } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index bbfa4c4ba..62e6b4668 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -49,4 +49,15 @@ object FlintSparkIndex { * ID column name. */ val ID_COLUMN: String = "__id__" + + /** + * Common prefix of Flint index name which is "flint_database_table_" + * + * @param fullTableName + * source full table name + * @return + * Flint index name + */ + def flintIndexNamePrefix(fullTableName: String): String = + s"flint_${fullTableName.replace(".", "_")}_" } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala new file mode 100644 index 000000000..740c02e1d --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalog.Column + +/** + * Flint Spark index builder base class. + * + * @param flint + * Flint Spark API entrypoint + */ +abstract class FlintSparkIndexBuilder(flint: FlintSpark) { + + /** Source table name */ + protected var tableName: String = "" + + /** All columns of the given source table */ + lazy protected val allColumns: Map[String, Column] = { + require(tableName.nonEmpty, "Source table name is not provided") + + flint.spark.catalog + .listColumns(tableName) + .collect() + .map(col => (col.name, col)) + .toMap + } + + /** + * Create Flint index. + */ + def create(): Unit = flint.createIndex(buildIndex()) + + /** + * Build method for concrete builder class to implement + */ + protected def buildIndex(): FlintSparkIndex + + protected def findColumn(colName: String): Column = + allColumns.getOrElse( + colName, + throw new IllegalArgumentException(s"Column $colName does not exist")) +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala new file mode 100644 index 000000000..f2f5933d6 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.covering + +import org.json4s.{Formats, NoTypeHints} +import org.json4s.JsonAST.{JArray, JObject, JString} +import org.json4s.native.JsonMethods.{compact, parse, render} +import org.json4s.native.Serialization +import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder} +import org.opensearch.flint.spark.FlintSparkIndex.flintIndexNamePrefix +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.flint.datatype.FlintDataType +import org.apache.spark.sql.types.StructType + +/** + * Flint covering index in Spark. + * + * @param indexName + * index name + * @param tableName + * source table name + * @param indexedColumns + * indexed column list + */ +class FlintSparkCoveringIndex( + indexName: String, + tableName: String, + indexedColumns: Map[String, String]) + extends FlintSparkIndex { + + require(indexedColumns.nonEmpty, "indexed columns must not be empty") + + /** Required by json4s write function */ + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + override val kind: String = COVERING_INDEX_TYPE + + override def name(): String = getFlintIndexName(indexName, tableName) + + override def metadata(): FlintMetadata = { + new FlintMetadata(s"""{ + | "_meta": { + | "name": "$indexName", + | "kind": "$kind", + | "indexedColumns": $getMetaInfo, + | "source": "$tableName" + | }, + | "properties": $getSchema + | } + |""".stripMargin) + } + + override def build(df: DataFrame): DataFrame = { + val colNames = indexedColumns.keys.toSeq + df.select(colNames.head, colNames.tail: _*) + } + + // TODO: refactor all these once Flint metadata spec finalized + private def getMetaInfo: String = { + val objects = indexedColumns.map { case (colName, colType) => + JObject("columnName" -> JString(colName), "columnType" -> JString(colType)) + }.toList + Serialization.write(JArray(objects)) + } + + private def getSchema: String = { + val catalogDDL = + indexedColumns + .map { case (colName, colType) => s"$colName $colType not null" } + .mkString(",") + val properties = FlintDataType.serialize(StructType.fromDDL(catalogDDL)) + compact(render(parse(properties) \ "properties")) + } +} + +object FlintSparkCoveringIndex { + + /** Covering index type name */ + val COVERING_INDEX_TYPE = "covering" + + /** Flint covering index name suffix */ + val COVERING_INDEX_SUFFIX = "_index" + + /** + * Get Flint index name which follows the convention: "flint_" prefix + source table name + + + * given index name + "_index" suffix. + * + * This helps identify the Flint index because Flint index is not registered to Spark Catalog + * for now. + * + * @param tableName + * full table name + * @param indexName + * index name specified by user + * @return + * Flint covering index name + */ + def getFlintIndexName(indexName: String, tableName: String): String = { + require(tableName.contains("."), "Full table name database.table is required") + + flintIndexNamePrefix(tableName) + indexName + COVERING_INDEX_SUFFIX + } + + /** Builder class for covering index build */ + class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { + private var indexName: String = "" + private var indexedColumns: Map[String, String] = Map() + + /** + * Set covering index name. + * + * @param indexName + * index name + * @return + * index builder + */ + def name(indexName: String): Builder = { + this.indexName = indexName + this + } + + /** + * Configure which source table the index is based on. + * + * @param tableName + * full table name + * @return + * index builder + */ + def onTable(tableName: String): Builder = { + this.tableName = tableName + this + } + + /** + * Add indexed column name. + * + * @param colNames + * column names + * @return + * index builder + */ + def addIndexColumns(colNames: String*): Builder = { + colNames.foreach(colName => { + indexedColumns += (colName -> findColumn(colName).dataType) + }) + this + } + + override protected def buildIndex(): FlintSparkIndex = + new FlintSparkCoveringIndex(indexName, tableName, indexedColumns) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index d3579ef51..325f40254 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -10,22 +10,27 @@ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization import org.opensearch.flint.core.FlintVersion import org.opensearch.flint.core.metadata.FlintMetadata -import org.opensearch.flint.spark.FlintSparkIndex -import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder} +import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, ID_COLUMN} import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer +import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy +import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression import org.apache.spark.sql.flint.datatype.FlintDataType import org.apache.spark.sql.functions.{col, input_file_name, sha1} -import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} +import org.apache.spark.sql.types.StructType /** * Flint skipping index in Spark. * * @param tableName * source table name + * @param indexedColumns + * indexed column list */ class FlintSparkSkippingIndex( tableName: String, @@ -47,6 +52,7 @@ class FlintSparkSkippingIndex( override def metadata(): FlintMetadata = { new FlintMetadata(s"""{ | "_meta": { + | "name": "${name()}", | "version": "${FlintVersion.current()}", | "kind": "$SKIPPING_INDEX_TYPE", | "indexedColumns": $getMetaInfo, @@ -98,6 +104,9 @@ object FlintSparkSkippingIndex { /** File path column name */ val FILE_PATH_COLUMN = "file_path" + /** Flint skipping index name suffix */ + val SKIPPING_INDEX_SUFFIX = "skipping_index" + /** * Get skipping index name which follows the convention: "flint_" prefix + source table name + * "_skipping_index" suffix. @@ -113,6 +122,84 @@ object FlintSparkSkippingIndex { def getSkippingIndexName(tableName: String): String = { require(tableName.contains("."), "Full table name database.table is required") - s"flint_${tableName.replace(".", "_")}_skipping_index" + flintIndexNamePrefix(tableName) + SKIPPING_INDEX_SUFFIX + } + + /** Builder class for skipping index build */ + class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { + private var indexedColumns: Seq[FlintSparkSkippingStrategy] = Seq() + + /** + * Configure which source table the index is based on. + * + * @param tableName + * full table name + * @return + * index builder + */ + def onTable(tableName: String): Builder = { + this.tableName = tableName + this + } + + /** + * Add partition skipping indexed columns. + * + * @param colNames + * indexed column names + * @return + * index builder + */ + def addPartitions(colNames: String*): Builder = { + require(tableName.nonEmpty, "table name cannot be empty") + + colNames + .map(findColumn) + .map(col => PartitionSkippingStrategy(columnName = col.name, columnType = col.dataType)) + .foreach(addIndexedColumn) + this + } + + /** + * Add value set skipping indexed column. + * + * @param colName + * indexed column name + * @return + * index builder + */ + def addValueSet(colName: String): Builder = { + require(tableName.nonEmpty, "table name cannot be empty") + + val col = findColumn(colName) + addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType)) + this + } + + /** + * Add min max skipping indexed column. + * + * @param colName + * indexed column name + * @return + * index builder + */ + def addMinMax(colName: String): Builder = { + val col = findColumn(colName) + indexedColumns = + indexedColumns :+ MinMaxSkippingStrategy(columnName = col.name, columnType = col.dataType) + this + } + + override def buildIndex(): FlintSparkIndex = + new FlintSparkSkippingIndex(tableName, indexedColumns) + + private def addIndexedColumn(indexedCol: FlintSparkSkippingStrategy): Unit = { + require( + indexedColumns.forall(_.columnName != indexedCol.columnName), + s"${indexedCol.columnName} is already indexed") + + indexedColumns = indexedColumns :+ indexedCol + } } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala new file mode 100644 index 000000000..a50db1af2 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.covering + +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.spark.FlintSuite + +class FlintSparkCoveringIndexSuite extends FlintSuite { + + test("get covering index name") { + val index = new FlintSparkCoveringIndex("ci", "default.test", Map("name" -> "string")) + index.name() shouldBe "flint_default_test_ci_index" + } + + test("should fail if get index name without full table name") { + val index = new FlintSparkCoveringIndex("ci", "test", Map("name" -> "string")) + assertThrows[IllegalArgumentException] { + index.name() + } + } + + test("should fail if no indexed column given") { + assertThrows[IllegalArgumentException] { + new FlintSparkCoveringIndex("ci", "default.test", Map.empty) + } + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index d50f83073..1d65fd821 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.skipping -import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods.parse import org.mockito.Mockito.when import org.opensearch.flint.core.metadata.FlintMetadata diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala new file mode 100644 index 000000000..20e4dca24 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL} +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName +import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.spark.sql.Row + +class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { + + /** Test table and index name */ + private val testTable = "default.ci_test" + private val testIndex = "name_and_age" + private val testFlintIndex = getFlintIndexName(testIndex, testTable) + + override def beforeAll(): Unit = { + super.beforeAll() + + createPartitionedTable(testTable) + } + + override def afterEach(): Unit = { + super.afterEach() + + // Delete all test indices + flint.deleteIndex(testFlintIndex) + } + + test("create covering index with metadata successfully") { + flint + .coveringIndex() + .name(testIndex) + .onTable(testTable) + .addIndexColumns("name", "age") + .create() + + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get.metadata().getContent should matchJson(s"""{ + | "_meta": { + | "name": "name_and_age", + | "kind": "covering", + | "indexedColumns": [ + | { + | "columnName": "name", + | "columnType": "string" + | }, + | { + | "columnName": "age", + | "columnType": "int" + | }], + | "source": "default.ci_test" + | }, + | "properties": { + | "name": { + | "type": "keyword" + | }, + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin) + } + + test("full refresh covering index successfully") { + flint + .coveringIndex() + .name(testIndex) + .onTable(testTable) + .addIndexColumns("name", "age") + .create() + + flint.refreshIndex(testFlintIndex, FULL) + + val indexData = flint.queryIndex(testFlintIndex) + checkAnswer(indexData, Seq(Row("Hello", 30), Row("World", 25))) + } + + test("incremental refresh covering index successfully") { + flint + .coveringIndex() + .name(testIndex) + .onTable(testTable) + .addIndexColumns("name", "age") + .create() + + val jobId = flint.refreshIndex(testFlintIndex, INCREMENTAL) + jobId shouldBe defined + + val job = spark.streams.get(jobId.get) + failAfter(streamingTimeout) { + job.processAllAvailable() + } + + val indexData = flint.queryIndex(testFlintIndex) + checkAnswer(indexData, Seq(Row("Hello", 30), Row("World", 25))) + } + + test("can have multiple covering indexes on a table") { + flint + .coveringIndex() + .name(testIndex) + .onTable(testTable) + .addIndexColumns("name", "age") + .create() + + val newIndex = testIndex + "_address" + flint + .coveringIndex() + .name(newIndex) + .onTable(testTable) + .addIndexColumns("address") + .create() + flint.deleteIndex(getFlintIndexName(newIndex, testTable)) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index e399b43c1..5d31d8724 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -6,7 +6,6 @@ package org.opensearch.flint.spark import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson -import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.core.FlintVersion.current import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL} import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN @@ -16,27 +15,13 @@ import org.scalatest.matchers.{Matcher, MatchResult} import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.apache.spark.FlintSuite -import org.apache.spark.sql.{Column, QueryTest, Row} +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.flint.config.FlintSparkConf._ import org.apache.spark.sql.functions.col -import org.apache.spark.sql.streaming.StreamTest - -class FlintSparkSkippingIndexITSuite - extends QueryTest - with FlintSuite - with OpenSearchSuite - with StreamTest { - - /** Flint Spark high level API being tested */ - lazy val flint: FlintSpark = { - setFlintSparkConf(HOST_ENDPOINT, openSearchHost) - setFlintSparkConf(HOST_PORT, openSearchPort) - setFlintSparkConf(REFRESH_POLICY, "true") - new FlintSpark(spark) - } + +class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { /** Test table and index name */ private val testTable = "default.test" @@ -45,35 +30,7 @@ class FlintSparkSkippingIndexITSuite override def beforeAll(): Unit = { super.beforeAll() - sql(s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT, - | address STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - sql(s""" - | INSERT INTO $testTable - | PARTITION (year=2023, month=4) - | VALUES ('Hello', 30, 'Seattle') - | """.stripMargin) - - sql(s""" - | INSERT INTO $testTable - | PARTITION (year=2023, month=5) - | VALUES ('World', 25, 'Portland') - | """.stripMargin) + createPartitionedTable(testTable) } override def afterEach(): Unit = { @@ -97,6 +54,7 @@ class FlintSparkSkippingIndexITSuite index shouldBe defined index.get.metadata().getContent should matchJson(s"""{ | "_meta": { + | "name": "flint_default_test_skipping_index", | "version": "${current()}", | "kind": "skipping", | "indexedColumns": [ @@ -457,6 +415,7 @@ class FlintSparkSkippingIndexITSuite index.get.metadata().getContent should matchJson( s"""{ | "_meta": { + | "name": "flint_default_data_type_table_skipping_index", | "version": "${current()}", | "kind": "skipping", | "indexedColumns": [ diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala new file mode 100644 index 000000000..3ee6deda1 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.OpenSearchSuite + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY} +import org.apache.spark.sql.streaming.StreamTest + +/** + * Flint Spark suite trait that initializes [[FlintSpark]] API instance. + */ +trait FlintSparkSuite + extends QueryTest + with FlintSuite + with OpenSearchSuite + with StreamTest { + + /** Flint Spark high level API being tested */ + lazy protected val flint: FlintSpark = { + setFlintSparkConf(HOST_ENDPOINT, openSearchHost) + setFlintSparkConf(HOST_PORT, openSearchPort) + setFlintSparkConf(REFRESH_POLICY, "true") + new FlintSpark(spark) + } + + protected def createPartitionedTable(testTable: String): Unit = { + sql( + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | address STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + sql( + s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Hello', 30, 'Seattle') + | """.stripMargin) + + sql( + s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=5) + | VALUES ('World', 25, 'Portland') + | """.stripMargin) + } +}