Skip to content

Commit

Permalink
backport 768 bug fix to 0.5
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Oct 11, 2024
1 parent e10058b commit 0d5d188
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.execution.ExplainMode
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExplainCommand}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLBasicITSuite
Expand All @@ -21,12 +22,20 @@ class FlintSparkPPLBasicITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

private val t1 = "`spark_catalog`.`default`.`flint_ppl_test1`"
private val t2 = "`spark_catalog`.default.`flint_ppl_test2`"
private val t3 = "spark_catalog.`default`.`flint_ppl_test3`"
private val t4 = "`spark_catalog`.`default`.flint_ppl_test4"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createPartitionedStateCountryTable(t1)
createPartitionedStateCountryTable(t2)
createPartitionedStateCountryTable(t3)
createPartitionedStateCountryTable(t4)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -118,6 +127,88 @@ class FlintSparkPPLBasicITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(
s"""
| source = $table| head 2
| """.stripMargin)
assert(frame.collect().length == 2)
}
// test read table which is unable to create
val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`"
val t6 = "spark_catalog.default.`flint_ppl_test6.log`"
Seq(t5, t6).foreach { table =>
val ex = intercept[AnalysisException](sql(
s"""
| source = $table| head 2
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(
s"""
| source = $t7| head 2
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}

test("test describe backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(
s"""
| describe $table
| """.stripMargin)
assert(frame.collect().length > 0)
}
// test read table which is unable to create
val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`"
val t6 = "spark_catalog.default.`flint_ppl_test6.log`"
Seq(t5, t6).foreach { table =>
val ex = intercept[AnalysisException](sql(
s"""
| describe $table
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(
s"""
| describe $t7
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}

test("test explain backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(
s"""
| explain extended | source = $table
| """.stripMargin)
assert(frame.collect().length > 0)
}
// test read table which is unable to create
val table = "`spark_catalog`.default.`flint/ppl/test4.log`"
val frame = sql(
s"""
| explain extended | source = $table
| """.stripMargin)
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint/ppl/test4.log"))
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedStar(None)), relation),
ExplainMode.fromString("extended"))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)

val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(
s"""
| explain extended | source = $t7
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}

test("create ppl simple query test") {
val testTableQuoted = "`spark_catalog`.`default`.`flint_ppl_test`"
Seq(testTable, testTableQuoted).foreach { table =>
Expand Down
6 changes: 5 additions & 1 deletion ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,11 @@ See the next samples of PPL queries :

**Describe**
- `describe table` This command is equal to the `DESCRIBE EXTENDED table` SQL command

- `describe schema.table`
- `` describe schema.`table` ``
- `describe catalog.schema.table`
- `` describe catalog.schema.`table` ``
- `` describe `catalog`.`schema`.`table` ``
**Fields**
- `source = table`
- `source = table | fields a,b,c`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ public Relation(UnresolvedExpression tableName, String alias) {
private String alias;

/**
* Return table name.
*
* @return table name
*/
public List<String> getTableName() {
return tableName.stream().map(Object::toString).collect(Collectors.toList());
}



public List<QualifiedName> getQualifiedNames() {
return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList());
}

/**
* Return alias.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent;
import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEventAndKeepEmpty;
import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join;
import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier;
import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField;
import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window;

Expand Down Expand Up @@ -126,33 +127,18 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) {
@Override
public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
if (node instanceof DescribeRelation) {
TableIdentifier identifier;
if (node.getTableQualifiedName().getParts().size() == 1) {
identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0));
} else if (node.getTableQualifiedName().getParts().size() == 2) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(1),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)));
} else if (node.getTableQualifiedName().getParts().size() == 3) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(2),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1)));
} else {
throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName()
+ " Syntax: [ database_name. ] table_name");
}
TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName());
return context.with(
new DescribeTableCommand(
identifier,
scala.collection.immutable.Map$.MODULE$.<String, String>empty(),
true,
DescribeRelation$.MODULE$.getOutputAttrs()));
}
//regular sql algebraic relations
node.getTableName().forEach(t ->
//regular sql algebraic relations
node.getQualifiedNames().forEach(q ->
// Resolving the qualifiedName which is composed of a datasource.schema.table
context.withRelation(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false))
context.withRelation(new UnresolvedRelation(getTableIdentifier(q).nameParts(), CaseInsensitiveStringMap.empty(), false))
);
return context.getPlan();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.opensearch.sql.ppl.utils;

import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.opensearch.sql.ast.expression.QualifiedName;
import scala.Option$;

import java.util.List;
import java.util.Optional;
Expand All @@ -15,7 +17,6 @@ public interface RelationUtils {
*
* @param relations
* @param node
* @param contextRelations
* @return
*/
static Optional<QualifiedName> resolveField(List<UnresolvedRelation> relations, QualifiedName node, List<LogicalPlan> tables) {
Expand All @@ -29,4 +30,26 @@ static Optional<QualifiedName> resolveField(List<UnresolvedRelation> relations,
.findFirst()
.map(rel -> node);
}

static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) {
TableIdentifier identifier;
if (qualifiedName.getParts().isEmpty()) {
throw new IllegalArgumentException("Empty table name is invalid");
} else if (qualifiedName.getParts().size() == 1) {
identifier = new TableIdentifier(qualifiedName.getParts().get(0));
} else if (qualifiedName.getParts().size() == 2) {
identifier = new TableIdentifier(
qualifiedName.getParts().get(1),
Option$.MODULE$.apply(qualifiedName.getParts().get(0)));
} else if (qualifiedName.getParts().size() == 3) {
identifier = new TableIdentifier(
qualifiedName.getParts().get(2),
Option$.MODULE$.apply(qualifiedName.getParts().get(1)),
Option$.MODULE$.apply(qualifiedName.getParts().get(0)));
} else {
throw new IllegalArgumentException("Invalid table name: " + qualifiedName
+ " Syntax: [ database_name. ] table_name");
}
return identifier;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,27 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
assert(
thrown.getMessage === "Invalid table name: t.b.c.d Syntax: [ database_name. ] table_name")
}

test("test describe with backticks") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "describe t.b.`c.d`", false), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("c.d", Option("b"), Option("t")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
comparePlans(expectedPlan, logPlan, false)
}

test("test describe FQN table clause") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "describe schema.default.http_logs", false), context)
planTransformer.visit(plan(pplParser, "describe catalog.schema.http_logs", false), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("http_logs", Option("schema"), Option("default")),
TableIdentifier("http_logs", Option("schema"), Option("catalog")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
Expand All @@ -64,10 +77,10 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

test("test FQN table describe table clause") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "describe catalog.t", false), context)
val logPlan = planTransformer.visit(plan(pplParser, "describe schema.t", false), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("t", Option("catalog")),
TableIdentifier("t", Option("schema")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
Expand Down

0 comments on commit 0d5d188

Please sign in to comment.