From ffe81e5aa0c61993580cb2226a716e5c2e22f824 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 28 Aug 2024 16:15:57 +0800 Subject: [PATCH 1/2] Add UT and IT for 2+ level aggregations PPL command Signed-off-by: Lantao Jin --- .../FlintSparkPPLAggregationsITSuite.scala | 205 +++++++++++++++++- ...ggregationQueriesTranslatorTestSuite.scala | 124 +++++++++++ 2 files changed, 328 insertions(+), 1 deletion(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index 7a3a886dd..55d3d0709 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, LessThan, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -919,4 +919,207 @@ class FlintSparkPPLAggregationsITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + test("two-level stats") { + val frame = sql(s""" + | source = $testTable| stats avg(age) as avg_age by state, country | stats avg(avg_age) as avg_state_age by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val avgAgeField = UnresolvedAttribute("avg_age") + val stateAlias = Alias(stateField, "state")() + val countryAlias = Alias(countryField, "country")() + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes1 = Seq(stateAlias, countryAlias) + val aggregateExpressions1 = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")() + val aggregatePlan1 = + Aggregate(groupByAttributes1, Seq(aggregateExpressions1, stateAlias, countryAlias), table) + + val groupByAttributes2 = Seq(countryAlias) + val aggregateExpressions2 = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(avgAgeField), isDistinct = false), + "avg_state_age")() + + val aggregatePlan2 = + Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), aggregatePlan1) + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("two-level stats with eval") { + val frame = sql(s""" + | source = $testTable| stats avg(age) as avg_age by state, country | eval new_avg_age = avg_age - 10 | stats avg(new_avg_age) as avg_state_age by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(12.5, "Canada"), Row(40.0, "USA")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val avgAgeField = UnresolvedAttribute("avg_age") + val stateAlias = Alias(stateField, "state")() + val countryAlias = Alias(countryField, "country")() + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes1 = Seq(stateAlias, countryAlias) + val aggregateExpressions1 = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")() + val aggregatePlan1 = + Aggregate(groupByAttributes1, Seq(aggregateExpressions1, stateAlias, countryAlias), table) + + val newAvgAgeAlias = + Alias( + UnresolvedFunction(Seq("-"), Seq(avgAgeField, Literal(10)), isDistinct = false), + "new_avg_age")() + val evalProject = Project(Seq(UnresolvedStar(None), newAvgAgeAlias), aggregatePlan1) + + val groupByAttributes2 = Seq(countryAlias) + val aggregateExpressions2 = + Alias( + UnresolvedFunction( + Seq("AVG"), + Seq(UnresolvedAttribute("new_avg_age")), + isDistinct = false), + "avg_state_age")() + + val aggregatePlan2 = + Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), evalProject) + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("two-level stats with filter") { + val frame = sql(s""" + | source = $testTable| stats avg(age) as avg_age by country, state | where avg_age > 0 | stats count(avg_age) as count_state_age by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val avgAgeField = UnresolvedAttribute("avg_age") + val stateAlias = Alias(stateField, "state")() + val countryAlias = Alias(countryField, "country")() + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes1 = Seq(countryAlias, stateAlias) + val aggregateExpressions1 = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")() + val aggregatePlan1 = + Aggregate(groupByAttributes1, Seq(aggregateExpressions1, countryAlias, stateAlias), table) + + val filterExpr = GreaterThan(avgAgeField, Literal(0)) + val filterPlan = Filter(filterExpr, aggregatePlan1) + + val groupByAttributes2 = Seq(countryAlias) + val aggregateExpressions2 = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(avgAgeField), isDistinct = false), + "count_state_age")() + + val aggregatePlan2 = + Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("three-level stats with eval and filter") { + val frame = sql(s""" + | source = $testTable| stats avg(age) as avg_age by country, state, name | eval avg_age_divide_20 = avg_age - 20 | stats avg(avg_age_divide_20) + | as avg_state_age by country, state | where avg_state_age > 0 | stats count(avg_state_age) as count_country_age_greater_20 by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val star = Seq(UnresolvedStar(None)) + val nameField = UnresolvedAttribute("name") + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val avgAgeField = UnresolvedAttribute("avg_age") + val nameAlias = Alias(nameField, "name")() + val stateAlias = Alias(stateField, "state")() + val countryAlias = Alias(countryField, "country")() + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes1 = Seq(countryAlias, stateAlias, nameAlias) + val aggregateExpressions1 = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")() + val aggregatePlan1 = + Aggregate( + groupByAttributes1, + Seq(aggregateExpressions1, countryAlias, stateAlias, nameAlias), + table) + + val avg_age_divide_20_Alias = + Alias( + UnresolvedFunction(Seq("-"), Seq(avgAgeField, Literal(20)), isDistinct = false), + "avg_age_divide_20")() + val evalProject = Project(Seq(UnresolvedStar(None), avg_age_divide_20_Alias), aggregatePlan1) + val groupByAttributes2 = Seq(countryAlias, stateAlias) + val aggregateExpressions2 = + Alias( + UnresolvedFunction( + Seq("AVG"), + Seq(UnresolvedAttribute("avg_age_divide_20")), + isDistinct = false), + "avg_state_age")() + val aggregatePlan2 = + Aggregate( + groupByAttributes2, + Seq(aggregateExpressions2, countryAlias, stateAlias), + evalProject) + + val filterExpr = GreaterThan(UnresolvedAttribute("avg_state_age"), Literal(0)) + val filterPlan = Filter(filterExpr, aggregatePlan2) + + val groupByAttributes3 = Seq(countryAlias) + val aggregateExpressions3 = + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("avg_state_age")), + isDistinct = false), + "count_country_age_greater_20")() + + val aggregatePlan3 = + Aggregate(groupByAttributes3, Seq(aggregateExpressions3, countryAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan3) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index e47ebedbf..e9fc7e5e1 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -876,4 +876,128 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("multiple stats - test average price and average age") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = table | stats avg(price) | stats avg(age)", false), + context) + val star = Seq(UnresolvedStar(None)) + + val priceField = UnresolvedAttribute("price") + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + val aggregateExpressions1 = Seq( + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()) + val aggregatePlan1 = Aggregate(Seq(), aggregateExpressions1, tableRelation) + val aggregateExpressions2 = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregatePlan2 = Aggregate(Seq(), aggregateExpressions2, aggregatePlan1) + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(expectedPlan, logPlan, false) + } + + test("multiple stats - test average price and average age with Alias") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats avg(price) as avg_price | stats avg(age) as avg_age", + false), + context) + val star = Seq(UnresolvedStar(None)) + + val priceField = UnresolvedAttribute("price") + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + val aggregateExpressions1 = Seq( + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg_price")()) + val aggregatePlan1 = Aggregate(Seq(), aggregateExpressions1, tableRelation) + val aggregateExpressions2 = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()) + val aggregatePlan2 = Aggregate(Seq(), aggregateExpressions2, aggregatePlan1) + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(expectedPlan, logPlan, false) + } + + test( + "multiple stats - test average price group by product and average age by span of interval of 10 years") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats avg(price) by product | stats avg(age) by span(age, 10) as age_span", + false), + context) + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions1 = + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan1 = + Aggregate(groupByAttributes, Seq(aggregateExpressions1, productAlias), tableRelation) + + val aggregateExpressions2 = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan2 = Aggregate(Seq(span), Seq(aggregateExpressions2, span), aggregatePlan1) + + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(expectedPlan, logPlan, false) + } + + test("multiple levels stats") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats avg(response_time) as avg_response_time by host, service | stats avg(avg_response_time) as avg_host_response_time by service", + false), + context) + val star = Seq(UnresolvedStar(None)) + val hostField = UnresolvedAttribute("host") + val serviceField = UnresolvedAttribute("service") + val ageField = UnresolvedAttribute("age") + val responseTimeField = UnresolvedAttribute("response_time") + val tableRelation = UnresolvedRelation(Seq("table")) + val hostAlias = Alias(hostField, "host")() + val serviceAlias = Alias(serviceField, "service")() + + val groupByAttributes1 = Seq(Alias(hostField, "host")(), Alias(serviceField, "service")()) + val aggregateExpressions1 = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(responseTimeField), isDistinct = false), + "avg_response_time")() + val responseTimeAlias = Alias(responseTimeField, "response_time")() + val aggregatePlan1 = + Aggregate( + groupByAttributes1, + Seq(aggregateExpressions1, hostAlias, serviceAlias), + tableRelation) + + val avgResponseTimeField = UnresolvedAttribute("avg_response_time") + val groupByAttributes2 = Seq(Alias(serviceField, "service")()) + val aggregateExpressions2 = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(avgResponseTimeField), isDistinct = false), + "avg_host_response_time")() + + val aggregatePlan2 = + Aggregate(groupByAttributes2, Seq(aggregateExpressions2, serviceAlias), aggregatePlan1) + + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(expectedPlan, logPlan, false) + } } From cd0bdc628a46529106a72862caf809589134486a Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 28 Aug 2024 16:27:37 +0800 Subject: [PATCH 2/2] doc Signed-off-by: Lantao Jin --- ppl-spark-integration/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 24639e444..9d293ae57 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -285,6 +285,10 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date` - `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` +**Aggregations Group by Multiple Levels** +- `source = table | stats avg(age) as avg_state_age by country, state | stats avg(avg_state_age) as avg_country_age by country` +- `source = table | stats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | stats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | stats avg(avg_state_age) as avg_adult_country_age by country` + **Dedup** - `source = table | dedup a | fields a,b,c` - `source = table | dedup a,b | fields a,b,c`