From 9b73e3ec957fb36664a835ddcd3051c5e3d17012 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 27 Sep 2023 10:38:48 -0700 Subject: [PATCH 1/6] Change to qualified name in Flint Spark index API layer Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndexBuilder.scala | 4 +- .../covering/FlintSparkCoveringIndex.scala | 4 +- .../ApplyFlintSparkSkippingIndex.scala | 16 +++-- .../skipping/FlintSparkSkippingIndex.scala | 4 +- .../flint/spark/util/QualifiedTableName.scala | 68 +++++++++++++++++++ .../FlintSparkCoveringIndexSuite.scala | 5 +- .../ApplyFlintSparkSkippingIndexSuite.scala | 11 ++- .../FlintSparkSkippingIndexSuite.scala | 36 +++++----- .../FlintSparkCoveringIndexITSuite.scala | 4 +- .../FlintSparkCoveringIndexSqlITSuite.scala | 2 +- .../FlintSparkSkippingIndexITSuite.scala | 15 ++-- .../FlintSparkSkippingIndexSqlITSuite.scala | 2 +- 12 files changed, 130 insertions(+), 41 deletions(-) create mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 95e351f7d..493a80751 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.empty +import org.opensearch.flint.spark.util.QualifiedTableName import org.apache.spark.sql.catalog.Column @@ -27,8 +28,9 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { lazy protected val allColumns: Map[String, Column] = { require(tableName.nonEmpty, "Source table name is not provided") + val qualified = new QualifiedTableName(tableName)(flint.spark) flint.spark.catalog - .listColumns(tableName) + .listColumns(qualified.nameWithoutCatalog) .collect() .map(col => (col.name, col)) .toMap diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 29503919d..e18123a0e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -109,7 +109,9 @@ object FlintSparkCoveringIndex { * Flint covering index name */ def getFlintIndexName(indexName: String, tableName: String): String = { - require(tableName.contains("."), "Full table name database.table is required") + require( + tableName.split("\\.").length >= 3, + "Qualified table name catalog.database.table is required") flintIndexNamePrefix(tableName) + indexName + COVERING_INDEX_SUFFIX } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 56ea0e9bc..85068d413 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -7,7 +7,9 @@ package org.opensearch.flint.spark.skipping import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} +import org.opensearch.flint.spark.util.QualifiedTableName +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or, Predicate} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule @@ -32,8 +34,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] Some(table), false)) if hasNoDisjunction(condition) && !location.isInstanceOf[FlintSparkSkippingFileIndex] => - val indexName = getSkippingIndexName(table.identifier.unquotedString) - val index = flint.describeIndex(indexName) + val index = flint.describeIndex(getIndexName(table)) if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) { val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex] val indexFilter = rewriteToIndexFilter(skippingIndex, condition) @@ -58,9 +59,16 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] } } + private def getIndexName(table: CatalogTable): String = { + // Spark qualified name only contains database.table without catalog + val tableName = table.qualifiedName + val qualifiedTableName = new QualifiedTableName(tableName)(flint.spark).name + getSkippingIndexName(qualifiedTableName) + } + private def hasNoDisjunction(condition: Expression): Boolean = { - condition.collectFirst { - case Or(_, _) => true + condition.collectFirst { case Or(_, _) => + true }.isEmpty } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index da69cc1fa..9efbc6c4e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -127,7 +127,9 @@ object FlintSparkSkippingIndex { * Flint skipping index name */ def getSkippingIndexName(tableName: String): String = { - require(tableName.contains("."), "Full table name database.table is required") + require( + tableName.split("\\.").length >= 3, + "Qualified table name catalog.database.table is required") flintIndexNamePrefix(tableName) + SKIPPING_INDEX_SUFFIX } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala new file mode 100644 index 000000000..ece282bee --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.util + +import org.opensearch.flint.spark.util.QualifiedTableName.{catalogName, tableNameWithoutCatalog} + +import org.apache.spark.sql.SparkSession + +/** + * Qualified table name class that encapsulates table name parsing and qualifying utility. This is + * useful because Spark doesn't associate catalog info in logical plan even after analyzed. + * + * @param tableName + * table name maybe qualified or not + * @param spark + * Spark session to get current catalog and database info + */ +class QualifiedTableName(tableName: String)(spark: SparkSession) { + + /** Qualified table name */ + lazy private val qualifiedTableName: String = { + val parts = tableName.split("\\.") + if (parts.length == 1) { + s"$currentCatalog.$currentDatabase.$tableName" + } else if (parts.length == 2) { + s"$currentCatalog.$tableName" + } else { + tableName + } + } + + def name: String = qualifiedTableName + + def catalog: String = catalogName(qualifiedTableName) + + def nameWithoutCatalog: String = tableNameWithoutCatalog(qualifiedTableName) + + private def currentCatalog: String = { + require(spark != null, "Spark session required to unqualify the given table name") + + val catalogMgr = spark.sessionState.catalogManager + catalogMgr.currentCatalog.name() + } + + private def currentDatabase: String = { + require(spark != null, "Spark session required to unqualify the given table name") + + val catalogMgr = spark.sessionState.catalogManager + catalogMgr.currentNamespace.mkString(".") + } +} + +/** + * Utility methods for table name already qualified and thus has no dependency on Spark session. + */ +object QualifiedTableName { + + def catalogName(qualifiedTableName: String): String = { + qualifiedTableName.substring(0, qualifiedTableName.indexOf(".")) + } + + def tableNameWithoutCatalog(qualifiedTableName: String): String = { + qualifiedTableName.substring(qualifiedTableName.indexOf(".") + 1) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index a50db1af2..8c144b46b 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -12,8 +12,9 @@ import org.apache.spark.FlintSuite class FlintSparkCoveringIndexSuite extends FlintSuite { test("get covering index name") { - val index = new FlintSparkCoveringIndex("ci", "default.test", Map("name" -> "string")) - index.name() shouldBe "flint_default_test_ci_index" + val index = + new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) + index.name() shouldBe "flint_spark_catalog_default_test_ci_index" } test("should fail if get index name without full table name") { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala index 98a2390d8..f9455fbfa 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala @@ -6,7 +6,7 @@ package org.opensearch.flint.spark.skipping import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{doAnswer, when} +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { /** Test table and index */ - private val testTable = "default.apply_skipping_index_test" + private val testTable = "spark_catalog.default.apply_skipping_index_test" private val testIndex = getSkippingIndexName(testTable) private val testSchema = StructType( Seq( @@ -112,7 +112,12 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { } private class AssertionHelper { - private val flint = mock[FlintSpark] + private val flint = { + val mockFlint = mock[FlintSpark](RETURNS_DEEP_STUBS) + when(mockFlint.spark.sessionState.catalogManager.currentCatalog.name()) + .thenReturn("spark_catalog") + mockFlint + } private val rule = new ApplyFlintSparkSkippingIndex(flint) private var relation: LogicalRelation = _ private var plan: LogicalPlan = _ diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 1d65fd821..b31e18480 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -20,16 +20,18 @@ import org.apache.spark.sql.functions.col class FlintSparkSkippingIndexSuite extends FlintSuite { + private val testTable = "spark_catalog.default.test" + test("get skipping index name") { - val index = new FlintSparkSkippingIndex("default.test", Seq(mock[FlintSparkSkippingStrategy])) - index.name() shouldBe "flint_default_test_skipping_index" + val index = new FlintSparkSkippingIndex(testTable, Seq(mock[FlintSparkSkippingStrategy])) + index.name() shouldBe "flint_spark_catalog_default_test_skipping_index" } test("can build index building job with unique ID column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.outputSchema()).thenReturn(Map("name" -> "string")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") val indexDf = index.build(df) @@ -41,7 +43,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("boolean_col" -> "boolean")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("boolean_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -60,7 +62,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("string_col" -> "string")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("string_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -81,7 +83,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("varchar_col" -> "varchar(20)")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("varchar_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -100,7 +102,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("char_col" -> "char(20)")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("char_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -119,7 +121,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("long_col" -> "bigint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("long_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -138,7 +140,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("int_col" -> "int")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("int_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -157,7 +159,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("short_col" -> "smallint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("short_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -176,7 +178,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("byte_col" -> "tinyint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("byte_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -195,7 +197,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("double_col" -> "double")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("double_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -214,7 +216,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("float_col" -> "float")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("float_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -233,7 +235,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("timestamp_col" -> "timestamp")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("timestamp_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -253,7 +255,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { when(indexCol.outputSchema()).thenReturn(Map("date_col" -> "date")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("date_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -274,7 +276,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { .thenReturn(Map("struct_col" -> "struct")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("struct_col").expr))) - val index = new FlintSparkSkippingIndex("default.test", Seq(indexCol)) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) schemaShouldMatch( index.metadata(), s"""{ @@ -303,7 +305,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("should fail if no indexed column given") { assertThrows[IllegalArgumentException] { - new FlintSparkSkippingIndex("default.test", Seq.empty) + new FlintSparkSkippingIndex(testTable, Seq.empty) } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala index 140ecdd77..717438f7e 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -16,7 +16,7 @@ import org.apache.spark.sql.Row class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.ci_test" + private val testTable = "spark_catalog.default.ci_test" private val testIndex = "name_and_age" private val testFlintIndex = getFlintIndexName(testIndex, testTable) @@ -56,7 +56,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { | "columnName": "age", | "columnType": "int" | }], - | "source": "default.ci_test", + | "source": "spark_catalog.default.ci_test", | "options": {} | }, | "properties": { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 66d19a261..9fb580b60 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -17,7 +17,7 @@ import org.apache.spark.sql.Row class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.covering_sql_test" + private val testTable = "spark_catalog.default.covering_sql_test" private val testIndex = "name_and_age" private val testFlintIndex = getFlintIndexName(testIndex, testTable) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 7d5598cd5..05f398b88 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.functions.col class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.test" + private val testTable = "spark_catalog.default.test" private val testIndex = getSkippingIndexName(testTable) override def beforeAll(): Unit = { @@ -49,12 +49,11 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { .addMinMax("age") .create() - val indexName = s"flint_default_test_skipping_index" - val index = flint.describeIndex(indexName) + val index = flint.describeIndex(testIndex) index shouldBe defined index.get.metadata().getContent should matchJson(s"""{ | "_meta": { - | "name": "flint_default_test_skipping_index", + | "name": "flint_spark_catalog_default_test_skipping_index", | "version": "${current()}", | "kind": "skipping", | "indexedColumns": [ @@ -78,7 +77,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | "columnName": "age", | "columnType": "int" | }], - | "source": "default.test", + | "source": "spark_catalog.default.test", | "options": {} | }, | "properties": { @@ -371,7 +370,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { test("create skipping index for all supported data types successfully") { // Prepare test table - val testTable = "default.data_type_table" + val testTable = "spark_catalog.default.data_type_table" val testIndex = getSkippingIndexName(testTable) sql( s""" @@ -437,7 +436,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { index.get.metadata().getContent should matchJson( s"""{ | "_meta": { - | "name": "flint_default_data_type_table_skipping_index", + | "name": "flint_spark_catalog_default_data_type_table_skipping_index", | "version": "${current()}", | "kind": "skipping", | "indexedColumns": [ @@ -569,7 +568,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { } test("can build skipping index for varchar and char and rewrite applicable query") { - val testTable = "default.varchar_char_table" + val testTable = "spark_catalog.default.varchar_char_table" val testIndex = getSkippingIndexName(testTable) sql( s""" diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index 5846cab23..bf599d137 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -17,7 +17,7 @@ import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "default.skipping_sql_test" + private val testTable = "spark_catalog.default.skipping_sql_test" private val testIndex = getSkippingIndexName(testTable) override def beforeAll(): Unit = { From 3d7de22fdc795112937e4425df114478d3807853 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 27 Sep 2023 13:29:17 -0700 Subject: [PATCH 2/6] Qualify table name in Flint SQL layer Signed-off-by: Chen Dai --- .../main/antlr4/FlintSparkSqlExtensions.g4 | 28 ++++++++++++------- .../flint/spark/FlintSparkIndexBuilder.scala | 5 +++- .../ApplyFlintSparkSkippingIndex.scala | 2 +- .../spark/sql/FlintSparkSqlAstBuilder.scala | 13 ++++----- .../flint/spark/util/QualifiedTableName.scala | 18 ++++++++++-- .../FlintSparkSkippingIndexSqlITSuite.scala | 2 +- 6 files changed, 44 insertions(+), 24 deletions(-) diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index 12f69680e..2d50fbc49 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -27,21 +27,21 @@ skippingIndexStatement ; createSkippingIndexStatement - : CREATE SKIPPING INDEX ON tableName=multipartIdentifier + : CREATE SKIPPING INDEX ON tableName LEFT_PAREN indexColTypeList RIGHT_PAREN (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; refreshSkippingIndexStatement - : REFRESH SKIPPING INDEX ON tableName=multipartIdentifier + : REFRESH SKIPPING INDEX ON tableName ; describeSkippingIndexStatement - : (DESC | DESCRIBE) SKIPPING INDEX ON tableName=multipartIdentifier + : (DESC | DESCRIBE) SKIPPING INDEX ON tableName ; dropSkippingIndexStatement - : DROP SKIPPING INDEX ON tableName=multipartIdentifier + : DROP SKIPPING INDEX ON tableName ; coveringIndexStatement @@ -53,25 +53,25 @@ coveringIndexStatement ; createCoveringIndexStatement - : CREATE INDEX indexName=identifier ON tableName=multipartIdentifier + : CREATE INDEX indexName ON tableName LEFT_PAREN indexColumns=multipartIdentifierPropertyList RIGHT_PAREN (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; refreshCoveringIndexStatement - : REFRESH INDEX indexName=identifier ON tableName=multipartIdentifier + : REFRESH INDEX indexName ON tableName ; showCoveringIndexStatement - : SHOW (INDEX | INDEXES) ON tableName=multipartIdentifier + : SHOW (INDEX | INDEXES) ON tableName ; describeCoveringIndexStatement - : (DESC | DESCRIBE) INDEX indexName=identifier ON tableName=multipartIdentifier + : (DESC | DESCRIBE) INDEX indexName ON tableName ; dropCoveringIndexStatement - : DROP INDEX indexName=identifier ON tableName=multipartIdentifier + : DROP INDEX indexName ON tableName ; indexColTypeList @@ -80,4 +80,12 @@ indexColTypeList indexColType : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX) - ; \ No newline at end of file + ; + +indexName + : identifier + ; + +tableName + : multipartIdentifier + ; diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 493a80751..6b86c9cef 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -28,7 +28,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { lazy protected val allColumns: Map[String, Column] = { require(tableName.nonEmpty, "Source table name is not provided") - val qualified = new QualifiedTableName(tableName)(flint.spark) + val qualified = new QualifiedTableName(flint.spark, tableName) flint.spark.catalog .listColumns(qualified.nameWithoutCatalog) .collect() @@ -59,6 +59,9 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { */ protected def buildIndex(): FlintSparkIndex + /** + * Find column with the given name. + */ protected def findColumn(colName: String): Column = allColumns.getOrElse( colName, diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 85068d413..3eac8ca20 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -62,7 +62,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] private def getIndexName(table: CatalogTable): String = { // Spark qualified name only contains database.table without catalog val tableName = table.qualifiedName - val qualifiedTableName = new QualifiedTableName(tableName)(flint.spark).name + val qualifiedTableName = new QualifiedTableName(flint.spark, tableName).name getSkippingIndexName(qualifiedTableName) } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala index 9f3539b09..96bb20003 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala @@ -9,6 +9,7 @@ import org.antlr.v4.runtime.tree.{ParseTree, RuleNode} import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.sql.covering.FlintSparkCoveringIndexAstBuilder import org.opensearch.flint.spark.sql.skipping.FlintSparkSkippingIndexAstBuilder +import org.opensearch.flint.spark.util.QualifiedTableName import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -33,7 +34,9 @@ class FlintSparkSqlAstBuilder object FlintSparkSqlAstBuilder { /** - * Get full table name if database not specified. + * Get full table name if catalog or database not specified. The reason we cannot do this in + * common SparkSqlAstBuilder.visitTableName is that SparkSession is required to qualify table + * name which is only available at execution time instead of parsing time. * * @param flint * Flint Spark which has access to Spark Catalog @@ -42,12 +45,6 @@ object FlintSparkSqlAstBuilder { * @return */ def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = { - val tableName = tableNameCtx.getText - if (tableName.contains(".")) { - tableName - } else { - val db = flint.spark.catalog.currentDatabase - s"$db.$tableName" - } + QualifiedTableName(flint.spark, tableNameCtx.getText).name } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala index ece282bee..d1da39692 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala @@ -18,10 +18,10 @@ import org.apache.spark.sql.SparkSession * @param spark * Spark session to get current catalog and database info */ -class QualifiedTableName(tableName: String)(spark: SparkSession) { +case class QualifiedTableName(spark: SparkSession, tableName: String) { /** Qualified table name */ - lazy private val qualifiedTableName: String = { + private val qualifiedTableName: String = { val parts = tableName.split("\\.") if (parts.length == 1) { s"$currentCatalog.$currentDatabase.$tableName" @@ -32,10 +32,22 @@ class QualifiedTableName(tableName: String)(spark: SparkSession) { } } + /** + * @return + * Qualified table name + */ def name: String = qualifiedTableName + /** + * @return + * catalog name only + */ def catalog: String = catalogName(qualifiedTableName) + /** + * @return + * database and table name only + */ def nameWithoutCatalog: String = tableNameWithoutCatalog(qualifiedTableName) private def currentCatalog: String = { @@ -54,7 +66,7 @@ class QualifiedTableName(tableName: String)(spark: SparkSession) { } /** - * Utility methods for table name already qualified and thus has no dependency on Spark session. + * Utility methods for table name already qualified and thus dont' need Spark session. */ object QualifiedTableName { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index bf599d137..226b6e371 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -132,7 +132,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { sql("CREATE TABLE test (name STRING) USING CSV") sql("CREATE SKIPPING INDEX ON test (name VALUE_SET)") - flint.describeIndex("flint_sample_test_skipping_index") shouldBe defined + flint.describeIndex("flint_spark_catalog_sample_test_skipping_index") shouldBe defined } test("should return empty if no skipping index to describe") { From 84bf7932151dc3c9e3d87b6c4d4862da68ed7795 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 27 Sep 2023 14:08:29 -0700 Subject: [PATCH 3/6] Add more IT Signed-off-by: Chen Dai --- .../FlintSparkCoveringIndexAstBuilder.scala | 2 +- .../FlintSparkCoveringIndexSqlITSuite.scala | 40 +++++++++++++++++++ .../FlintSparkSkippingIndexSqlITSuite.scala | 40 ++++++++++++++----- 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala index 689dfaefc..e05d5da64 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala @@ -28,7 +28,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A ctx: CreateCoveringIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => val indexName = ctx.indexName.getText - val tableName = ctx.tableName.getText + val tableName = getFullTableName(flint, ctx.tableName) val indexBuilder = flint .coveringIndex() diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 9fb580b60..3be51f573 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -67,6 +67,46 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { indexData.count() shouldBe 2 } + test("create covering index on table without database name") { + sql(s"CREATE INDEX $testIndex ON covering_sql_test (name)") + + flint.describeIndex(testFlintIndex) shouldBe defined + } + + test("create covering index on table in other database") { + sql("CREATE SCHEMA sample") + sql("USE sample") + + // Create index without database name specified + sql("CREATE TABLE test1 (name STRING) USING CSV") + sql(s"CREATE INDEX $testIndex ON sample.test1 (name)") + + // Create index with database name specified + sql("CREATE TABLE test2 (name STRING) USING CSV") + sql(s"CREATE INDEX $testIndex ON sample.test2 (name)") + + try { + flint.describeIndex(s"flint_spark_catalog_sample_test1_${testIndex}_index") shouldBe defined + flint.describeIndex(s"flint_spark_catalog_sample_test2_${testIndex}_index") shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } + } + + test("create covering index on table in other database than current") { + sql("CREATE SCHEMA sample") + sql("USE sample") + + // Specify database "default" in table name instead of current "sample" database + sql(s"CREATE INDEX $testIndex ON $testTable (name)") + + try { + flint.describeIndex(testFlintIndex) shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } + } + test("show all covering index on the source table") { flint .coveringIndex() diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index 226b6e371..fc70f340f 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -114,14 +114,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { } test("create skipping index on table without database name") { - sql(s""" - | CREATE SKIPPING INDEX ON skipping_sql_test - | ( - | year PARTITION, - | name VALUE_SET, - | age MIN_MAX - | ) - | """.stripMargin) + sql("CREATE SKIPPING INDEX ON skipping_sql_test ( year PARTITION )") flint.describeIndex(testIndex) shouldBe defined } @@ -129,10 +122,35 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { test("create skipping index on table in other database") { sql("CREATE SCHEMA sample") sql("USE sample") - sql("CREATE TABLE test (name STRING) USING CSV") - sql("CREATE SKIPPING INDEX ON test (name VALUE_SET)") - flint.describeIndex("flint_spark_catalog_sample_test_skipping_index") shouldBe defined + // Create index without database name specified + sql("CREATE TABLE test1 (name STRING) USING CSV") + sql("CREATE SKIPPING INDEX ON test1 (name VALUE_SET)") + + // Create index with database name specified + sql("CREATE TABLE test2 (name STRING) USING CSV") + sql("CREATE SKIPPING INDEX ON sample.test2 (name VALUE_SET)") + + try { + flint.describeIndex("flint_spark_catalog_sample_test1_skipping_index") shouldBe defined + flint.describeIndex("flint_spark_catalog_sample_test2_skipping_index") shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } + } + + test("create skipping index on table in other database than current") { + sql("CREATE SCHEMA sample") + sql("USE sample") + + // Specify database "default" in table name instead of current "sample" database + sql(s"CREATE SKIPPING INDEX ON $testTable (name VALUE_SET)") + + try { + flint.describeIndex(testIndex) shouldBe defined + } finally { + sql("DROP DATABASE sample CASCADE") + } } test("should return empty if no skipping index to describe") { From ccf379809177c04385785dbbb73a19593cc0536e Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 28 Sep 2023 14:19:55 -0700 Subject: [PATCH 4/6] Reuse Spark utility method for parsing Signed-off-by: Chen Dai --- .../org/apache/spark/sql/flint/package.scala | 34 ++++++++ .../flint/spark/FlintSparkIndexBuilder.scala | 26 ++++-- .../ApplyFlintSparkSkippingIndex.scala | 4 +- .../spark/sql/FlintSparkSqlAstBuilder.scala | 4 +- .../flint/spark/util/QualifiedTableName.scala | 80 ------------------- 5 files changed, 58 insertions(+), 90 deletions(-) create mode 100644 flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala delete mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala new file mode 100644 index 000000000..afdbf38ba --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.connector.catalog._ + +package object flint { + + def qualifyTableName(spark: SparkSession, tableName: String): String = { + val (catalog, ident) = parseTableName(spark, tableName) + s"${catalog.name}.${ident.namespace.mkString(".")}.${ident.name}" + } + + def parseTableName(spark: SparkSession, tableName: String): (CatalogPlugin, Identifier) = { + new LookupCatalog { + override protected val catalogManager: CatalogManager = spark.sessionState.catalogManager + + def parseTableName(): (CatalogPlugin, Identifier) = { + val parts = tableName.split("\\.").toSeq + parts match { + case CatalogAndIdentifier(catalog, ident) => (catalog, ident) + // case _ => None + } + } + }.parseTableName() + } + + def loadTable(catalog: CatalogPlugin, ident: Identifier): Option[Table] = { + CatalogV2Util.loadTable(catalog, ident) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 6b86c9cef..116764312 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -6,9 +6,10 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.empty -import org.opensearch.flint.spark.util.QualifiedTableName import org.apache.spark.sql.catalog.Column +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.flint.{loadTable, parseTableName} /** * Flint Spark index builder base class. @@ -28,11 +29,24 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { lazy protected val allColumns: Map[String, Column] = { require(tableName.nonEmpty, "Source table name is not provided") - val qualified = new QualifiedTableName(flint.spark, tableName) - flint.spark.catalog - .listColumns(qualified.nameWithoutCatalog) - .collect() - .map(col => (col.name, col)) + val (catalog, ident) = parseTableName(flint.spark, tableName) + val table = loadTable(catalog, ident).getOrElse( + throw new IllegalStateException(s"Table $tableName is not found")) + + table + .schema() + .fields + .map { field => + field.name -> new Column( + name = field.name, + description = field.getComment().orNull, + dataType = + // CatalogImpl.listColumns: Varchar/Char is StringType with real type name in metadata + CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString, + nullable = field.nullable, + isPartition = false, // useless for now so just set to false + isBucket = false) + } .toMap } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 3eac8ca20..a318908a6 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -7,13 +7,13 @@ package org.opensearch.flint.spark.skipping import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} -import org.opensearch.flint.spark.util.QualifiedTableName import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or, Predicate} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.flint.qualifyTableName /** * Flint Spark skipping index apply rule that rewrites applicable query's filtering condition and @@ -62,7 +62,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] private def getIndexName(table: CatalogTable): String = { // Spark qualified name only contains database.table without catalog val tableName = table.qualifiedName - val qualifiedTableName = new QualifiedTableName(flint.spark, tableName).name + val qualifiedTableName = qualifyTableName(flint.spark, tableName) getSkippingIndexName(qualifiedTableName) } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala index 96bb20003..98abb3878 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala @@ -9,9 +9,9 @@ import org.antlr.v4.runtime.tree.{ParseTree, RuleNode} import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.sql.covering.FlintSparkCoveringIndexAstBuilder import org.opensearch.flint.spark.sql.skipping.FlintSparkSkippingIndexAstBuilder -import org.opensearch.flint.spark.util.QualifiedTableName import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.flint.qualifyTableName /** * Flint Spark AST builder that builds Spark command for Flint index statement. This class mix-in @@ -45,6 +45,6 @@ object FlintSparkSqlAstBuilder { * @return */ def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = { - QualifiedTableName(flint.spark, tableNameCtx.getText).name + qualifyTableName(flint.spark, tableNameCtx.getText) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala deleted file mode 100644 index d1da39692..000000000 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/util/QualifiedTableName.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.spark.util - -import org.opensearch.flint.spark.util.QualifiedTableName.{catalogName, tableNameWithoutCatalog} - -import org.apache.spark.sql.SparkSession - -/** - * Qualified table name class that encapsulates table name parsing and qualifying utility. This is - * useful because Spark doesn't associate catalog info in logical plan even after analyzed. - * - * @param tableName - * table name maybe qualified or not - * @param spark - * Spark session to get current catalog and database info - */ -case class QualifiedTableName(spark: SparkSession, tableName: String) { - - /** Qualified table name */ - private val qualifiedTableName: String = { - val parts = tableName.split("\\.") - if (parts.length == 1) { - s"$currentCatalog.$currentDatabase.$tableName" - } else if (parts.length == 2) { - s"$currentCatalog.$tableName" - } else { - tableName - } - } - - /** - * @return - * Qualified table name - */ - def name: String = qualifiedTableName - - /** - * @return - * catalog name only - */ - def catalog: String = catalogName(qualifiedTableName) - - /** - * @return - * database and table name only - */ - def nameWithoutCatalog: String = tableNameWithoutCatalog(qualifiedTableName) - - private def currentCatalog: String = { - require(spark != null, "Spark session required to unqualify the given table name") - - val catalogMgr = spark.sessionState.catalogManager - catalogMgr.currentCatalog.name() - } - - private def currentDatabase: String = { - require(spark != null, "Spark session required to unqualify the given table name") - - val catalogMgr = spark.sessionState.catalogManager - catalogMgr.currentNamespace.mkString(".") - } -} - -/** - * Utility methods for table name already qualified and thus dont' need Spark session. - */ -object QualifiedTableName { - - def catalogName(qualifiedTableName: String): String = { - qualifiedTableName.substring(0, qualifiedTableName.indexOf(".")) - } - - def tableNameWithoutCatalog(qualifiedTableName: String): String = { - qualifiedTableName.substring(qualifiedTableName.indexOf(".") + 1) - } -} From 25805cc39cfb900bf4ff6f2d9e20400385b8c4bf Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 28 Sep 2023 14:46:41 -0700 Subject: [PATCH 5/6] Update javadoc Signed-off-by: Chen Dai --- .../org/apache/spark/sql/flint/package.scala | 35 ++++++++++++++++++- .../flint/spark/FlintSparkIndexBuilder.scala | 2 +- .../ApplyFlintSparkSkippingIndex.scala | 3 +- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index afdbf38ba..36478c957 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -7,14 +7,38 @@ package org.apache.spark.sql import org.apache.spark.sql.connector.catalog._ +/** + * Flint utility methods that rely on access to private code in Spark SQL package. + */ package object flint { + /** + * Qualify a given table name. + * + * @param spark + * Spark session + * @param tableName + * table name maybe qualified or not + * @return + * qualified table name in catalog.database.table format + */ def qualifyTableName(spark: SparkSession, tableName: String): String = { val (catalog, ident) = parseTableName(spark, tableName) s"${catalog.name}.${ident.namespace.mkString(".")}.${ident.name}" } + /** + * Parse a given table name into its catalog and table identifier. + * + * @param spark + * Spark session + * @param tableName + * table name maybe qualified or not + * @return + * Spark catalog and table identifier + */ def parseTableName(spark: SparkSession, tableName: String): (CatalogPlugin, Identifier) = { + // Create a anonymous class to access CatalogAndIdentifier new LookupCatalog { override protected val catalogManager: CatalogManager = spark.sessionState.catalogManager @@ -22,12 +46,21 @@ package object flint { val parts = tableName.split("\\.").toSeq parts match { case CatalogAndIdentifier(catalog, ident) => (catalog, ident) - // case _ => None } } }.parseTableName() } + /** + * Load table for the given table identifier in the catalog. + * + * @param catalog + * Spark catalog + * @param ident + * table identifier + * @return + * Spark table + */ def loadTable(catalog: CatalogPlugin, ident: Identifier): Option[Table] = { CatalogV2Util.loadTable(catalog, ident) } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 116764312..af6850e10 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -33,6 +33,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { val table = loadTable(catalog, ident).getOrElse( throw new IllegalStateException(s"Table $tableName is not found")) + // Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata table .schema() .fields @@ -41,7 +42,6 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { name = field.name, description = field.getComment().orNull, dataType = - // CatalogImpl.listColumns: Varchar/Char is StringType with real type name in metadata CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString, nullable = field.nullable, isPartition = false, // useless for now so just set to false diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index a318908a6..11f8ad304 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -60,7 +60,8 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] } private def getIndexName(table: CatalogTable): String = { - // Spark qualified name only contains database.table without catalog + // Because Spark qualified name only contains database.table without catalog + // the limitation here is qualifyTableName always use current catalog. val tableName = table.qualifiedName val qualifiedTableName = qualifyTableName(flint.spark, tableName) getSkippingIndexName(qualifiedTableName) From 0e9f682f5899b09648d1b49199e08d1873942212 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 29 Sep 2023 15:24:22 -0700 Subject: [PATCH 6/6] Fix catalog plugin name issue Signed-off-by: Chen Dai --- .../org/apache/spark/sql/flint/package.scala | 12 +- .../flint/spark/FlintSparkIndexBuilder.scala | 56 +++++---- .../spark/FlintSparkIndexBuilderSuite.scala | 108 ++++++++++++++++++ 3 files changed, 154 insertions(+), 22 deletions(-) create mode 100644 flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index 36478c957..b848f47b4 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -24,7 +24,17 @@ package object flint { */ def qualifyTableName(spark: SparkSession, tableName: String): String = { val (catalog, ident) = parseTableName(spark, tableName) - s"${catalog.name}.${ident.namespace.mkString(".")}.${ident.name}" + + // Tricky that our Flint delegate catalog's name has to be spark_catalog + // so we have to find its actual name in CatalogManager + val catalogMgr = spark.sessionState.catalogManager + val catalogName = + catalogMgr + .listCatalogs(Some("*")) + .find(catalogMgr.catalog(_) == catalog) + .getOrElse(catalog.name()) + + s"$catalogName.${ident.namespace.mkString(".")}.${ident.name}" } /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index cdcedb36b..d429244d1 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -9,7 +9,8 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.apache.spark.sql.catalog.Column import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.flint.{loadTable, parseTableName} +import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} +import org.apache.spark.sql.types.StructField /** * Flint Spark index builder base class. @@ -19,35 +20,22 @@ import org.apache.spark.sql.flint.{loadTable, parseTableName} */ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { - /** Source table name */ - protected var tableName: String = "" + /** Qualified source table name */ + protected var qualifiedTableName: String = "" /** Index options */ protected var indexOptions: FlintSparkIndexOptions = empty /** All columns of the given source table */ lazy protected val allColumns: Map[String, Column] = { - require(tableName.nonEmpty, "Source table name is not provided") + require(qualifiedTableName.nonEmpty, "Source table name is not provided") - val (catalog, ident) = parseTableName(flint.spark, tableName) + val (catalog, ident) = parseTableName(flint.spark, qualifiedTableName) val table = loadTable(catalog, ident).getOrElse( - throw new IllegalStateException(s"Table $tableName is not found")) + throw new IllegalStateException(s"Table $qualifiedTableName is not found")) - // Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata - table - .schema() - .fields - .map { field => - field.name -> new Column( - name = field.name, - description = field.getComment().orNull, - dataType = - CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString, - nullable = field.nullable, - isPartition = false, // useless for now so just set to false - isBucket = false) - } - .toMap + val allFields = table.schema().fields + allFields.map { field => field.name -> convertFieldToColumn(field) }.toMap } /** @@ -77,6 +65,20 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { */ protected def buildIndex(): FlintSparkIndex + /** + * Table name setter that qualifies given table name for subclass automatically. + */ + protected def tableName_=(tableName: String): Unit = { + qualifiedTableName = qualifyTableName(flint.spark, tableName) + } + + /** + * Table name getter + */ + protected def tableName: String = { + qualifiedTableName + } + /** * Find column with the given name. */ @@ -84,4 +86,16 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { allColumns.getOrElse( colName, throw new IllegalArgumentException(s"Column $colName does not exist")) + + private def convertFieldToColumn(field: StructField): Column = { + // Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata + new Column( + name = field.name, + description = field.getComment().orNull, + dataType = + CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString, + nullable = field.nullable, + isPartition = false, // useless for now so just set to false + isBucket = false) + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala new file mode 100644 index 000000000..0cd4a5293 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.scalatest.matchers.must.Matchers.contain +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.spark.FlintSuite + +class FlintSparkIndexBuilderSuite extends FlintSuite { + + override def beforeAll(): Unit = { + super.beforeAll() + + sql(""" + | CREATE TABLE spark_catalog.default.test + | ( name STRING, age INT ) + | USING JSON + """.stripMargin) + } + + protected override def afterAll(): Unit = { + sql("DROP TABLE spark_catalog.default.test") + + super.afterAll() + } + + test("should qualify table name in default database") { + builder() + .onTable("test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + + builder() + .onTable("default.test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + + builder() + .onTable("spark_catalog.default.test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + } + + test("should qualify table name and get columns in other database") { + sql("CREATE DATABASE mydb") + sql("CREATE TABLE mydb.test2 (address STRING) USING JSON") + sql("USE mydb") + + try { + builder() + .onTable("test2") + .expectTableName("spark_catalog.mydb.test2") + .expectAllColumns("address") + + builder() + .onTable("mydb.test2") + .expectTableName("spark_catalog.mydb.test2") + .expectAllColumns("address") + + builder() + .onTable("spark_catalog.mydb.test2") + .expectTableName("spark_catalog.mydb.test2") + .expectAllColumns("address") + + // Can parse any specified table name + builder() + .onTable("spark_catalog.default.test") + .expectTableName("spark_catalog.default.test") + .expectAllColumns("name", "age") + } finally { + sql("DROP DATABASE mydb CASCADE") + sql("USE default") + } + } + + private def builder(): FakeFlintSparkIndexBuilder = { + new FakeFlintSparkIndexBuilder + } + + /** + * Fake builder that have access to internal method for assertion + */ + class FakeFlintSparkIndexBuilder extends FlintSparkIndexBuilder(new FlintSpark(spark)) { + + def onTable(tableName: String): FakeFlintSparkIndexBuilder = { + this.tableName = tableName + this + } + + def expectTableName(expected: String): FakeFlintSparkIndexBuilder = { + tableName shouldBe expected + this + } + + def expectAllColumns(expected: String*): FakeFlintSparkIndexBuilder = { + allColumns.keys should contain theSameElementsAs expected + this + } + + override protected def buildIndex(): FlintSparkIndex = { + null + } + } +}