diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index 2418ecc14..93583bd89 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -40,7 +40,7 @@ class FlintSparkPPLTopAndRareITSuite test("create ppl rare address field query test") { val frame = sql(s""" - | source = $testTable| rare address" + | source = $testTable| rare address | """.stripMargin) // Retrieve the results diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 7eba00c94..7c112c41a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -179,7 +179,6 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex node.getChild().get(0).accept(this, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); - List sortExpList = visitExpressionList(node.getSortExprList(), context); if (!groupExpList.isEmpty()) { //add group by fields to context context.getGroupingParseExpressions().addAll(groupExpList); @@ -199,7 +198,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex List sortDirections = new ArrayList<>(); sortDirections.add(node instanceof RareAggregation ? Ascending$.MODULE$ : node instanceof TopAggregation ? Descending$.MODULE$ : Ascending$.MODULE$); - if (!sortExpList.isEmpty()) { + if (!node.getSortExprList().isEmpty()) { visitExpressionList(node.getSortExprList(), context); Seq sortElements = context.retainAllNamedParseExpressions(exp -> new SortOrder(exp, diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ca988a7d8..3a814ece9 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -277,14 +277,20 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); - Alias alias = new Alias(name, aggExpression); + Alias alias = new Alias("count("+name+")", aggExpression); aggListBuilder.add(alias); + // group by the `field-list` as the mandatory groupBy fields + groupListBuilder.add(internalVisitExpression(field)); }); - List groupList = + + // group by the `by-clause` as the optional groupBy fields + groupListBuilder.addAll( Optional.ofNullable(ctx.byClause()) .map(OpenSearchPPLParser.ByClauseContext::fieldList) .map( @@ -297,15 +303,17 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) getTextInQuery(groupCtx), internalVisitExpression(groupCtx))) .collect(Collectors.toList())) - .orElse(emptyList()); - - - + .orElse(emptyList()) + ); + //build the sort fields + ctx.fieldList().fieldExpression().forEach(field -> { + sortListBuilder.add(internalVisitExpression(field)); + }); TopAggregation aggregation = new TopAggregation( aggListBuilder.build(), - emptyList(), - groupList); + sortListBuilder.build(), + groupListBuilder.build()); return aggregation; } @@ -313,15 +321,20 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) @Override public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); - Alias alias = new Alias(name, aggExpression); + Alias alias = new Alias("count("+name+")", aggExpression); aggListBuilder.add(alias); + // group by the `field-list` as the mandatory groupBy fields + groupListBuilder.add(internalVisitExpression(field)); }); - List groupList = + + // group by the `by-clause` as the optional groupBy fields + groupListBuilder.addAll( Optional.ofNullable(ctx.byClause()) .map(OpenSearchPPLParser.ByClauseContext::fieldList) .map( @@ -334,7 +347,8 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct getTextInQuery(groupCtx), internalVisitExpression(groupCtx))) .collect(Collectors.toList())) - .orElse(emptyList()); + .orElse(emptyList()) + ); //build the sort fields ctx.fieldList().fieldExpression().forEach(field -> { sortListBuilder.add(internalVisitExpression(field)); @@ -343,9 +357,8 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct new RareAggregation( aggListBuilder.build(), sortListBuilder.build(), - groupList); + groupListBuilder.build()); return aggregation; - } /** From clause. */ diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index ba634cc1c..7db9d1af2 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -41,6 +41,47 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test("test count price") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = table | stats count(price) ", false), context) + // SQL: SELECT avg(price) as avg_price FROM table + val star = Seq(UnresolvedStar(None)) + + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false), "count(price)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test count price by country") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = table | stats count(price) by product ", false), context) + // SQL: SELECT count(price) AS count_price FROM table GROUP BY product + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false), "count(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } test("test average price with Alias") { // if successful build ppl logical plan and translate to catalyst logical plan diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index f33cc3000..b23d183fa 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -7,8 +7,8 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -29,149 +29,25 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=accounts | rare gender", false), context) + val genderField = UnresolvedAttribute("gender") + val tableRelation = UnresolvedRelation(Seq("accounts")) val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) - val expectedPlan = Project(projectList, UnresolvedRelation(Seq("accounts"))) - comparePlans(expectedPlan, logPlan, false) - } - - test("test simple search with escaped table name") { - // if successful build ppl logical plan and translate to catalyst logical plan - val context = new CatalystPlanContext - val logPlan = planTransformer.visit(plan(pplParser, "source=`table`", false), context) - - val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) - val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - comparePlans(expectedPlan, logPlan, false) - } - - test("test simple search with schema.table and no explicit fields (defaults to all fields)") { - // if successful build ppl logical plan and translate to catalyst logical plan - val context = new CatalystPlanContext - val logPlan = planTransformer.visit(plan(pplParser, "source=schema.table", false), context) - - val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) - val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - comparePlans(expectedPlan, logPlan, false) - - } - - test("test simple search with schema.table and one field projected") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit(plan(pplParser, "source=schema.table | fields A", false), context) - - val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) - val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - comparePlans(expectedPlan, logPlan, false) - } - - test("test simple search with only one table with one field projected") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit(plan(pplParser, "source=table | fields A", false), context) - - val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) - val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - comparePlans(expectedPlan, logPlan, false) - } - - test("test simple search with only one table with two fields projected") { - val context = new CatalystPlanContext - val logPlan = planTransformer.visit(plan(pplParser, "source=t | fields A, B", false), context) - - val table = UnresolvedRelation(Seq("t")) - val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) - val expectedPlan = Project(projectList, table) - comparePlans(expectedPlan, logPlan, false) - } - - test("test simple search with one table with two fields projected sorted by one field") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit(plan(pplParser, "source=t | sort A | fields A, B", false), context) - - val table = UnresolvedRelation(Seq("t")) - val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) - // Sort by A ascending - val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Ascending)) - val sorted = Sort(sortOrder, true, table) - val expectedPlan = Project(projectList, sorted) - - comparePlans(expectedPlan, logPlan, false) - } - - test( - "test simple search with only one table with two fields with head (limit ) command projected") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit(plan(pplParser, "source=t | fields A, B | head 5", false), context) - - val table = UnresolvedRelation(Seq("t")) - val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) - val planWithLimit = - GlobalLimit(Literal(5), LocalLimit(Literal(5), Project(projectList, table))) - val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - comparePlans(expectedPlan, logPlan, false) - } - - test( - "test simple search with only one table with two fields with head (limit ) command projected sorted by one descending field") { - val context = new CatalystPlanContext - val logPlan = planTransformer.visit( - plan(pplParser, "source=t | sort - A | fields A, B | head 5", false), - context) - - val table = UnresolvedRelation(Seq("t")) - val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending)) - val sorted = Sort(sortOrder, true, table) - val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) - val projectAB = Project(projectList, sorted) - - val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projectAB)) - val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - comparePlans(expectedPlan, logPlan, false) - } - - test( - "Search multiple tables - translated into union call - fields expected to exist in both tables ") { - val context = new CatalystPlanContext - val logPlan = planTransformer.visit( - plan(pplParser, "search source = table1, table2 | fields A, B", false), - context) - - val table1 = UnresolvedRelation(Seq("table1")) - val table2 = UnresolvedRelation(Seq("table2")) - - val allFields1 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) - val allFields2 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) - - val projectedTable1 = Project(allFields1, table1) - val projectedTable2 = Project(allFields2, table2) - - val expectedPlan = - Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) - - comparePlans(expectedPlan, logPlan, false) - } - - test("Search multiple tables - translated into union call with fields") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit(plan(pplParser, "source = table1, table2 ", false), context) - - val table1 = UnresolvedRelation(Seq("table1")) - val table2 = UnresolvedRelation(Seq("table2")) - - val allFields1 = UnresolvedStar(None) - val allFields2 = UnresolvedStar(None) - val projectedTable1 = Project(Seq(allFields1), table1) - val projectedTable2 = Project(Seq(allFields2), table2) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("COUNT"), Seq(genderField), isDistinct = false), "count(gender)")(), + genderField + ) - val expectedPlan = - Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + val aggregatePlan = + Aggregate(Seq(genderField), aggregateExpressions, tableRelation) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("gender"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) comparePlans(expectedPlan, logPlan, false) } }