Skip to content

Commit

Permalink
also fix the bug of describe
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Oct 11, 2024
1 parent 6d62567 commit ee80aaf
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ class FlintSparkPPLBasicITSuite
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(s"""
| source = $t7| head 2
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}

test("test describe backtick table names and name contains '.'") {
Expand All @@ -551,14 +556,19 @@ class FlintSparkPPLBasicITSuite
assert(frame.collect().length > 0)
}
// test read table which is unable to create
val t5 = "`spark_catalog`.default.`flint/ppl/test4.log`"
val t6 = "spark_catalog.default.`flint_ppl_test5.log`"
val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`"
val t6 = "spark_catalog.default.`flint_ppl_test6.log`"
Seq(t5, t6).foreach { table =>
val ex = intercept[AnalysisException](sql(s"""
| describe $table
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(s"""
| describe $t7
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}

test("test explain backtick table names and name contains '.'") {
Expand All @@ -573,15 +583,18 @@ class FlintSparkPPLBasicITSuite
val frame = sql(s"""
| explain extended | source = $table
| """.stripMargin)
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint/ppl/test4.log"))
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedStar(None)), relation),
ExplainMode.fromString("extended"))
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)

val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(s"""
| explain extended | source = $t7
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ public List<String> getTableName() {
return tableName.stream().map(Object::toString).collect(Collectors.toList());
}

public List<QualifiedName> getQualifiedNames() {
return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList());
}

/**
* Return alias.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
import org.opensearch.sql.ppl.utils.ParseStrategy;
import org.opensearch.sql.ppl.utils.SortUtils;
import scala.Option;
import scala.Option$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Seq;
Expand All @@ -111,6 +110,7 @@
import static org.opensearch.sql.ppl.utils.LookupTransformer.buildLookupRelationProjectList;
import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList;
import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields;
import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier;
import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField;
import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window;

Expand Down Expand Up @@ -150,24 +150,7 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) {
@Override
public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
if (node instanceof DescribeRelation) {
TableIdentifier identifier;
if (node.getTableQualifiedName().getParts().isEmpty()) {
throw new IllegalArgumentException("Empty table name is invalid");
} else if (node.getTableQualifiedName().getParts().size() == 1) {
identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0));
} else if (node.getTableQualifiedName().getParts().size() == 2) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(1),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)));
} else if (node.getTableQualifiedName().getParts().size() == 3) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(2),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1)),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)));
} else {
throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName()
+ " Syntax: [ database_name. ] table_name");
}
TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName());
return context.with(
new DescribeTableCommand(
identifier,
Expand All @@ -176,9 +159,9 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
DescribeRelation$.MODULE$.getOutputAttrs()));
}
//regular sql algebraic relations
node.getTableName().forEach(t ->
node.getQualifiedNames().forEach(q ->
// Resolving the qualifiedName which is composed of a datasource.schema.table
context.withRelation(new UnresolvedRelation(seq(node.getTableQualifiedName().getParts()), CaseInsensitiveStringMap.empty(), false))
context.withRelation(new UnresolvedRelation(getTableIdentifier(q).nameParts(), CaseInsensitiveStringMap.empty(), false))
);
return context.getPlan();
}
Expand Down Expand Up @@ -327,7 +310,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
seq(new ArrayList<Expression>())));
context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan));
}
//visit TopAggregation results limit
//visit TopAggregation results limit
if ((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) {
context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.opensearch.sql.ppl.utils;

import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.opensearch.sql.ast.expression.QualifiedName;
import scala.Option$;

import java.util.List;
import java.util.Optional;
Expand All @@ -15,7 +17,7 @@ public interface RelationUtils {
*
* @param relations
* @param node
* @param contextRelations
* @param tables
* @return
*/
static Optional<QualifiedName> resolveField(List<UnresolvedRelation> relations, QualifiedName node, List<LogicalPlan> tables) {
Expand All @@ -29,4 +31,26 @@ static Optional<QualifiedName> resolveField(List<UnresolvedRelation> relations,
.findFirst()
.map(rel -> node);
}

static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) {
TableIdentifier identifier;
if (qualifiedName.getParts().isEmpty()) {
throw new IllegalArgumentException("Empty table name is invalid");
} else if (qualifiedName.getParts().size() == 1) {
identifier = new TableIdentifier(qualifiedName.getParts().get(0));
} else if (qualifiedName.getParts().size() == 2) {
identifier = new TableIdentifier(
qualifiedName.getParts().get(1),
Option$.MODULE$.apply(qualifiedName.getParts().get(0)));
} else if (qualifiedName.getParts().size() == 3) {
identifier = new TableIdentifier(
qualifiedName.getParts().get(2),
Option$.MODULE$.apply(qualifiedName.getParts().get(1)),
Option$.MODULE$.apply(qualifiedName.getParts().get(0)));
} else {
throw new IllegalArgumentException("Invalid table name: " + qualifiedName
+ " Syntax: [ database_name. ] table_name");
}
return identifier;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
planTransformer.visit(plan(pplParser, "describe schema.default.http_logs"), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("http_logs", Option("schema"), Option("default")),
TableIdentifier("http_logs", Option("default"), Option("schema")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
Expand Down

0 comments on commit ee80aaf

Please sign in to comment.