From 6d625677f6698c6e091c21aebe6fa449e4156338 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 11 Oct 2024 14:16:06 +0800 Subject: [PATCH] Support table identifier contains dot with backticks Signed-off-by: Lantao Jin --- .../spark/ppl/FlintSparkPPLBasicITSuite.scala | 70 ++++++++++++++++++- .../sql/ppl/CatalystQueryPlanVisitor.java | 10 +-- 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index 4c38e1471..087a2080a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder} @@ -22,12 +22,20 @@ class FlintSparkPPLBasicITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val t1 = "`spark_catalog`.`default`.`flint_ppl_test1`" + private val t2 = "`spark_catalog`.default.`flint_ppl_test2`" + private val t3 = "spark_catalog.`default`.`flint_ppl_test3`" + private val t4 = "`spark_catalog`.`default`.flint_ppl_test4" override def beforeAll(): Unit = { super.beforeAll() // Create test table createPartitionedStateCountryTable(testTable) + createPartitionedStateCountryTable(t1) + createPartitionedStateCountryTable(t2) + createPartitionedStateCountryTable(t3) + createPartitionedStateCountryTable(t4) } protected override def afterEach(): Unit = { @@ -516,4 +524,64 @@ class FlintSparkPPLBasicITSuite // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test backtick table names and name contains '.'") { + Seq(t1, t2, t3, t4).foreach { table => + val frame = sql(s""" + | source = $table| head 2 + | """.stripMargin) + assert(frame.collect().length == 2) + } + // test read table which is unable to create + val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`" + val t6 = "spark_catalog.default.`flint_ppl_test6.log`" + Seq(t5, t6).foreach { table => + val ex = intercept[AnalysisException](sql(s""" + | source = $table| head 2 + | """.stripMargin)) + assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND")) + } + } + + test("test describe backtick table names and name contains '.'") { + Seq(t1, t2, t3, t4).foreach { table => + val frame = sql(s""" + | describe $table + | """.stripMargin) + assert(frame.collect().length > 0) + } + // test read table which is unable to create + val t5 = "`spark_catalog`.default.`flint/ppl/test4.log`" + val t6 = "spark_catalog.default.`flint_ppl_test5.log`" + Seq(t5, t6).foreach { table => + val ex = intercept[AnalysisException](sql(s""" + | describe $table + | """.stripMargin)) + assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND")) + } + } + + test("test explain backtick table names and name contains '.'") { + Seq(t1, t2, t3, t4).foreach { table => + val frame = sql(s""" + | explain extended | source = $table + | """.stripMargin) + assert(frame.collect().length > 0) + } + // test read table which is unable to create + val table = "`spark_catalog`.default.`flint/ppl/test4.log`" + val frame = sql(s""" + | explain extended | source = $table + | """.stripMargin) + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint/ppl/test4.log")) + val expectedPlan: LogicalPlan = + ExplainCommand( + Project(Seq(UnresolvedStar(None)), relation), + ExplainMode.fromString("extended")) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 26ad4198a..28a9c5f32 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -151,7 +151,9 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { if (node instanceof DescribeRelation) { TableIdentifier identifier; - if (node.getTableQualifiedName().getParts().size() == 1) { + if (node.getTableQualifiedName().getParts().isEmpty()) { + throw new IllegalArgumentException("Empty table name is invalid"); + } else if (node.getTableQualifiedName().getParts().size() == 1) { identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0)); } else if (node.getTableQualifiedName().getParts().size() == 2) { identifier = new TableIdentifier( @@ -160,8 +162,8 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { } else if (node.getTableQualifiedName().getParts().size() == 3) { identifier = new TableIdentifier( node.getTableQualifiedName().getParts().get(2), - Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)), - Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1))); + Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1)), + Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0))); } else { throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName() + " Syntax: [ database_name. ] table_name"); @@ -176,7 +178,7 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { //regular sql algebraic relations node.getTableName().forEach(t -> // Resolving the qualifiedName which is composed of a datasource.schema.table - context.withRelation(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)) + context.withRelation(new UnresolvedRelation(seq(node.getTableQualifiedName().getParts()), CaseInsensitiveStringMap.empty(), false)) ); return context.getPlan(); }