Skip to content

Commit

Permalink
Fix catalog name missing in Flint index name (#48)
Browse files Browse the repository at this point in the history
* Change to qualified name in Flint Spark index API layer

Signed-off-by: Chen Dai <[email protected]>

* Qualify table name in Flint SQL layer

Signed-off-by: Chen Dai <[email protected]>

* Add more IT

Signed-off-by: Chen Dai <[email protected]>

* Reuse Spark utility method for parsing

Signed-off-by: Chen Dai <[email protected]>

* Update javadoc

Signed-off-by: Chen Dai <[email protected]>

* Fix catalog plugin name issue

Signed-off-by: Chen Dai <[email protected]>

---------

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Oct 2, 2023
1 parent 73bd6e8 commit e3210f0
Show file tree
Hide file tree
Showing 16 changed files with 380 additions and 79 deletions.
30 changes: 19 additions & 11 deletions flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ skippingIndexStatement

createSkippingIndexStatement
: CREATE SKIPPING INDEX (IF NOT EXISTS)?
ON tableName=multipartIdentifier
ON tableName
LEFT_PAREN indexColTypeList RIGHT_PAREN
(WITH LEFT_PAREN propertyList RIGHT_PAREN)?
;

refreshSkippingIndexStatement
: REFRESH SKIPPING INDEX ON tableName=multipartIdentifier
: REFRESH SKIPPING INDEX ON tableName
;

describeSkippingIndexStatement
: (DESC | DESCRIBE) SKIPPING INDEX ON tableName=multipartIdentifier
: (DESC | DESCRIBE) SKIPPING INDEX ON tableName
;

dropSkippingIndexStatement
: DROP SKIPPING INDEX ON tableName=multipartIdentifier
: DROP SKIPPING INDEX ON tableName
;

coveringIndexStatement
Expand All @@ -54,26 +54,26 @@ coveringIndexStatement
;

createCoveringIndexStatement
: CREATE INDEX (IF NOT EXISTS)? indexName=identifier
ON tableName=multipartIdentifier
: CREATE INDEX (IF NOT EXISTS)? indexName
ON tableName
LEFT_PAREN indexColumns=multipartIdentifierPropertyList RIGHT_PAREN
(WITH LEFT_PAREN propertyList RIGHT_PAREN)?
;

refreshCoveringIndexStatement
: REFRESH INDEX indexName=identifier ON tableName=multipartIdentifier
: REFRESH INDEX indexName ON tableName
;

showCoveringIndexStatement
: SHOW (INDEX | INDEXES) ON tableName=multipartIdentifier
: SHOW (INDEX | INDEXES) ON tableName
;

describeCoveringIndexStatement
: (DESC | DESCRIBE) INDEX indexName=identifier ON tableName=multipartIdentifier
: (DESC | DESCRIBE) INDEX indexName ON tableName
;

dropCoveringIndexStatement
: DROP INDEX indexName=identifier ON tableName=multipartIdentifier
: DROP INDEX indexName ON tableName
;

indexColTypeList
Expand All @@ -82,4 +82,12 @@ indexColTypeList

indexColType
: identifier skipType=(PARTITION | VALUE_SET | MIN_MAX)
;
;

indexName
: identifier
;

tableName
: multipartIdentifier
;
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.apache.spark.sql.connector.catalog._

/**
* Flint utility methods that rely on access to private code in Spark SQL package.
*/
package object flint {

/**
* Qualify a given table name.
*
* @param spark
* Spark session
* @param tableName
* table name maybe qualified or not
* @return
* qualified table name in catalog.database.table format
*/
def qualifyTableName(spark: SparkSession, tableName: String): String = {
val (catalog, ident) = parseTableName(spark, tableName)

// Tricky that our Flint delegate catalog's name has to be spark_catalog
// so we have to find its actual name in CatalogManager
val catalogMgr = spark.sessionState.catalogManager
val catalogName =
catalogMgr
.listCatalogs(Some("*"))
.find(catalogMgr.catalog(_) == catalog)
.getOrElse(catalog.name())

s"$catalogName.${ident.namespace.mkString(".")}.${ident.name}"
}

/**
* Parse a given table name into its catalog and table identifier.
*
* @param spark
* Spark session
* @param tableName
* table name maybe qualified or not
* @return
* Spark catalog and table identifier
*/
def parseTableName(spark: SparkSession, tableName: String): (CatalogPlugin, Identifier) = {
// Create a anonymous class to access CatalogAndIdentifier
new LookupCatalog {
override protected val catalogManager: CatalogManager = spark.sessionState.catalogManager

def parseTableName(): (CatalogPlugin, Identifier) = {
val parts = tableName.split("\\.").toSeq
parts match {
case CatalogAndIdentifier(catalog, ident) => (catalog, ident)
}
}
}.parseTableName()
}

/**
* Load table for the given table identifier in the catalog.
*
* @param catalog
* Spark catalog
* @param ident
* table identifier
* @return
* Spark table
*/
def loadTable(catalog: CatalogPlugin, ident: Identifier): Option[Table] = {
CatalogV2Util.loadTable(catalog, ident)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ package org.opensearch.flint.spark
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty

import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}
import org.apache.spark.sql.types.StructField

/**
* Flint Spark index builder base class.
Expand All @@ -17,21 +20,22 @@ import org.apache.spark.sql.catalog.Column
*/
abstract class FlintSparkIndexBuilder(flint: FlintSpark) {

/** Source table name */
protected var tableName: String = ""
/** Qualified source table name */
protected var qualifiedTableName: String = ""

/** Index options */
protected var indexOptions: FlintSparkIndexOptions = empty

/** All columns of the given source table */
lazy protected val allColumns: Map[String, Column] = {
require(tableName.nonEmpty, "Source table name is not provided")
require(qualifiedTableName.nonEmpty, "Source table name is not provided")

flint.spark.catalog
.listColumns(tableName)
.collect()
.map(col => (col.name, col))
.toMap
val (catalog, ident) = parseTableName(flint.spark, qualifiedTableName)
val table = loadTable(catalog, ident).getOrElse(
throw new IllegalStateException(s"Table $qualifiedTableName is not found"))

val allFields = table.schema().fields
allFields.map { field => field.name -> convertFieldToColumn(field) }.toMap
}

/**
Expand Down Expand Up @@ -61,8 +65,37 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
*/
protected def buildIndex(): FlintSparkIndex

/**
* Table name setter that qualifies given table name for subclass automatically.
*/
protected def tableName_=(tableName: String): Unit = {
qualifiedTableName = qualifyTableName(flint.spark, tableName)
}

/**
* Table name getter
*/
protected def tableName: String = {
qualifiedTableName
}

/**
* Find column with the given name.
*/
protected def findColumn(colName: String): Column =
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist"))

private def convertFieldToColumn(field: StructField): Column = {
// Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata
new Column(
name = field.name,
description = field.getComment().orNull,
dataType =
CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString,
nullable = field.nullable,
isPartition = false, // useless for now so just set to false
isBucket = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ object FlintSparkCoveringIndex {
* Flint covering index name
*/
def getFlintIndexName(indexName: String, tableName: String): String = {
require(tableName.contains("."), "Full table name database.table is required")
require(
tableName.split("\\.").length >= 3,
"Qualified table name catalog.database.table is required")

flintIndexNamePrefix(tableName) + indexName + COVERING_INDEX_SUFFIX
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ package org.opensearch.flint.spark.skipping
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.flint.qualifyTableName

/**
* Flint Spark skipping index apply rule that rewrites applicable query's filtering condition and
Expand All @@ -32,8 +34,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
Some(table),
false))
if hasNoDisjunction(condition) && !location.isInstanceOf[FlintSparkSkippingFileIndex] =>
val indexName = getSkippingIndexName(table.identifier.unquotedString)
val index = flint.describeIndex(indexName)
val index = flint.describeIndex(getIndexName(table))
if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) {
val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex]
val indexFilter = rewriteToIndexFilter(skippingIndex, condition)
Expand All @@ -58,9 +59,17 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
}
}

private def getIndexName(table: CatalogTable): String = {
// Because Spark qualified name only contains database.table without catalog
// the limitation here is qualifyTableName always use current catalog.
val tableName = table.qualifiedName
val qualifiedTableName = qualifyTableName(flint.spark, tableName)
getSkippingIndexName(qualifiedTableName)
}

private def hasNoDisjunction(condition: Expression): Boolean = {
condition.collectFirst {
case Or(_, _) => true
condition.collectFirst { case Or(_, _) =>
true
}.isEmpty
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ object FlintSparkSkippingIndex {
* Flint skipping index name
*/
def getSkippingIndexName(tableName: String): String = {
require(tableName.contains("."), "Full table name database.table is required")
require(
tableName.split("\\.").length >= 3,
"Qualified table name catalog.database.table is required")

flintIndexNamePrefix(tableName) + SKIPPING_INDEX_SUFFIX
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.opensearch.flint.spark.sql.covering.FlintSparkCoveringIndexAstBuilder
import org.opensearch.flint.spark.sql.skipping.FlintSparkSkippingIndexAstBuilder

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.flint.qualifyTableName

/**
* Flint Spark AST builder that builds Spark command for Flint index statement. This class mix-in
Expand All @@ -33,7 +34,9 @@ class FlintSparkSqlAstBuilder
object FlintSparkSqlAstBuilder {

/**
* Get full table name if database not specified.
* Get full table name if catalog or database not specified. The reason we cannot do this in
* common SparkSqlAstBuilder.visitTableName is that SparkSession is required to qualify table
* name which is only available at execution time instead of parsing time.
*
* @param flint
* Flint Spark which has access to Spark Catalog
Expand All @@ -42,12 +45,6 @@ object FlintSparkSqlAstBuilder {
* @return
*/
def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = {
val tableName = tableNameCtx.getText
if (tableName.contains(".")) {
tableName
} else {
val db = flint.spark.catalog.currentDatabase
s"$db.$tableName"
}
qualifyTableName(flint.spark, tableNameCtx.getText)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
ctx: CreateCoveringIndexStatementContext): Command = {
FlintSparkSqlCommand() { flint =>
val indexName = ctx.indexName.getText
val tableName = ctx.tableName.getText
val tableName = getFullTableName(flint, ctx.tableName)
val indexBuilder =
flint
.coveringIndex()
Expand Down
Loading

0 comments on commit e3210f0

Please sign in to comment.