diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index eb2cca410..e8e0264f2 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -28,21 +28,21 @@ skippingIndexStatement createSkippingIndexStatement : CREATE SKIPPING INDEX (IF NOT EXISTS)? - ON tableName=multipartIdentifier + ON tableName LEFT_PAREN indexColTypeList RIGHT_PAREN (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; refreshSkippingIndexStatement - : REFRESH SKIPPING INDEX ON tableName=multipartIdentifier + : REFRESH SKIPPING INDEX ON tableName ; describeSkippingIndexStatement - : (DESC | DESCRIBE) SKIPPING INDEX ON tableName=multipartIdentifier + : (DESC | DESCRIBE) SKIPPING INDEX ON tableName ; dropSkippingIndexStatement - : DROP SKIPPING INDEX ON tableName=multipartIdentifier + : DROP SKIPPING INDEX ON tableName ; coveringIndexStatement @@ -54,26 +54,26 @@ coveringIndexStatement ; createCoveringIndexStatement - : CREATE INDEX (IF NOT EXISTS)? indexName=identifier - ON tableName=multipartIdentifier + : CREATE INDEX (IF NOT EXISTS)? indexName + ON tableName LEFT_PAREN indexColumns=multipartIdentifierPropertyList RIGHT_PAREN (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; refreshCoveringIndexStatement - : REFRESH INDEX indexName=identifier ON tableName=multipartIdentifier + : REFRESH INDEX indexName ON tableName ; showCoveringIndexStatement - : SHOW (INDEX | INDEXES) ON tableName=multipartIdentifier + : SHOW (INDEX | INDEXES) ON tableName ; describeCoveringIndexStatement - : (DESC | DESCRIBE) INDEX indexName=identifier ON tableName=multipartIdentifier + : (DESC | DESCRIBE) INDEX indexName ON tableName ; dropCoveringIndexStatement - : DROP INDEX indexName=identifier ON tableName=multipartIdentifier + : DROP INDEX indexName ON tableName ; indexColTypeList @@ -82,4 +82,12 @@ indexColTypeList indexColType : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX) - ; \ No newline at end of file + ; + +indexName + : identifier + ; + +tableName + : multipartIdentifier + ; diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala new file mode 100644 index 000000000..b848f47b4 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.connector.catalog._ + +/** + * Flint utility methods that rely on access to private code in Spark SQL package. + */ +package object flint { + + /** + * Qualify a given table name. + * + * @param spark + * Spark session + * @param tableName + * table name maybe qualified or not + * @return + * qualified table name in catalog.database.table format + */ + def qualifyTableName(spark: SparkSession, tableName: String): String = { + val (catalog, ident) = parseTableName(spark, tableName) + + // Tricky that our Flint delegate catalog's name has to be spark_catalog + // so we have to find its actual name in CatalogManager + val catalogMgr = spark.sessionState.catalogManager + val catalogName = + catalogMgr + .listCatalogs(Some("*")) + .find(catalogMgr.catalog(_) == catalog) + .getOrElse(catalog.name()) + + s"$catalogName.${ident.namespace.mkString(".")}.${ident.name}" + } + + /** + * Parse a given table name into its catalog and table identifier. + * + * @param spark + * Spark session + * @param tableName + * table name maybe qualified or not + * @return + * Spark catalog and table identifier + */ + def parseTableName(spark: SparkSession, tableName: String): (CatalogPlugin, Identifier) = { + // Create a anonymous class to access CatalogAndIdentifier + new LookupCatalog { + override protected val catalogManager: CatalogManager = spark.sessionState.catalogManager + + def parseTableName(): (CatalogPlugin, Identifier) = { + val parts = tableName.split("\\.").toSeq + parts match { + case CatalogAndIdentifier(catalog, ident) => (catalog, ident) + } + } + }.parseTableName() + } + + /** + * Load table for the given table identifier in the catalog. + * + * @param catalog + * Spark catalog + * @param ident + * table identifier + * @return + * Spark table + */ + def loadTable(catalog: CatalogPlugin, ident: Identifier): Option[Table] = { + CatalogV2Util.loadTable(catalog, ident) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 2212826dc..d429244d1 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -8,6 +8,9 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.apache.spark.sql.catalog.Column +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} +import org.apache.spark.sql.types.StructField /** * Flint Spark index builder base class. @@ -17,21 +20,22 @@ import org.apache.spark.sql.catalog.Column */ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { - /** Source table name */ - protected var tableName: String = "" + /** Qualified source table name */ + protected var qualifiedTableName: String = "" /** Index options */ protected var indexOptions: FlintSparkIndexOptions = empty /** All columns of the given source table */ lazy protected val allColumns: Map[String, Column] = { - require(tableName.nonEmpty, "Source table name is not provided") + require(qualifiedTableName.nonEmpty, "Source table name is not provided") - flint.spark.catalog - .listColumns(tableName) - .collect() - .map(col => (col.name, col)) - .toMap + val (catalog, ident) = parseTableName(flint.spark, qualifiedTableName) + val table = loadTable(catalog, ident).getOrElse( + throw new IllegalStateException(s"Table $qualifiedTableName is not found")) + + val allFields = table.schema().fields + allFields.map { field => field.name -> convertFieldToColumn(field) }.toMap } /** @@ -61,8 +65,37 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { */ protected def buildIndex(): FlintSparkIndex + /** + * Table name setter that qualifies given table name for subclass automatically. + */ + protected def tableName_=(tableName: String): Unit = { + qualifiedTableName = qualifyTableName(flint.spark, tableName) + } + + /** + * Table name getter + */ + protected def tableName: String = { + qualifiedTableName + } + + /** + * Find column with the given name. + */ protected def findColumn(colName: String): Column = allColumns.getOrElse( colName, throw new IllegalArgumentException(s"Column $colName does not exist")) + + private def convertFieldToColumn(field: StructField): Column = { + // Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata + new Column( + name = field.name, + description = field.getComment().orNull, + dataType = + CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString, + nullable = field.nullable, + isPartition = false, // useless for now so just set to false + isBucket = false) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 29503919d..e18123a0e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -109,7 +109,9 @@ object FlintSparkCoveringIndex { * Flint covering index name */ def getFlintIndexName(indexName: String, tableName: String): String = { - require(tableName.contains("."), "Full table name database.table is required") + require( + tableName.split("\\.").length >= 3, + "Qualified table name catalog.database.table is required") flintIndexNamePrefix(tableName) + indexName + COVERING_INDEX_SUFFIX } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 56ea0e9bc..11f8ad304 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -8,10 +8,12 @@ package org.opensearch.flint.spark.skipping import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or, Predicate} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.flint.qualifyTableName /** * Flint Spark skipping index apply rule that rewrites applicable query's filtering condition and @@ -32,8 +34,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] Some(table), false)) if hasNoDisjunction(condition) && !location.isInstanceOf[FlintSparkSkippingFileIndex] => - val indexName = getSkippingIndexName(table.identifier.unquotedString) - val index = flint.describeIndex(indexName) + val index = flint.describeIndex(getIndexName(table)) if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) { val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex] val indexFilter = rewriteToIndexFilter(skippingIndex, condition) @@ -58,9 +59,17 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] } } + private def getIndexName(table: CatalogTable): String = { + // Because Spark qualified name only contains database.table without catalog + // the limitation here is qualifyTableName always use current catalog. + val tableName = table.qualifiedName + val qualifiedTableName = qualifyTableName(flint.spark, tableName) + getSkippingIndexName(qualifiedTableName) + } + private def hasNoDisjunction(condition: Expression): Boolean = { - condition.collectFirst { - case Or(_, _) => true + condition.collectFirst { case Or(_, _) => + true }.isEmpty } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index da69cc1fa..9efbc6c4e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -127,7 +127,9 @@ object FlintSparkSkippingIndex { * Flint skipping index name */ def getSkippingIndexName(tableName: String): String = { - require(tableName.contains("."), "Full table name database.table is required") + require( + tableName.split("\\.").length >= 3, + "Qualified table name catalog.database.table is required") flintIndexNamePrefix(tableName) + SKIPPING_INDEX_SUFFIX } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala index 9f3539b09..98abb3878 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala @@ -11,6 +11,7 @@ import org.opensearch.flint.spark.sql.covering.FlintSparkCoveringIndexAstBuilder import org.opensearch.flint.spark.sql.skipping.FlintSparkSkippingIndexAstBuilder import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.flint.qualifyTableName /** * Flint Spark AST builder that builds Spark command for Flint index statement. This class mix-in @@ -33,7 +34,9 @@ class FlintSparkSqlAstBuilder object FlintSparkSqlAstBuilder { /** - * Get full table name if database not specified. + * Get full table name if catalog or database not specified. The reason we cannot do this in + * common SparkSqlAstBuilder.visitTableName is that SparkSession is required to qualify table + * name which is only available at execution time instead of parsing time. * * @param flint * Flint Spark which has access to Spark Catalog @@ -42,12 +45,6 @@ object FlintSparkSqlAstBuilder { * @return */ def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = { - val tableName = tableNameCtx.getText - if (tableName.contains(".")) { - tableName - } else { - val db = flint.spark.catalog.currentDatabase - s"$db.$tableName" - } + qualifyTableName(flint.spark, tableNameCtx.getText) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala index 65a87c568..c0bb47830 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala @@ -28,7 +28,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A ctx: CreateCoveringIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => val indexName = ctx.indexName.getText - val tableName = ctx.tableName.getText + val tableName = getFullTableName(flint, ctx.tableName) val indexBuilder = flint .coveringIndex() diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala new file mode 100644 index 000000000..0cd4a5293 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.scalatest.matchers.must.Matchers.contain +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.spark.FlintSuite + +class FlintSparkIndexBuilderSuite extends FlintSuite { + + override def beforeAll(): Unit = { + super.beforeAll() + + sql(""" + | CREATE TABLE spark_catalog.default.test + | ( name STRING, age INT ) + | USING JSON + """.stripMargin) + } + + protected override def afterAll(): Unit = { + sql("DROP TABLE spark_catalog.default.test") + + super.afterAll() + } + + test("should qualify table name in default database") { + builder() + .onTable("test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + + builder() + .onTable("default.test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + + builder() + .onTable("spark_catalog.default.test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + } + + test("should qualify table name and get columns in other database") { + sql("CREATE DATABASE mydb") + sql("CREATE TABLE mydb.test2 (address STRING) USING JSON") + sql("USE mydb") + + try { + builder() + .onTable("test2") + .expectTableName("spark_catalog.mydb.test2") + .expectAllColumns("address") + + builder() + .onTable("mydb.test2") + .expectTableName("spark_catalog.mydb.test2") + .expectAllColumns("address") + + builder() + .onTable("spark_catalog.mydb.test2") + .expectTableName("spark_catalog.mydb.test2") + .expectAllColumns("address") + + // Can parse any specified table name + builder() + .onTable("spark_catalog.default.test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + } finally { + sql("DROP DATABASE mydb CASCADE") + sql("USE default") + } + } + + private def builder(): FakeFlintSparkIndexBuilder = { + new FakeFlintSparkIndexBuilder + } + + /** + * Fake builder that have access to internal method for assertion + */ + class FakeFlintSparkIndexBuilder extends FlintSparkIndexBuilder(new FlintSpark(spark)) { + + def onTable(tableName: String): FakeFlintSparkIndexBuilder = { + this.tableName = tableName + this + } + + def expectTableName(expected: String): FakeFlintSparkIndexBuilder = { + tableName shouldBe expected + this + } + + def expectAllColumns(expected: String*): FakeFlintSparkIndexBuilder = { + allColumns.keys should contain theSameElementsAs expected + this + } + + override protected def buildIndex(): FlintSparkIndex = { + null + } + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index a50db1af2..8c144b46b 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -12,8 +12,9 @@ import org.apache.spark.FlintSuite class FlintSparkCoveringIndexSuite extends FlintSuite { test("get covering index name") { - val index = new FlintSparkCoveringIndex("ci", "default.test", Map("name" -> "string")) - index.name() shouldBe "flint_default_test_ci_index" + val index = + new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) + index.name() shouldBe "flint_spark_catalog_default_test_ci_index" } test("should fail if get index name without full table name") { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala index 98a2390d8..f9455fbfa 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala @@ -6,7 +6,7 @@ package org.opensearch.flint.spark.skipping import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{doAnswer, when} +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { /** Test table and index */ - private val testTable = "default.apply_skipping_index_test" + private val testTable = "spark_catalog.default.apply_skipping_index_test" private val testIndex = getSkippingIndexName(testTable) private val testSchema = StructType( Seq( @@ -112,7 +112,12 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { } private class AssertionHelper { - private val flint = mock[FlintSpark] + private val flint = { + val mockFlint = mock[FlintSpark](RETURNS_DEEP_STUBS) + when(mockFlint.spark.sessionState.catalogManager.currentCatalog.name()) + .thenReturn("spark_catalog") + mockFlint + } private val rule = new ApplyFlintSparkSkippingIndex(flint) private var relation: LogicalRelation = _ private var plan: LogicalPlan = _ diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 1d65fd821..b31e18480 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -20,16 +20,18 @@ import org.apache.spark.sql.functions.col class FlintSparkSkippingIndexSuite extends FlintSuite { + private val testTable = "spark_catalog.default.test" + test("get skipping index name") { - val index = new FlintSparkSkippingIndex("default.test", Seq(mock[FlintSparkSkippingStrategy])) - index.name() shouldBe "flint_default_test_skipping_index" + val index = new FlintSparkSkippingIndex(testTable, Seq(mock[FlintSparkSkippingStrategy])) + index.name() shouldBe "flint_spark_catalog_default_test_skipping_index" } test("can build index building job with unique ID column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.outputSchema()).thenReturn(Map("name" -> "string")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") val indexDf = index.build(df) @@ -41,7 +43,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("boolean_col" -> "boolean")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("boolean_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -60,7 +62,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("string_col" -> "string")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("string_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -81,7 +83,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("varchar_col" -> "varchar(20)")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("varchar_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -100,7 +102,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("char_col" -> "char(20)")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("char_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -119,7 +121,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("long_col" -> "bigint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("long_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -138,7 +140,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("int_col" -> "int")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("int_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -157,7 +159,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("short_col" -> "smallint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("short_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -176,7 +178,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("byte_col" -> "tinyint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("byte_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -195,7 +197,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("double_col" -> "double")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("double_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -214,7 +216,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("float_col" -> "float")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("float_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -233,7 +235,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("timestamp_col" -> "timestamp")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("timestamp_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -253,7 +255,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("date_col" -> "date")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("date_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -274,7 +276,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { .thenReturn(Map("struct_col" -> "struct")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("struct_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -303,7 +305,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("should fail if no indexed column given") { assertThrows[IllegalArgumentException] { - new FlintSparkSkippingIndex("default.test", Seq.empty) + new FlintSparkSkippingIndex(testTable, Seq.empty) } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala index 140ecdd77..717438f7e 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -16,7 +16,7 @@ import org.apache.spark.sql.Row class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.ci_test" + private val testTable = "spark_catalog.default.ci_test" private val testIndex = "name_and_age" private val testFlintIndex = getFlintIndexName(testIndex, testTable) @@ -56,7 +56,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { | "columnName": "age", | "columnType": "int" | }], - | "source": "default.ci_test", + | "source": "spark_catalog.default.ci_test", | "options": {} | }, | "properties": { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 78ec1619b..e50b17b0a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -17,7 +17,7 @@ import org.apache.spark.sql.Row class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.covering_sql_test" + private val testTable = "spark_catalog.default.covering_sql_test" private val testIndex = "name_and_age" private val testFlintIndex = getFlintIndexName(testIndex, testTable) @@ -67,6 +67,46 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { indexData.count() shouldBe 2 } + test("create covering index on table without database name") { + sql(s"CREATE INDEX $testIndex ON covering_sql_test (name)") + + flint.describeIndex(testFlintIndex) shouldBe defined + } + + test("create covering index on table in other database") { + sql("CREATE SCHEMA sample") + sql("USE sample") + + // Create index without database name specified + sql("CREATE TABLE test1 (name STRING) USING CSV") + sql(s"CREATE INDEX $testIndex ON sample.test1 (name)") + + // Create index with database name specified + sql("CREATE TABLE test2 (name STRING) USING CSV") + sql(s"CREATE INDEX $testIndex ON sample.test2 (name)") + + try { + flint.describeIndex(s"flint_spark_catalog_sample_test1_${testIndex}_index") shouldBe defined + flint.describeIndex(s"flint_spark_catalog_sample_test2_${testIndex}_index") shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } + } + + test("create covering index on table in other database than current") { + sql("CREATE SCHEMA sample") + sql("USE sample") + + // Specify database "default" in table name instead of current "sample" database + sql(s"CREATE INDEX $testIndex ON $testTable (name)") + + try { + flint.describeIndex(testFlintIndex) shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } + } + test("create covering index if not exists") { sql(s""" | CREATE INDEX IF NOT EXISTS $testIndex diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 7d5598cd5..05f398b88 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.functions.col class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.test" + private val testTable = "spark_catalog.default.test" private val testIndex = getSkippingIndexName(testTable) override def beforeAll(): Unit = { @@ -49,12 +49,11 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { .addMinMax("age") .create() - val indexName = s"flint_default_test_skipping_index" - val index = flint.describeIndex(indexName) + val index = flint.describeIndex(testIndex) index shouldBe defined index.get.metadata().getContent should matchJson(s"""{ | "_meta": { - | "name": "flint_default_test_skipping_index", + | "name": "flint_spark_catalog_default_test_skipping_index", | "version": "${current()}", | "kind": "skipping", | "indexedColumns": [ @@ -78,7 +77,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | "columnName": "age", | "columnType": "int" | }], - | "source": "default.test", + | "source": "spark_catalog.default.test", | "options": {} | }, | "properties": { @@ -371,7 +370,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { test("create skipping index for all supported data types successfully") { // Prepare test table - val testTable = "default.data_type_table" + val testTable = "spark_catalog.default.data_type_table" val testIndex = getSkippingIndexName(testTable) sql( s""" @@ -437,7 +436,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { index.get.metadata().getContent should matchJson( s"""{ | "_meta": { - | "name": "flint_default_data_type_table_skipping_index", + | "name": "flint_spark_catalog_default_data_type_table_skipping_index", | "version": "${current()}", | "kind": "skipping", | "indexedColumns": [ @@ -569,7 +568,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { } test("can build skipping index for varchar and char and rewrite applicable query") { - val testTable = "default.varchar_char_table" + val testTable = "spark_catalog.default.varchar_char_table" val testIndex = getSkippingIndexName(testTable) sql( s""" diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index eec9dfca9..23a1bb542 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -17,7 +17,7 @@ import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.skipping_sql_test" + private val testTable = "spark_catalog.default.skipping_sql_test" private val testIndex = getSkippingIndexName(testTable) override def beforeAll(): Unit = { @@ -136,14 +136,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { } test("create skipping index on table without database name") { - sql(s""" - | CREATE SKIPPING INDEX ON skipping_sql_test - | ( - | year PARTITION, - | name VALUE_SET, - | age MIN_MAX - | ) - | """.stripMargin) + sql("CREATE SKIPPING INDEX ON skipping_sql_test ( year PARTITION )") flint.describeIndex(testIndex) shouldBe defined } @@ -151,10 +144,35 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { test("create skipping index on table in other database") { sql("CREATE SCHEMA sample") sql("USE sample") - sql("CREATE TABLE test (name STRING) USING CSV") - sql("CREATE SKIPPING INDEX ON test (name VALUE_SET)") - flint.describeIndex("flint_sample_test_skipping_index") shouldBe defined + // Create index without database name specified + sql("CREATE TABLE test1 (name STRING) USING CSV") + sql("CREATE SKIPPING INDEX ON test1 (name VALUE_SET)") + + // Create index with database name specified + sql("CREATE TABLE test2 (name STRING) USING CSV") + sql("CREATE SKIPPING INDEX ON sample.test2 (name VALUE_SET)") + + try { + flint.describeIndex("flint_spark_catalog_sample_test1_skipping_index") shouldBe defined + flint.describeIndex("flint_spark_catalog_sample_test2_skipping_index") shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } + } + + test("create skipping index on table in other database than current") { + sql("CREATE SCHEMA sample") + sql("USE sample") + + // Specify database "default" in table name instead of current "sample" database + sql(s"CREATE SKIPPING INDEX ON $testTable (name VALUE_SET)") + + try { + flint.describeIndex(testIndex) shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } } test("should return empty if no skipping index to describe") {