Skip to content

Commit

Permalink
Support table identifier contains dot with backticks
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Oct 11, 2024
1 parent 1e91e70 commit 6d62567
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder}
Expand All @@ -22,12 +22,20 @@ class FlintSparkPPLBasicITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val t1 = "`spark_catalog`.`default`.`flint_ppl_test1`"
private val t2 = "`spark_catalog`.default.`flint_ppl_test2`"
private val t3 = "spark_catalog.`default`.`flint_ppl_test3`"
private val t4 = "`spark_catalog`.`default`.flint_ppl_test4"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createPartitionedStateCountryTable(t1)
createPartitionedStateCountryTable(t2)
createPartitionedStateCountryTable(t3)
createPartitionedStateCountryTable(t4)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -516,4 +524,64 @@ class FlintSparkPPLBasicITSuite
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(s"""
| source = $table| head 2
| """.stripMargin)
assert(frame.collect().length == 2)
}
// test read table which is unable to create
val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`"
val t6 = "spark_catalog.default.`flint_ppl_test6.log`"
Seq(t5, t6).foreach { table =>
val ex = intercept[AnalysisException](sql(s"""
| source = $table| head 2
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
}

test("test describe backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(s"""
| describe $table
| """.stripMargin)
assert(frame.collect().length > 0)
}
// test read table which is unable to create
val t5 = "`spark_catalog`.default.`flint/ppl/test4.log`"
val t6 = "spark_catalog.default.`flint_ppl_test5.log`"
Seq(t5, t6).foreach { table =>
val ex = intercept[AnalysisException](sql(s"""
| describe $table
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
}

test("test explain backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(s"""
| explain extended | source = $table
| """.stripMargin)
assert(frame.collect().length > 0)
}
// test read table which is unable to create
val table = "`spark_catalog`.default.`flint/ppl/test4.log`"
val frame = sql(s"""
| explain extended | source = $table
| """.stripMargin)
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint/ppl/test4.log"))
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedStar(None)), relation),
ExplainMode.fromString("extended"))
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) {
public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
if (node instanceof DescribeRelation) {
TableIdentifier identifier;
if (node.getTableQualifiedName().getParts().size() == 1) {
if (node.getTableQualifiedName().getParts().isEmpty()) {
throw new IllegalArgumentException("Empty table name is invalid");
} else if (node.getTableQualifiedName().getParts().size() == 1) {
identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0));
} else if (node.getTableQualifiedName().getParts().size() == 2) {
identifier = new TableIdentifier(
Expand All @@ -160,8 +162,8 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
} else if (node.getTableQualifiedName().getParts().size() == 3) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(2),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1)));
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1)),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)));
} else {
throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName()
+ " Syntax: [ database_name. ] table_name");
Expand All @@ -176,7 +178,7 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
//regular sql algebraic relations
node.getTableName().forEach(t ->
// Resolving the qualifiedName which is composed of a datasource.schema.table
context.withRelation(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false))
context.withRelation(new UnresolvedRelation(seq(node.getTableQualifiedName().getParts()), CaseInsensitiveStringMap.empty(), false))
);
return context.getPlan();
}
Expand Down

0 comments on commit 6d62567

Please sign in to comment.