From 157bbb7d024883e53689d54834766edb74577cd3 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 13 Sep 2023 15:53:31 -0700 Subject: [PATCH] add `head` support add README.md details for supported commands and planned future support Signed-off-by: YANGDB --- .../flint/spark/FlintSparkPPLITSuite.scala | 152 +++++++++++++++++- ppl-spark-integration/README.md | 19 ++- .../sql/ppl/CatalystPlanContext.java | 9 ++ .../sql/ppl/CatalystQueryPlanVisitor.java | 15 +- ...lPlanBasicQueriesTranslatorTestSuite.scala | 14 +- 5 files changed, 200 insertions(+), 9 deletions(-) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index 09b3dbdd7..1786c676d 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Limit, LogicalPlan, Project} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.{QueryTest, Row} @@ -66,7 +66,7 @@ class FlintSparkPPLITSuite } } - test("create ppl simple query with start fields result test") { + test("create ppl simple query test") { val frame = sql( s""" | source = $testTable @@ -75,7 +75,6 @@ class FlintSparkPPLITSuite // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - // [John,25,Ontario,Canada,2023,4], [Jane,25,Quebec,Canada,2023,4], [Jake,70,California,USA,2023,4], [Hello,30,New York,USA,2023,4] val expectedResults: Array[Row] = Array( Row("Jake", 70, "California", "USA", 2023, 4), Row("Hello", 30, "New York", "USA", 2023, 4), @@ -95,6 +94,24 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } + test("create ppl simple query with head (limit) 3 test") { + val frame = sql( + s""" + | source = $testTable | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val expectedPlan: LogicalPlan = Limit(Literal(2), Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test("create ppl simple query two with fields result test") { val frame = sql( s""" @@ -124,6 +141,25 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } + test("create ppl simple query two with fields and head (limit) test") { + val frame = sql( + s""" + | source = $testTable| fields name, age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val project = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test"))) + // Define the expected logical plan + val expectedPlan: LogicalPlan = Limit(Literal(1), Project(Seq(UnresolvedStar(None)), project)) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } + test("create ppl simple age literal equal filter query with two fields result test") { val frame = sql( s""" @@ -217,6 +253,30 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } + test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { + val frame = sql( + s""" + | source = $testTable age<=20 OR country = 'USA' | fields name, age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = Or(EqualTo(UnresolvedAttribute("country"), Literal("USA")), LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val projectPlan = Project(Seq(UnresolvedStar(None)), Project(projectList, filterPlan)) + val expectedPlan = Limit(Literal(1), projectPlan) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } + test("create ppl simple age literal greater than filter query with two fields result test") { val frame = sql( s""" @@ -437,6 +497,35 @@ class FlintSparkPPLITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + test("create ppl simple age avg group by country head (limit) query test ") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) by country | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val projectPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + val expectedPlan = Limit(Literal(1), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } test("create ppl simple age max group by country query test ") { val frame = sql( @@ -564,7 +653,7 @@ class FlintSparkPPLITSuite ) // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) assert( results.sorted.sameElements(expectedResults.sorted), s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}" @@ -721,6 +810,34 @@ class FlintSparkPPLITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + test("create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val projectPlan = Project(star, aggregatePlan) + val expectedPlan = Limit(Literal(2), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } /** * +--------+-------+-----------+ @@ -767,4 +884,31 @@ class FlintSparkPPLITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + ignore("create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { + val frame = sql( + s""" + | source = $testTable | stats avg(age) by span(age, 10) as age_span, country | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val projectPlan = Project(star, aggregatePlan) + val expectedPlan = Limit(Literal(1), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index a497fcd6d..01c101cff 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -231,21 +231,28 @@ The next samples of PPL queries are currently supported: - `source = table | where a >= 1 | fields a,b,c` - `source = table | where a < 1 | fields a,b,c` - `source = table | where b != 'test' | fields a,b,c` - - `source = table | where c = 'test' | fields a,b,c` + - `source = table | where c = 'test' | fields a,b,c | head 3` **Filters With Logical Conditions** - `source = table | where c = 'test' AND a = 1 | fields a,b,c` - - `source = table | where c != 'test' OR a > 1 | fields a,b,c` + - `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1` - `source = table | where c = 'test' NOT a > 1 | fields a,b,c` **Aggregations** - `source = table | stats avg(a) ` - `source = table | where a < 50 | stats avg(c) ` - `source = table | stats max(c) by b` + - `source = table | stats count(c) by b | head 5` **Aggregations With Span** - `source = table | stats count(a) by span(a, 10) as a_span` +#### Supported Commands: + - `search` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/search.rst) + - `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst) + - `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst) + - `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst) + - `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) > For additional details review the next [Integration Test ](../integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala) @@ -253,4 +260,10 @@ The next samples of PPL queries are currently supported: #### Planned Support - - support the `explain` command to return the explained PPL query logical plan and expected execution plan \ No newline at end of file + - support the `explain` command to return the explained PPL query logical plan and expected execution plan + - add [sort](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) support + - add [conditions](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/condition.rst) support + - add [top](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/top.rst) support + - add [cast](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/conversion.rst) support + - add [math](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/math.rst) support + - add [deduplicate](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/dedup.rst) support \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 63a05440e..f85fe27bc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -21,6 +21,7 @@ public class CatalystPlanContext { * Catalyst evolving logical plan **/ private Stack planBranches = new Stack<>(); + private int limit = Integer.MIN_VALUE; /** * NamedExpression contextual parameters @@ -48,6 +49,14 @@ public void with(LogicalPlan plan) { this.planBranches.push(plan); } + public void limit(int limit) { + this.limit = limit; + } + + public int getLimit() { + return limit; + } + public void plan(Function transformFunction) { this.planBranches.replaceAll(transformFunction::apply); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 039459150..20d117efa 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -19,6 +19,9 @@ import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; @@ -27,6 +30,7 @@ import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.Interval; @@ -204,8 +208,12 @@ public String visitProject(Project node, CatalystPlanContext context) { // Create a projection list from the existing expressions Seq projectList = asScalaBuffer(context.getNamedParseExpressions()).toSeq(); if (!projectList.isEmpty()) { + Seq namedExpressionSeq = asScalaBuffer(context.getNamedParseExpressions().stream() + .map(v -> (NamedExpression) v).collect(Collectors.toList())).toSeq(); // build the plan with the projection step - context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project((Seq) projectList, p)); + context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(namedExpressionSeq, p)); + //now remove all context.getNamedParseExpressions() + context.getNamedParseExpressions().retainAll(emptyList()); } if (node.hasArgument()) { Argument argument = node.getArgExprList().get(0); @@ -214,6 +222,10 @@ public String visitProject(Project node, CatalystPlanContext context) { arg = "-"; } } + if(context.getLimit() > 0) { + context.plan(p-> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( + context.getLimit(), DataTypes.IntegerType), p)); + } return format("%s | fields %s %s", child, arg, fields); } @@ -259,6 +271,7 @@ public String visitDedupe(Dedupe node, CatalystPlanContext context) { public String visitHead(Head node, CatalystPlanContext context) { String child = node.getChild().get(0).accept(this, context); Integer size = node.getSize(); + context.limit(size); return format("%s | head %d", child, size); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 517db2ec7..26e31b60c 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.{Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan @@ -76,6 +76,18 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | fields + A,B") } + test("test simple search with only one table with two fields with head (limit ) command projected") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source=t | fields A, B | head 5", false), context) + + + val table = UnresolvedRelation(Seq("t")) + val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val planWithLimit = Project(Seq(UnresolvedStar(None)), Project(projectList, table)) + val expectedPlan = GlobalLimit(Literal(5), LocalLimit(Literal(5), planWithLimit)) + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "source=[t] | fields + A,B | head 5 | fields + *") + } test("Search multiple tables - translated into union call - fields expected to exist in both tables ") { val context = new CatalystPlanContext