Skip to content

Commit

Permalink
Fix distinct_count ppl stats function (opensearch-project#548)
Browse files Browse the repository at this point in the history
* Fix distinct_count, Add test for distinct_count

Signed-off-by: Hendrik Saly <[email protected]>

* add tests for stats: distinct_count (#1)

add tests for stats: distinct_count

Signed-off-by: Kacper Trochimiak <[email protected]>

* Add functions to readme

Signed-off-by: Hendrik Saly <[email protected]>

* Fix comparePlans

Signed-off-by: Hendrik Saly <[email protected]>

---------

Signed-off-by: Hendrik Saly <[email protected]>
Signed-off-by: Kacper Trochimiak <[email protected]>
Co-authored-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
salyh and kt-eliatra authored Aug 21, 2024
1 parent 5c190b3 commit b407a06
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,50 @@ class FlintSparkPPLAggregationWithSpanITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* | age_span | count_age |
* |:---------|----------:|
* | 20 | 1 |
* | 30 | 1 |
* | 70 | 1 |
*/
test(
"create ppl simple distinct count age by span of interval of 10 years query with state filter test ") {
val frame = sql(s"""
| source = $testTable | where state != 'Quebec' | stats distinct_count(age) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(1, 20L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val stateField = UnresolvedAttribute("state")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true),
"distinct_count(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val filterExpr = Not(EqualTo(stateField, Literal("Quebec")))
val filterPlan = Filter(filterExpr, table)
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
comparePlans(expectedPlan, logicalPlan, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -797,4 +797,126 @@ class FlintSparkPPLAggregationsITSuite
}
assert(thrown.getMessage === "Unsupported value 'percent': -4 (expected: >= 0 <= 100))")
}

test("create ppl simple country distinct_count ") {
val frame = sql(s"""
| source = $testTable| stats distinct_count(country)
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()

// Define the expected results
val expectedResults: Array[Row] = Array(Row(2L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = true),
"distinct_count(country)")()

val aggregatePlan =
Aggregate(Seq.empty, Seq(aggregateExpressions), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl simple age distinct_count group by country query test with sort") {
val frame = sql(s"""
| source = $testTable | stats distinct_count(age) by country | sort country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true),
"distinct_count(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl simple age distinct_count group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats distinct_count(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val filterExpr = Not(EqualTo(stateField, Literal("Ontario")))
val filterPlan = Filter(filterExpr, table)
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true),
"distinct_count(age)")()
val productAlias = Alias(countryField, "country")()
val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
1 change: 1 addition & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | where a < 50 | stats avg(c) `
- `source = table | stats max(c) by b`
- `source = table | stats count(c) by b | head 5`
- `source = table | stats distinct_count(c)`
- `source = table | stats stddev_samp(c)`
- `source = table | stats stddev_pop(c)`
- `source = table | stats percentile(c, 90)`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction
// Additional aggregation function operators will be added here
switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) {
case MAX:
return new UnresolvedFunction(seq("MAX"), seq(arg),false, empty(),false);
return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
case MIN:
return new UnresolvedFunction(seq("MIN"), seq(arg),false, empty(),false);
return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
case AVG:
return new UnresolvedFunction(seq("AVG"), seq(arg),false, empty(),false);
return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
case COUNT:
return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false);
return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
case SUM:
return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false);
return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
case STDDEV_POP:
return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
case STDDEV_SAMP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._

Expand Down Expand Up @@ -758,4 +758,122 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
comparePlans(expectedPlan, logPlan, false)
}

test("test distinct count product group by brand sorted") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats distinct_count(product) by brand | sort brand",
false),
context)
val star = Seq(UnresolvedStar(None))
val brandField = UnresolvedAttribute("brand")
val productField = UnresolvedAttribute("product")
val tableRelation = UnresolvedRelation(Seq("table"))

val groupByAttributes = Seq(Alias(brandField, "brand")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(productField), isDistinct = true),
"distinct_count(product)")()
val brandAlias = Alias(brandField, "brand")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, brandAlias), tableRelation)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(brandField, Ascending)), global = true, aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

comparePlans(expectedPlan, logPlan, false)
}

test("test distinct count product with alias and filter") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table price > 100 | stats distinct_count(product) as dc_product",
false),
context)
val star = Seq(UnresolvedStar(None))
val productField = UnresolvedAttribute("product")
val priceField = UnresolvedAttribute("price")
val tableRelation = UnresolvedRelation(Seq("table"))

val aggregateExpressions = Seq(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(productField), isDistinct = true),
"dc_product")())
val filterExpr = GreaterThan(priceField, Literal(100))
val filterPlan = Filter(filterExpr, tableRelation)
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan)
val expectedPlan = Project(star, aggregatePlan)

comparePlans(expectedPlan, logPlan, false)
}

test("test distinct count age by span of interval of 10 years query with sort ") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats distinct_count(age) by span(age, 10) as age_span | sort age",
false),
context)
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true),
"distinct_count(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

comparePlans(expectedPlan, logPlan, false)
}

test("test distinct count status by week window and group by status with limit") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats distinct_count(status) by span(@timestamp, 1w) as status_count_by_week, status | head 100",
false),
context)
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val status = Alias(UnresolvedAttribute("status"), "status")()
val statusCount = UnresolvedAttribute("status")
val table = UnresolvedRelation(Seq("table"))

val windowExpression = Alias(
TimeWindow(
UnresolvedAttribute("`@timestamp`"),
TimeWindow.parseExpression(Literal("1 week")),
TimeWindow.parseExpression(Literal("1 week")),
0),
"status_count_by_week")()

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(statusCount), isDistinct = true),
"distinct_count(status)")()
val aggregatePlan = Aggregate(
Seq(status, windowExpression),
Seq(aggregateExpressions, status, windowExpression),
table)
val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan))
val expectedPlan = Project(star, planWithLimit)
// Compare the two plans
comparePlans(expectedPlan, logPlan, false)
}

}

0 comments on commit b407a06

Please sign in to comment.