Skip to content

Commit

Permalink
add tests for stats: stdev_samp, stdev_pop
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
kt-eliatra committed Aug 14, 2024
1 parent afe6886 commit a834a0d
Show file tree
Hide file tree
Showing 3 changed files with 546 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, Floor, Literal, Multiply, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, EqualTo, Floor, Literal, Multiply, Not, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -262,4 +262,93 @@ class FlintSparkPPLAggregationWithSpanITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* | age_span | age_stddev_samp |
* |:---------|-------------------:|
* | 20 | 3.5355339059327378 |
*/
test(
"create ppl age sample stddev by span of interval of 10 years query with country filter test ") {
val frame = sql(s"""
| source = $testTable | where country != 'USA' | stats stddev_samp(age) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(3.5355339059327378d, 20L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val countryField = UnresolvedAttribute("country")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val filterExpr = Not(EqualTo(countryField, Literal("USA")))
val filterPlan = Filter(filterExpr, table)
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* | age_span | age_stddev_pop |
* |:---------|------------------:|
* | 20 | 2.5 |
* | 30 | 0 |
*/
test(
"create ppl age population stddev by span of interval of 10 years query with state filter test ") {
val frame = sql(s"""
| source = $testTable | where state != 'California' | stats stddev_pop(age) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(2.5d, 20L), Row(0d, 30L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val stateField = UnresolvedAttribute("state")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val filterExpr = Not(EqualTo(stateField, Literal("California")))
val filterPlan = Filter(filterExpr, table)
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,238 @@ class FlintSparkPPLAggregationsITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl age sample stddev") {
val frame = sql(s"""
| source = $testTable| stats stddev_samp(age)
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(22.86737122335374d))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val aggregatePlan =
Aggregate(Seq.empty, Seq(aggregateExpressions), table)
val expectedPlan = Project(star, aggregatePlan)
// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age sample stddev group by country query test with sort") {
val frame = sql(s"""
| source = $testTable | stats stddev_samp(age) by country | sort country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(3.5355339059327378d, "Canada"), Row(28.284271247461902d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age sample stddev group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats stddev_samp(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(null, "Canada"), Row(28.284271247461902d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val filterExpr = Not(EqualTo(stateField, Literal("Ontario")))
val filterPlan = Filter(filterExpr, table)
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val productAlias = Alias(countryField, "country")()
val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl age population stddev") {
val frame = sql(s"""
| source = $testTable| stats stddev_pop(age)
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(19.803724397193573d))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val aggregatePlan =
Aggregate(Seq.empty, Seq(aggregateExpressions), table)
val expectedPlan = Project(star, aggregatePlan)
// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age population stddev group by country query test with sort") {
val frame = sql(s"""
| source = $testTable | stats stddev_pop(age) by country | sort country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(2.5d, "Canada"), Row(20d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age population stddev group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats stddev_pop(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(0d, "Canada"), Row(20d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val filterExpr = Not(EqualTo(stateField, Literal("Ontario")))
val filterPlan = Filter(filterExpr, table)
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val productAlias = Alias(countryField, "country")()
val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
Loading

0 comments on commit a834a0d

Please sign in to comment.