diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index 93583bd89..70640d47b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -6,8 +6,8 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.streaming.StreamTest @@ -45,143 +45,65 @@ class FlintSparkPPLTopAndRareITSuite // Retrieve the results val results: Array[Row] = frame.collect() - assert(results.length == 2) - + assert(results.length == 3) + + val expectedRow = Row(1, "Vancouver") + assert(results.head == expectedRow, s"Expected least frequent result to be $expectedRow, but got ${results.head}") + // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val limitPlan: LogicalPlan = - Limit(Literal(2), UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) - val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple query with head (limit) and sorted test") { - val frame = sql(s""" - | source = $testTable| sort name | head 2 - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - assert(results.length == 2) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), "count(address)")(), addressField) + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), + Seq(SortOrder(UnresolvedAttribute("address"), Descending)), global = true, - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) - - // Define the expected logical plan - val expectedPlan: LogicalPlan = - Project(Seq(UnresolvedStar(None)), Limit(Literal(2), sortedPlan)) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, false) } - - test("create ppl simple query two with fields result test") { + + test("create ppl top address field query test") { val frame = sql(s""" - | source = $testTable| fields name, age + | source = $testTable| top address | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = - Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) + assert(results.length == 3) + + val expectedRows = Set( + Row(2, "Portland"), + Row(2, "Seattle") + ) + val actualRows = results.take(2).toSet + + // Compare the sets + assert(actualRows == expectedRows, + s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows") // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Project( - Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test("create ppl simple sorted query two with fields result test sorted") { - val frame = sql(s""" - | source = $testTable| sort age | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = - Array(Row("Jane", 20), Row("John", 25), Row("Hello", 30), Row("Jake", 70)) - assert(results === expectedResults) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), "count(address)")(), addressField) + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), global = true, - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) - - // Define the expected logical plan - val expectedPlan: LogicalPlan = - Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple query two with fields and head (limit) test") { - val frame = sql(s""" - | source = $testTable| fields name, age | head 1 - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - assert(results.length == 1) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - val project = Project( - Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) - // Define the expected logical plan - val limitPlan: LogicalPlan = Limit(Literal(1), project) - val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), limitPlan) - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple query two with fields and head (limit) with sorting test") { - Seq(("name, age", "age"), ("`name`, `age`", "`age`")).foreach { - case (selectFields, sortField) => - val frame = sql(s""" - | source = $testTable| fields $selectFields | head 1 | sort $sortField - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - assert(results.length == 1) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - val project = Project( - Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) - // Define the expected logical plan - val limitPlan: LogicalPlan = Limit(Literal(1), project) - val sortedPlan: LogicalPlan = - Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan) - - val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, false) } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 7c112c41a..4c28354ba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -196,7 +196,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex // set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc) List sortDirections = new ArrayList<>(); - sortDirections.add(node instanceof RareAggregation ? Ascending$.MODULE$ : node instanceof TopAggregation ? Descending$.MODULE$ : Ascending$.MODULE$); + sortDirections.add(node instanceof RareAggregation ? Descending$.MODULE$ : Ascending$.MODULE$); if (!node.getSortExprList().isEmpty()) { visitExpressionList(node.getSortExprList(), context); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index b23d183fa..8dfded480 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -39,6 +39,32 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite genderField ) + val aggregatePlan = + Aggregate(Seq(genderField), aggregateExpressions, tableRelation) + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("gender"), Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test simple top command with a single field") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=accounts | top gender", false), context) + val genderField = UnresolvedAttribute("gender") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("COUNT"), Seq(genderField), isDistinct = false), "count(gender)")(), + genderField + ) + val aggregatePlan = Aggregate(Seq(genderField), aggregateExpressions, tableRelation)