Skip to content

Commit

Permalink
Fix catalog plugin name issue
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Sep 29, 2023
1 parent a162cb5 commit 0e9f682
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@ package object flint {
*/
def qualifyTableName(spark: SparkSession, tableName: String): String = {
val (catalog, ident) = parseTableName(spark, tableName)
s"${catalog.name}.${ident.namespace.mkString(".")}.${ident.name}"

// Tricky that our Flint delegate catalog's name has to be spark_catalog
// so we have to find its actual name in CatalogManager
val catalogMgr = spark.sessionState.catalogManager
val catalogName =
catalogMgr
.listCatalogs(Some("*"))
.find(catalogMgr.catalog(_) == catalog)
.getOrElse(catalog.name())

s"$catalogName.${ident.namespace.mkString(".")}.${ident.name}"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty

import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.flint.{loadTable, parseTableName}
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}
import org.apache.spark.sql.types.StructField

/**
* Flint Spark index builder base class.
Expand All @@ -19,35 +20,22 @@ import org.apache.spark.sql.flint.{loadTable, parseTableName}
*/
abstract class FlintSparkIndexBuilder(flint: FlintSpark) {

/** Source table name */
protected var tableName: String = ""
/** Qualified source table name */
protected var qualifiedTableName: String = ""

/** Index options */
protected var indexOptions: FlintSparkIndexOptions = empty

/** All columns of the given source table */
lazy protected val allColumns: Map[String, Column] = {
require(tableName.nonEmpty, "Source table name is not provided")
require(qualifiedTableName.nonEmpty, "Source table name is not provided")

val (catalog, ident) = parseTableName(flint.spark, tableName)
val (catalog, ident) = parseTableName(flint.spark, qualifiedTableName)
val table = loadTable(catalog, ident).getOrElse(
throw new IllegalStateException(s"Table $tableName is not found"))
throw new IllegalStateException(s"Table $qualifiedTableName is not found"))

// Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata
table
.schema()
.fields
.map { field =>
field.name -> new Column(
name = field.name,
description = field.getComment().orNull,
dataType =
CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString,
nullable = field.nullable,
isPartition = false, // useless for now so just set to false
isBucket = false)
}
.toMap
val allFields = table.schema().fields
allFields.map { field => field.name -> convertFieldToColumn(field) }.toMap
}

/**
Expand Down Expand Up @@ -77,11 +65,37 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
*/
protected def buildIndex(): FlintSparkIndex

/**
* Table name setter that qualifies given table name for subclass automatically.
*/
protected def tableName_=(tableName: String): Unit = {
qualifiedTableName = qualifyTableName(flint.spark, tableName)
}

/**
* Table name getter
*/
protected def tableName: String = {
qualifiedTableName
}

/**
* Find column with the given name.
*/
protected def findColumn(colName: String): Column =
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist"))

private def convertFieldToColumn(field: StructField): Column = {
// Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata
new Column(
name = field.name,
description = field.getComment().orNull,
dataType =
CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString,
nullable = field.nullable,
isPartition = false, // useless for now so just set to false
isBucket = false)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite

class FlintSparkIndexBuilderSuite extends FlintSuite {

override def beforeAll(): Unit = {
super.beforeAll()

sql("""
| CREATE TABLE spark_catalog.default.test
| ( name STRING, age INT )
| USING JSON
""".stripMargin)
}

protected override def afterAll(): Unit = {
sql("DROP TABLE spark_catalog.default.test")

super.afterAll()
}

test("should qualify table name in default database") {
builder()
.onTable("test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")

builder()
.onTable("default.test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")

builder()
.onTable("spark_catalog.default.test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")
}

test("should qualify table name and get columns in other database") {
sql("CREATE DATABASE mydb")
sql("CREATE TABLE mydb.test2 (address STRING) USING JSON")
sql("USE mydb")

try {
builder()
.onTable("test2")
.expectTableName("spark_catalog.mydb.test2")
.expectAllColumns("address")

builder()
.onTable("mydb.test2")
.expectTableName("spark_catalog.mydb.test2")
.expectAllColumns("address")

builder()
.onTable("spark_catalog.mydb.test2")
.expectTableName("spark_catalog.mydb.test2")
.expectAllColumns("address")

// Can parse any specified table name
builder()
.onTable("spark_catalog.default.test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")
} finally {
sql("DROP DATABASE mydb CASCADE")
sql("USE default")
}
}

private def builder(): FakeFlintSparkIndexBuilder = {
new FakeFlintSparkIndexBuilder
}

/**
* Fake builder that have access to internal method for assertion
*/
class FakeFlintSparkIndexBuilder extends FlintSparkIndexBuilder(new FlintSpark(spark)) {

def onTable(tableName: String): FakeFlintSparkIndexBuilder = {
this.tableName = tableName
this
}

def expectTableName(expected: String): FakeFlintSparkIndexBuilder = {
tableName shouldBe expected
this
}

def expectAllColumns(expected: String*): FakeFlintSparkIndexBuilder = {
allColumns.keys should contain theSameElementsAs expected
this
}

override protected def buildIndex(): FlintSparkIndex = {
null
}
}
}

0 comments on commit 0e9f682

Please sign in to comment.