Skip to content

Commit

Permalink
Add percentile PPL function (#547)
Browse files Browse the repository at this point in the history
* percentile prototype

Signed-off-by: Hendrik Saly <[email protected]>

* add tests for stats: percentile

Signed-off-by: Kacper Trochimiak <[email protected]>

* Add PERCENTILE_APPROX

Signed-off-by: Hendrik Saly <[email protected]>

* Add functions to readme

Signed-off-by: Hendrik Saly <[email protected]>

* Add null checks

Signed-off-by: Hendrik Saly <[email protected]>

* Fix tests, add tests

Signed-off-by: Hendrik Saly <[email protected]>

---------

Signed-off-by: Hendrik Saly <[email protected]>
Signed-off-by: Kacper Trochimiak <[email protected]>
Co-authored-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
salyh and kt-eliatra authored Aug 20, 2024
1 parent aa55968 commit f283df8
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,101 @@ class FlintSparkPPLAggregationWithSpanITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* | age_span | age_percentile |
* |:---------|---------------:|
* | 20 | 25 |
* | 30 | 30 |
* | 70 | 70 |
*/
test(
"create ppl simple age 60th percentile by span of interval of 10 years query with state filter test ") {
val frame = sql(s"""
| source = $testTable | where state != 'Quebec' | stats percentile(age, 60) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(25d, 20L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val percentage = Literal(0.6)
val stateField = UnresolvedAttribute("state")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false),
"percentile(age, 60)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val filterExpr = Not(EqualTo(stateField, Literal("Quebec")))
val filterPlan = Filter(filterExpr, table)
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* | age_span | age_percentile |
* |:---------|---------------:|
* | 20 | 25 |
* | 30 | 30 |
* | 70 | 70 |
*/
test(
"create ppl simple age 60th percentile approx by span of interval of 10 years query with state filter test ") {
val frame = sql(s"""
| source = $testTable | where state != 'Quebec' | stats percentile_approx(age, 60) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(25d, 20L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val percentage = Literal(0.6)
val stateField = UnresolvedAttribute("state")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(
Seq("PERCENTILE_APPROX"),
Seq(ageField, percentage),
isDistinct = false),
"percentile_approx(age, 60)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val filterExpr = Not(EqualTo(stateField, Literal("Quebec")))
val filterPlan = Filter(filterExpr, table)
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -615,4 +615,186 @@ class FlintSparkPPLAggregationsITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age 50th percentile ") {
val frame = sql(s"""
| source = $testTable| stats percentile(age, 50)
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(27.5))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val percentage = Literal("0.5")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false),
"percentile(age, 50)")()
val aggregatePlan =
Aggregate(Seq.empty, Seq(aggregateExpressions), table)
val expectedPlan = Project(star, aggregatePlan)
// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl simple age 20th percentile group by country query test with sort") {
val frame = sql(s"""
| source = $testTable | stats percentile(age, 20) by country | sort country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(21d, "Canada"), Row(38d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val percentage = Literal("0.2")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false),
"percentile(age, 20)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl simple age 40th percentile group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats percentile(age, 40) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(20d, "Canada"), Row(46d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val percentage = Literal("0.4")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val filterExpr = Not(EqualTo(stateField, Literal("Ontario")))
val filterPlan = Filter(filterExpr, table)
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false),
"percentile(age, 40)")()
val productAlias = Alias(countryField, "country")()
val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
"create ppl simple age 40th percentile approx group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats percentile_approx(age, 40) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(20d, "Canada"), Row(30d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val percentage = Literal("0.4")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val filterExpr = Not(EqualTo(stateField, Literal("Ontario")))
val filterPlan = Filter(filterExpr, table)
val aggregateExpressions =
Alias(
UnresolvedFunction(
Seq("PERCENTILE_APPROX"),
Seq(ageField, percentage),
isDistinct = false),
"percentile_approx(age, 40)")()
val productAlias = Alias(countryField, "country")()
val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create failing ppl percentile approx - due to too high percentage value test") {
val thrown = intercept[IllegalStateException] {
val frame = sql(s"""
| source = $testTable | stats percentile_approx(age, 200) by country
| """.stripMargin)
}
assert(thrown.getMessage === "Unsupported value 'percent': 200 (expected: >= 0 <= 100))")
}

test("create failing ppl percentile approx - due to too low percentage value test") {
val thrown = intercept[IllegalStateException] {
val frame = sql(s"""
| source = $testTable | stats percentile_approx(age, -4) by country
| """.stripMargin)
}
assert(thrown.getMessage === "Unsupported value 'percent': -4 (expected: >= 0 <= 100))")
}
}
2 changes: 2 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | stats count(c) by b | head 5`
- `source = table | stats stddev_samp(c)`
- `source = table | stats stddev_pop(c)`
- `source = table | stats percentile(c, 90)`
- `source = table | stats percentile_approx(c, 99)`

**Aggregations With Span**
- `source = table | stats count(a) by span(a, 10) as a_span`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ VAR_POP: 'VAR_POP';
STDDEV_SAMP: 'STDDEV_SAMP';
STDDEV_POP: 'STDDEV_POP';
PERCENTILE: 'PERCENTILE';
PERCENTILE_APPROX: 'PERCENTILE_APPROX';
TAKE: 'TAKE';
FIRST: 'FIRST';
LAST: 'LAST';
Expand Down
8 changes: 5 additions & 3 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,10 @@ statsAggTerm

// aggregation functions
statsFunction
: statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall
| COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall
| (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall
: statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall
| COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall
| (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall
| percentileFunctionName = (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS valueExpression COMMA percent = integerLiteral RT_PRTHS # percentileFunctionCall
;

statsFunctionName
Expand Down Expand Up @@ -897,6 +898,7 @@ keywordsCanBeId
| STDDEV_SAMP
| STDDEV_POP
| PERCENTILE
| PERCENTILE_APPROX
| TAKE
| FIRST
| LAST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ public enum BuiltinFunctionName {
TAKE(FunctionName.of("take")),
// Not always an aggregation query
NESTED(FunctionName.of("nested")),
PERCENTILE(FunctionName.of("percentile")),
PERCENTILE_APPROX(FunctionName.of("percentile_approx")),

/** Text Functions. */
ASCII(FunctionName.of("ascii")),
Expand Down Expand Up @@ -285,6 +287,8 @@ public FunctionName getName() {
.put("stddev_pop", BuiltinFunctionName.STDDEV_POP)
.put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP)
.put("take", BuiltinFunctionName.TAKE)
.put("percentile", BuiltinFunctionName.PERCENTILE)
.put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX)
.build();

public static Optional<BuiltinFunctionName> of(String str) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ public UnresolvedExpression visitPercentileAggFunction(OpenSearchPPLParser.Perce
Collections.singletonList(new Argument("rank", (Literal) visit(ctx.value))));
}

@Override
public UnresolvedExpression visitPercentileFunctionCall(OpenSearchPPLParser.PercentileFunctionCallContext ctx) {
return new AggregateFunction(
ctx.percentileFunctionName.getText(),
visit(ctx.valueExpression()),
Collections.singletonList(new Argument("percent", (Literal) visit(ctx.percent))));
}

/**
* Eval function.
*/
Expand Down
Loading

0 comments on commit f283df8

Please sign in to comment.