Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into issues/711
Browse files Browse the repository at this point in the history
  • Loading branch information
LantaoJin committed Oct 12, 2024
2 parents 7f1b8dd + fe5148c commit 6974171
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 37 deletions.
5 changes: 5 additions & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

#### **Describe**
- `describe table` This command is equal to the `DESCRIBE EXTENDED table` SQL command
- `describe schema.table`
- `` describe schema.`table` ``
- `describe catalog.schema.table`
- `` describe catalog.schema.`table` ``
- `` describe `catalog`.`schema`.`table` ``

#### **Explain**
- `explain simple | source = table | where a = 1 | fields a,b,c`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder}
Expand All @@ -22,12 +22,20 @@ class FlintSparkPPLBasicITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val t1 = "`spark_catalog`.`default`.`flint_ppl_test1`"
private val t2 = "`spark_catalog`.default.`flint_ppl_test2`"
private val t3 = "spark_catalog.`default`.`flint_ppl_test3`"
private val t4 = "`spark_catalog`.`default`.flint_ppl_test4"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createPartitionedStateCountryTable(t1)
createPartitionedStateCountryTable(t2)
createPartitionedStateCountryTable(t3)
createPartitionedStateCountryTable(t4)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -516,4 +524,77 @@ class FlintSparkPPLBasicITSuite
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(s"""
| source = $table| head 2
| """.stripMargin)
assert(frame.collect().length == 2)
}
// test read table which is unable to create
val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`"
val t6 = "spark_catalog.default.`flint_ppl_test6.log`"
Seq(t5, t6).foreach { table =>
val ex = intercept[AnalysisException](sql(s"""
| source = $table| head 2
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(s"""
| source = $t7| head 2
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}

test("test describe backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(s"""
| describe $table
| """.stripMargin)
assert(frame.collect().length > 0)
}
// test read table which is unable to create
val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`"
val t6 = "spark_catalog.default.`flint_ppl_test6.log`"
Seq(t5, t6).foreach { table =>
val ex = intercept[AnalysisException](sql(s"""
| describe $table
| """.stripMargin))
assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND"))
}
val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(s"""
| describe $t7
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}

test("test explain backtick table names and name contains '.'") {
Seq(t1, t2, t3, t4).foreach { table =>
val frame = sql(s"""
| explain extended | source = $table
| """.stripMargin)
assert(frame.collect().length > 0)
}
// test read table which is unable to create
val table = "`spark_catalog`.default.`flint/ppl/test4.log`"
val frame = sql(s"""
| explain extended | source = $table
| """.stripMargin)
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint/ppl/test4.log"))
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedStar(None)), relation),
ExplainMode.fromString("extended"))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)

val t7 = "spark_catalog.default.flint_ppl_test7.log"
val ex = intercept[IllegalArgumentException](sql(s"""
| explain extended | source = $t7
| """.stripMargin))
assert(ex.getMessage().contains("Invalid table name"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ expression
logicalExpression
: NOT logicalExpression # logicalNot
| comparisonExpression # comparsion
| left = logicalExpression OR right = logicalExpression # logicalOr
| left = logicalExpression (AND)? right = logicalExpression # logicalAnd
| left = logicalExpression OR right = logicalExpression # logicalOr
| left = logicalExpression XOR right = logicalExpression # logicalXor
| booleanExpression # booleanExpr
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ public List<String> getTableName() {
return tableName.stream().map(Object::toString).collect(Collectors.toList());
}

public List<QualifiedName> getQualifiedNames() {
return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList());
}

/**
* Return alias.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
import org.opensearch.sql.ppl.utils.ParseStrategy;
import org.opensearch.sql.ppl.utils.SortUtils;
import scala.Option;
import scala.Option$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Seq;
Expand All @@ -113,6 +112,7 @@
import static org.opensearch.sql.ppl.utils.LookupTransformer.buildLookupRelationProjectList;
import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList;
import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields;
import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier;
import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField;
import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window;

Expand Down Expand Up @@ -152,22 +152,7 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) {
@Override
public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
if (node instanceof DescribeRelation) {
TableIdentifier identifier;
if (node.getTableQualifiedName().getParts().size() == 1) {
identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0));
} else if (node.getTableQualifiedName().getParts().size() == 2) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(1),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)));
} else if (node.getTableQualifiedName().getParts().size() == 3) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(2),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1)));
} else {
throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName()
+ " Syntax: [ database_name. ] table_name");
}
TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName());
return context.with(
new DescribeTableCommand(
identifier,
Expand All @@ -176,9 +161,9 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
DescribeRelation$.MODULE$.getOutputAttrs()));
}
//regular sql algebraic relations
node.getTableName().forEach(t ->
node.getQualifiedNames().forEach(q ->
// Resolving the qualifiedName which is composed of a datasource.schema.table
context.withRelation(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false))
context.withRelation(new UnresolvedRelation(getTableIdentifier(q).nameParts(), CaseInsensitiveStringMap.empty(), false))
);
return context.getPlan();
}
Expand Down Expand Up @@ -327,7 +312,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
seq(new ArrayList<Expression>())));
context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan));
}
//visit TopAggregation results limit
//visit TopAggregation results limit
if ((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) {
context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.opensearch.sql.ppl.utils;

import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.opensearch.sql.ast.expression.QualifiedName;
import scala.Option$;

import java.util.List;
import java.util.Optional;
Expand All @@ -15,7 +17,7 @@ public interface RelationUtils {
*
* @param relations
* @param node
* @param contextRelations
* @param tables
* @return
*/
static Optional<QualifiedName> resolveField(List<UnresolvedRelation> relations, QualifiedName node, List<LogicalPlan> tables) {
Expand All @@ -29,4 +31,26 @@ static Optional<QualifiedName> resolveField(List<UnresolvedRelation> relations,
.findFirst()
.map(rel -> node);
}

static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) {
TableIdentifier identifier;
if (qualifiedName.getParts().isEmpty()) {
throw new IllegalArgumentException("Empty table name is invalid");
} else if (qualifiedName.getParts().size() == 1) {
identifier = new TableIdentifier(qualifiedName.getParts().get(0));
} else if (qualifiedName.getParts().size() == 2) {
identifier = new TableIdentifier(
qualifiedName.getParts().get(1),
Option$.MODULE$.apply(qualifiedName.getParts().get(0)));
} else if (qualifiedName.getParts().size() == 3) {
identifier = new TableIdentifier(
qualifiedName.getParts().get(2),
Option$.MODULE$.apply(qualifiedName.getParts().get(1)),
Option$.MODULE$.apply(qualifiedName.getParts().get(0)));
} else {
throw new IllegalArgumentException("Invalid table name: " + qualifiedName
+ " Syntax: [ database_name. ] table_name");
}
return identifier;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,26 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
thrown.getMessage === "Invalid table name: t.b.c.d Syntax: [ database_name. ] table_name")
}

test("test describe with backticks") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "describe t.b.`c.d`"), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("c.d", Option("b"), Option("t")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
comparePlans(expectedPlan, logPlan, false)
}

test("test describe FQN table clause") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "describe schema.default.http_logs"), context)
planTransformer.visit(plan(pplParser, "describe catalog.schema.http_logs"), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("http_logs", Option("schema"), Option("default")),
TableIdentifier("http_logs", Option("schema"), Option("catalog")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
Expand All @@ -64,10 +77,10 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

test("test FQN table describe table clause") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "describe catalog.t"), context)
val logPlan = planTransformer.visit(plan(pplParser, "describe schema.t"), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("t", Option("catalog")),
TableIdentifier("t", Option("schema")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,15 @@

package org.opensearch.flint.spark.ppl

import org.junit.Assert.assertEquals
import org.mockito.Mockito.when
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, TableFunctionRegistry, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Like, Literal, NamedExpression, Not, Or, SortOrder, UnixTimestamp}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class PPLLogicalPlanFiltersTranslatorTestSuite
extends SparkFunSuite
Expand Down Expand Up @@ -219,4 +211,26 @@ class PPLLogicalPlanFiltersTranslatorTestSuite

comparePlans(expectedPlan, logPlan, false)
}

test("test order of evaluation of predicate expression") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source=employees | where department = 'HR' OR job_title = 'Manager' AND salary > 50000"),
context)

val table = UnresolvedRelation(Seq("employees"))
val filter =
Filter(
Or(
EqualTo(UnresolvedAttribute("department"), Literal("HR")),
And(
EqualTo(UnresolvedAttribute("job_title"), Literal("Manager")),
GreaterThan(UnresolvedAttribute("salary"), Literal(50000)))),
table)

val expectedPlan = Project(Seq(UnresolvedStar(None)), filter)
comparePlans(expectedPlan, logPlan, false)
}
}

0 comments on commit 6974171

Please sign in to comment.