Skip to content

Commit

Permalink
add head support
Browse files Browse the repository at this point in the history
add README.md details for supported commands and planned future support

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Sep 13, 2023
1 parent eaa4e33 commit 157bbb7
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark

import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Limit, LogicalPlan, Project}
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}

Expand Down Expand Up @@ -66,7 +66,7 @@ class FlintSparkPPLITSuite
}
}

test("create ppl simple query with start fields result test") {
test("create ppl simple query test") {
val frame = sql(
s"""
| source = $testTable
Expand All @@ -75,7 +75,6 @@ class FlintSparkPPLITSuite
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
// [John,25,Ontario,Canada,2023,4], [Jane,25,Quebec,Canada,2023,4], [Jake,70,California,USA,2023,4], [Hello,30,New York,USA,2023,4]
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4),
Expand All @@ -95,6 +94,24 @@ class FlintSparkPPLITSuite
assert(expectedPlan === logicalPlan)
}

test("create ppl simple query with head (limit) 3 test") {
val frame = sql(
s"""
| source = $testTable | head 2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 2)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan = Limit(Literal(2), Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))))
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query two with fields result test") {
val frame = sql(
s"""
Expand Down Expand Up @@ -124,6 +141,25 @@ class FlintSparkPPLITSuite
assert(expectedPlan === logicalPlan)
}

test("create ppl simple query two with fields and head (limit) test") {
val frame = sql(
s"""
| source = $testTable| fields name, age | head 1
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test")))
// Define the expected logical plan
val expectedPlan: LogicalPlan = Limit(Literal(1), Project(Seq(UnresolvedStar(None)), project))
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age literal equal filter query with two fields result test") {
val frame = sql(
s"""
Expand Down Expand Up @@ -217,6 +253,30 @@ class FlintSparkPPLITSuite
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") {
val frame = sql(
s"""
| source = $testTable age<=20 OR country = 'USA' | fields name, age | head 1
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)


// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val filterExpr = Or(EqualTo(UnresolvedAttribute("country"), Literal("USA")), LessThanOrEqual(UnresolvedAttribute("age"), Literal(20)))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val projectPlan = Project(Seq(UnresolvedStar(None)), Project(projectList, filterPlan))
val expectedPlan = Limit(Literal(1), projectPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age literal greater than filter query with two fields result test") {
val frame = sql(
s"""
Expand Down Expand Up @@ -437,6 +497,35 @@ class FlintSparkPPLITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age avg group by country head (limit) query test ") {
val frame = sql(
s"""
| source = $testTable| stats avg(age) by country | head 1
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val projectPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan)
val expectedPlan = Limit(Literal(1), projectPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age max group by country query test ") {
val frame = sql(
Expand Down Expand Up @@ -564,7 +653,7 @@ class FlintSparkPPLITSuite
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}"
Expand Down Expand Up @@ -721,6 +810,34 @@ class FlintSparkPPLITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple avg age by span of interval of 10 years with head (limit) query test ") {
val frame = sql(
s"""
| source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 2)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table)
val projectPlan = Project(star, aggregatePlan)
val expectedPlan = Limit(Literal(2), projectPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* +--------+-------+-----------+
Expand Down Expand Up @@ -767,4 +884,31 @@ class FlintSparkPPLITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

ignore("create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") {
val frame = sql(
s"""
| source = $testTable | stats avg(age) by span(age, 10) as age_span, country | head 2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 2)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")()
val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table)
val projectPlan = Project(star, aggregatePlan)
val expectedPlan = Limit(Literal(1), projectPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
19 changes: 16 additions & 3 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,26 +231,39 @@ The next samples of PPL queries are currently supported:
- `source = table | where a >= 1 | fields a,b,c`
- `source = table | where a < 1 | fields a,b,c`
- `source = table | where b != 'test' | fields a,b,c`
- `source = table | where c = 'test' | fields a,b,c`
- `source = table | where c = 'test' | fields a,b,c | head 3`

**Filters With Logical Conditions**
- `source = table | where c = 'test' AND a = 1 | fields a,b,c`
- `source = table | where c != 'test' OR a > 1 | fields a,b,c`
- `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1`
- `source = table | where c = 'test' NOT a > 1 | fields a,b,c`

**Aggregations**
- `source = table | stats avg(a) `
- `source = table | where a < 50 | stats avg(c) `
- `source = table | stats max(c) by b`
- `source = table | stats count(c) by b | head 5`

**Aggregations With Span**
- `source = table | stats count(a) by span(a, 10) as a_span`

#### Supported Commands:
- `search` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/search.rst)
- `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst)
- `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst)
- `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst)
- `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst)

> For additional details review the next [Integration Test ](../integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala)
---

#### Planned Support

- support the `explain` command to return the explained PPL query logical plan and expected execution plan
- support the `explain` command to return the explained PPL query logical plan and expected execution plan
- add [sort](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) support
- add [conditions](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/condition.rst) support
- add [top](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/top.rst) support
- add [cast](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/conversion.rst) support
- add [math](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/math.rst) support
- add [deduplicate](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/dedup.rst) support
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class CatalystPlanContext {
* Catalyst evolving logical plan
**/
private Stack<LogicalPlan> planBranches = new Stack<>();
private int limit = Integer.MIN_VALUE;

/**
* NamedExpression contextual parameters
Expand Down Expand Up @@ -48,6 +49,14 @@ public void with(LogicalPlan plan) {
this.planBranches.push(plan);
}

public void limit(int limit) {
this.limit = limit;
}

public int getLimit() {
return limit;
}

public void plan(Function<LogicalPlan, LogicalPlan> transformFunction) {
this.planBranches.replaceAll(transformFunction::apply);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.Limit;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
Expand All @@ -27,6 +30,7 @@
import org.opensearch.sql.ast.expression.And;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.Interval;
Expand Down Expand Up @@ -204,8 +208,12 @@ public String visitProject(Project node, CatalystPlanContext context) {
// Create a projection list from the existing expressions
Seq<?> projectList = asScalaBuffer(context.getNamedParseExpressions()).toSeq();
if (!projectList.isEmpty()) {
Seq<NamedExpression> namedExpressionSeq = asScalaBuffer(context.getNamedParseExpressions().stream()
.map(v -> (NamedExpression) v).collect(Collectors.toList())).toSeq();
// build the plan with the projection step
context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project((Seq<NamedExpression>) projectList, p));
context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(namedExpressionSeq, p));
//now remove all context.getNamedParseExpressions()
context.getNamedParseExpressions().retainAll(emptyList());
}
if (node.hasArgument()) {
Argument argument = node.getArgExprList().get(0);
Expand All @@ -214,6 +222,10 @@ public String visitProject(Project node, CatalystPlanContext context) {
arg = "-";
}
}
if(context.getLimit() > 0) {
context.plan(p-> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
context.getLimit(), DataTypes.IntegerType), p));
}
return format("%s | fields %s %s", child, arg, fields);
}

Expand Down Expand Up @@ -259,6 +271,7 @@ public String visitDedupe(Dedupe node, CatalystPlanContext context) {
public String visitHead(Head node, CatalystPlanContext context) {
String child = node.getChild().get(0).accept(this, context);
Integer size = node.getSize();
context.limit(size);
return format("%s | head %d", child, size);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.expressions.{Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
Expand Down Expand Up @@ -76,6 +76,18 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | fields + A,B")
}
test("test simple search with only one table with two fields with head (limit ) command projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t | fields A, B | head 5", false), context)


val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val planWithLimit = Project(Seq(UnresolvedStar(None)), Project(projectList, table))
val expectedPlan = GlobalLimit(Literal(5), LocalLimit(Literal(5), planWithLimit))
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | fields + A,B | head 5 | fields + *")
}

test("Search multiple tables - translated into union call - fields expected to exist in both tables ") {
val context = new CatalystPlanContext
Expand Down

0 comments on commit 157bbb7

Please sign in to comment.