From 036590ac241a088cc5ad384046d178d62b6ca488 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 15 Aug 2024 20:27:50 -0700 Subject: [PATCH] add additional support for `rare` & `top` commands options including top N ... Signed-off-by: YANGDB --- .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 50 +++++++++++++++++++ .../sql/ast/tree/TopAggregation.java | 9 ++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 6 ++- .../opensearch/sql/ppl/parser/AstBuilder.java | 2 + ...TopAndRareQueriesTranslatorTestSuite.scala | 36 ++++++++++++- 5 files changed, 100 insertions(+), 3 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index 1b54ef277..5ec14e3cf 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -162,4 +162,54 @@ class FlintSparkPPLTopAndRareITSuite val expectedPlan = Project(projectList, sortedPlan) comparePlans(expectedPlan, logicalPlan, false) } + + test("create ppl top 3 countries by occupation field query test") { + val newTestTable = "spark_catalog.default.new_flint_ppl_test" + createOccupationTable(newTestTable) + + val frame = sql(s""" + | source = $newTestTable| top 3 country by occupation + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRows = Set(Row(1, "Canada", "Doctor"), Row(1, "Canada", "Scientist"), Row(1, "Canada", "Unemployed")) + val actualRows = results.take(3).toSet + + // Compare the sets + assert( + actualRows == expectedRows, + s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias(UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), "count(country)")() + val aggregateExpressions = Seq( + countExpr, + countryField, + occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java index 1aaa69dde..451446cc3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java @@ -5,19 +5,28 @@ package org.opensearch.sql.ast.tree; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in queries. */ public class TopAggregation extends Aggregation { + private final Optional results; + /** Aggregation Constructor without span and argument. */ public TopAggregation( + Optional results, List aggExprList, List sortExprList, List groupExprList) { super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + this.results = results; } + public Optional getResults() { + return results; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 4c28354ba..9d02d1596 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -191,7 +191,6 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek()); } // build the aggregation logical step -// context.apply(p -> extractedAggregation(context)); TODO remove LogicalPlan logicalPlan = extractedAggregation(context); // set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc) @@ -207,6 +206,11 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex seq(new ArrayList()))); context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan)); } + //visit TopAggregation results limit + if((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) { + context.apply(p ->(LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( + ((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); + } return logicalPlan; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 3a814ece9..7d91bbb7a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -309,8 +309,10 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) ctx.fieldList().fieldExpression().forEach(field -> { sortListBuilder.add(internalVisitExpression(field)); }); + UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); TopAggregation aggregation = new TopAggregation( + Optional.ofNullable((Literal) unresolvedPlan), aggListBuilder.build(), sortListBuilder.build(), groupListBuilder.build()); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index d1fc4441e..50e449432 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -8,13 +8,12 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.execution.command.DescribeTableCommand class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite @@ -114,4 +113,37 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val expectedPlan = Project(projectList, sortedPlan) comparePlans(expectedPlan, logPlan, false) } + + test("test simple top 1 command by age field") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=accounts | top 1 address by age", false), context) + + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val countExpr = Alias(UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), "count(address)")() + val aggregateExpressions = Seq( + countExpr, + addressField, + ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("accounts"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, false) + } }