diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index a686b5835..5fc8c6745 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GreaterThan, LessThanOrEqual, Literal, Not} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.{QueryTest, Row} @@ -24,12 +24,15 @@ class FlintSparkPPLITSuite super.beforeAll() // Create test table + // Update table creation sql( s""" | CREATE TABLE $testTable | ( | name STRING, - | age INT + | age INT, + | state STRING, + | country STRING | ) | USING CSV | OPTIONS ( @@ -42,15 +45,15 @@ class FlintSparkPPLITSuite | ) |""".stripMargin) - // Insert data + // Update data insertion sql( s""" | INSERT INTO $testTable | PARTITION (year=2023, month=4) - | VALUES ('Jake', 70), - | ('Hello', 30), - | ('John', 25), - | ('Jane', 25) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 25, 'Quebec', 'Canada') | """.stripMargin) } @@ -72,11 +75,12 @@ class FlintSparkPPLITSuite // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results +// [John,25,Ontario,Canada,2023,4], [Jane,25,Quebec,Canada,2023,4], [Jake,70,California,USA,2023,4], [Hello,30,New York,USA,2023,4] val expectedResults: Array[Row] = Array( - Row("Jake", 70, 2023, 4), - Row("Hello", 30, 2023, 4), - Row("John", 25, 2023, 4), - Row("Jane", 25, 2023, 4) + Row("Jake",70,"California","USA",2023,4), + Row("Hello",30,"New York","USA",2023,4), + Row("John",25,"Ontario","Canada",2023,4), + Row("Jane",25,"Quebec","Canada",2023,4) ) // Compare the results assert(results === expectedResults) @@ -286,14 +290,14 @@ class FlintSparkPPLITSuite assert(compareByString(aggregatePlan) === compareByString(logicalPlan)) } - ignore("create ppl simple age avg group by query test ") { - val checkData = sql(s"SELECT name, AVG(age) AS avg_age FROM $testTable group by name"); + test("create ppl simple age avg group by country query test ") { + val checkData = sql(s"SELECT country, AVG(age) AS avg_age FROM $testTable group by country"); checkData.show() checkData.queryExecution.logical.show() val frame = sql( s""" - | source = $testTable| stats avg(age) by name + | source = $testTable| stats avg(age) by country | """.stripMargin) @@ -301,21 +305,30 @@ class FlintSparkPPLITSuite val results: Array[Row] = frame.collect() // Define the expected results val expectedResults: Array[Row] = Array( - Row(37.5), + Row(25.0,"Canada"), + Row(50.0,"USA"), ) // Compare the results - assert(results === expectedResults) - + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val priceField = UnresolvedAttribute("price") + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()) - val aggregatePlan = Project( aggregateExpressions, table) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) // Compare the two plans - assert(aggregatePlan === logicalPlan) + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 47085395e..471016e03 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -150,10 +150,13 @@ public String visitAggregation(Aggregation node, CatalystPlanContext context) { final String visitExpressionList = visitExpressionList(node.getAggExprList(), context); final String group = visitExpressionList(node.getGroupExprList(), context); - NamedExpression namedExpression = (NamedExpression) context.getNamedParseExpressions().peek(); - Seq namedExpressionSeq = asScalaBuffer(singletonList(namedExpression)).toSeq(); if(!isNullOrEmpty(group)) { + NamedExpression namedExpression = (NamedExpression) context.getNamedParseExpressions().peek(); + Seq namedExpressionSeq = asScalaBuffer(context.getNamedParseExpressions().stream() + .map(v->(NamedExpression)v).collect(Collectors.toList())).toSeq(); + //now remove all context.getNamedParseExpressions() + context.getNamedParseExpressions().retainAll(emptyList()); context.plan(p->new Aggregate(asScalaBuffer(singletonList((Expression) namedExpression)),namedExpressionSeq,p)); } return format( @@ -183,11 +186,12 @@ public String visitProject(Project node, CatalystPlanContext context) { String arg = "+"; String fields = visitExpressionList(node.getProjectList(), context); - // Create an UnresolvedStar for all-fields projection + // Create a projection list from the existing expressions Seq projectList = asScalaBuffer(context.getNamedParseExpressions()).toSeq(); - // Create a Project node with the UnresolvedStar - context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project((Seq) projectList, p)); - + if(!projectList.isEmpty()) { + // build the plan with the projection step + context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project((Seq) projectList, p)); + } if (node.hasArgument()) { Argument argument = node.getArgExprList().get(0); Boolean exclude = (Boolean) argument.getValue().getValue(); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 7a3d1c243..473cbdd8a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -6,7 +6,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical._ import org.junit.Assert.assertEquals @@ -37,23 +37,24 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite assertEquals(compareByString(aggregatePlan), compareByString(context.getPlan)) } - ignore("test average price group by product ") { + test("test average price group by product ") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = planTrnasformer.visit(plan(pplParser, "source = table | stats avg(price) by product", false), context) //SQL: SELECT product, AVG(price) AS avg_price FROM table GROUP BY product - + val star = Seq(UnresolvedStar(None)) val productField = UnresolvedAttribute("product") val priceField = UnresolvedAttribute("price") val tableRelation = UnresolvedRelation(Seq("table")) val groupByAttributes = Seq(Alias(productField, "product")()) - val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")() + val productAlias = Alias(productField, "product")() - val aggregatePlan = Aggregate(groupByAttributes, aggregateExpressions, tableRelation) - val expectedPlan = Project(Seq(productField), aggregatePlan) + val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions,productAlias), tableRelation) + val expectedPlan = Project(star, aggregatePlan) - assertEquals(logPlan, "source=[table] | stats avg(price) by product | fields + 'product AS product#1") + assertEquals(logPlan, "source=[table] | stats avg(price) by product | fields + *") assertEquals(compareByString(expectedPlan), compareByString(context.getPlan)) }