Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UT and IT for 2+ level aggregations PPL command #603

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, LessThan, Literal, Not, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -919,4 +919,207 @@ class FlintSparkPPLAggregationsITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("two-level stats") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by state, country | stats avg(avg_age) as avg_state_age by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(stateAlias, countryAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(groupByAttributes1, Seq(aggregateExpressions1, stateAlias, countryAlias), table)

val groupByAttributes2 = Seq(countryAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(avgAgeField), isDistinct = false),
"avg_state_age")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), aggregatePlan1)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("two-level stats with eval") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by state, country | eval new_avg_age = avg_age - 10 | stats avg(new_avg_age) as avg_state_age by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(12.5, "Canada"), Row(40.0, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(stateAlias, countryAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(groupByAttributes1, Seq(aggregateExpressions1, stateAlias, countryAlias), table)

val newAvgAgeAlias =
Alias(
UnresolvedFunction(Seq("-"), Seq(avgAgeField, Literal(10)), isDistinct = false),
"new_avg_age")()
val evalProject = Project(Seq(UnresolvedStar(None), newAvgAgeAlias), aggregatePlan1)

val groupByAttributes2 = Seq(countryAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("new_avg_age")),
isDistinct = false),
"avg_state_age")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), evalProject)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("two-level stats with filter") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by country, state | where avg_age > 0 | stats count(avg_age) as count_state_age by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(countryAlias, stateAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(groupByAttributes1, Seq(aggregateExpressions1, countryAlias, stateAlias), table)

val filterExpr = GreaterThan(avgAgeField, Literal(0))
val filterPlan = Filter(filterExpr, aggregatePlan1)

val groupByAttributes2 = Seq(countryAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(avgAgeField), isDistinct = false),
"count_state_age")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("three-level stats with eval and filter") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by country, state, name | eval avg_age_divide_20 = avg_age - 20 | stats avg(avg_age_divide_20)
| as avg_state_age by country, state | where avg_state_age > 0 | stats count(avg_state_age) as count_country_age_greater_20 by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val nameField = UnresolvedAttribute("name")
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val nameAlias = Alias(nameField, "name")()
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(countryAlias, stateAlias, nameAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(
groupByAttributes1,
Seq(aggregateExpressions1, countryAlias, stateAlias, nameAlias),
table)

val avg_age_divide_20_Alias =
Alias(
UnresolvedFunction(Seq("-"), Seq(avgAgeField, Literal(20)), isDistinct = false),
"avg_age_divide_20")()
val evalProject = Project(Seq(UnresolvedStar(None), avg_age_divide_20_Alias), aggregatePlan1)
val groupByAttributes2 = Seq(countryAlias, stateAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("avg_age_divide_20")),
isDistinct = false),
"avg_state_age")()
val aggregatePlan2 =
Aggregate(
groupByAttributes2,
Seq(aggregateExpressions2, countryAlias, stateAlias),
evalProject)

val filterExpr = GreaterThan(UnresolvedAttribute("avg_state_age"), Literal(0))
val filterPlan = Filter(filterExpr, aggregatePlan2)

val groupByAttributes3 = Seq(countryAlias)
val aggregateExpressions3 =
Alias(
UnresolvedFunction(
Seq("COUNT"),
Seq(UnresolvedAttribute("avg_state_age")),
isDistinct = false),
"count_country_age_greater_20")()

val aggregatePlan3 =
Aggregate(groupByAttributes3, Seq(aggregateExpressions3, countryAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan3)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
4 changes: 4 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date`
- `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId`

**Aggregations Group by Multiple Levels**
- `source = table | stats avg(age) as avg_state_age by country, state | stats avg(avg_state_age) as avg_country_age by country`
- `source = table | stats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | stats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | stats avg(avg_state_age) as avg_adult_country_age by country`

**Dedup**
- `source = table | dedup a | fields a,b,c`
- `source = table | dedup a,b | fields a,b,c`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -876,4 +876,128 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
comparePlans(expectedPlan, logPlan, false)
}

test("multiple stats - test average price and average age") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, "source = table | stats avg(price) | stats avg(age)", false),
context)
val star = Seq(UnresolvedStar(None))

val priceField = UnresolvedAttribute("price")
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))
val aggregateExpressions1 = Seq(
Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")())
val aggregatePlan1 = Aggregate(Seq(), aggregateExpressions1, tableRelation)
val aggregateExpressions2 =
Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")())
val aggregatePlan2 = Aggregate(Seq(), aggregateExpressions2, aggregatePlan1)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}

test("multiple stats - test average price and average age with Alias") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats avg(price) as avg_price | stats avg(age) as avg_age",
false),
context)
val star = Seq(UnresolvedStar(None))

val priceField = UnresolvedAttribute("price")
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))
val aggregateExpressions1 = Seq(
Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg_price")())
val aggregatePlan1 = Aggregate(Seq(), aggregateExpressions1, tableRelation)
val aggregateExpressions2 =
Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")())
val aggregatePlan2 = Aggregate(Seq(), aggregateExpressions2, aggregatePlan1)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}

test(
"multiple stats - test average price group by product and average age by span of interval of 10 years") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats avg(price) by product | stats avg(age) by span(age, 10) as age_span",
false),
context)
val star = Seq(UnresolvedStar(None))
val productField = UnresolvedAttribute("product")
val priceField = UnresolvedAttribute("price")
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))

val groupByAttributes = Seq(Alias(productField, "product")())
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()
val productAlias = Alias(productField, "product")()

val aggregatePlan1 =
Aggregate(groupByAttributes, Seq(aggregateExpressions1, productAlias), tableRelation)

val aggregateExpressions2 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val aggregatePlan2 = Aggregate(Seq(span), Seq(aggregateExpressions2, span), aggregatePlan1)

val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}

test("multiple levels stats") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats avg(response_time) as avg_response_time by host, service | stats avg(avg_response_time) as avg_host_response_time by service",
false),
context)
val star = Seq(UnresolvedStar(None))
val hostField = UnresolvedAttribute("host")
val serviceField = UnresolvedAttribute("service")
val ageField = UnresolvedAttribute("age")
val responseTimeField = UnresolvedAttribute("response_time")
val tableRelation = UnresolvedRelation(Seq("table"))
val hostAlias = Alias(hostField, "host")()
val serviceAlias = Alias(serviceField, "service")()

val groupByAttributes1 = Seq(Alias(hostField, "host")(), Alias(serviceField, "service")())
val aggregateExpressions1 =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(responseTimeField), isDistinct = false),
"avg_response_time")()
val responseTimeAlias = Alias(responseTimeField, "response_time")()
val aggregatePlan1 =
Aggregate(
groupByAttributes1,
Seq(aggregateExpressions1, hostAlias, serviceAlias),
tableRelation)

val avgResponseTimeField = UnresolvedAttribute("avg_response_time")
val groupByAttributes2 = Seq(Alias(serviceField, "service")())
val aggregateExpressions2 =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(avgResponseTimeField), isDistinct = false),
"avg_host_response_time")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, serviceAlias), aggregatePlan1)

val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}
}
Loading