From b407a06b52aea0aa7b82e41ba2f0128bcb35adf5 Mon Sep 17 00:00:00 2001 From: Hendrik Saly Date: Wed, 21 Aug 2024 18:07:29 +0200 Subject: [PATCH] Fix distinct_count ppl stats function (#548) * Fix distinct_count, Add test for distinct_count Signed-off-by: Hendrik Saly * add tests for stats: distinct_count (#1) add tests for stats: distinct_count Signed-off-by: Kacper Trochimiak * Add functions to readme Signed-off-by: Hendrik Saly * Fix comparePlans Signed-off-by: Hendrik Saly --------- Signed-off-by: Hendrik Saly Signed-off-by: Kacper Trochimiak Co-authored-by: Kacper Trochimiak --- ...ntSparkPPLAggregationWithSpanITSuite.scala | 46 +++++++ .../FlintSparkPPLAggregationsITSuite.scala | 122 ++++++++++++++++++ ppl-spark-integration/README.md | 1 + .../sql/ppl/utils/AggregatorTranslator.java | 10 +- ...ggregationQueriesTranslatorTestSuite.scala | 120 ++++++++++++++++- 5 files changed, 293 insertions(+), 6 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index 3ffe05e81..0bebca9b0 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -448,4 +448,50 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + /** + * | age_span | count_age | + * |:---------|----------:| + * | 20 | 1 | + * | 30 | 1 | + * | 70 | 1 | + */ + test( + "create ppl simple distinct count age by span of interval of 10 years query with state filter test ") { + val frame = sql(s""" + | source = $testTable | where state != 'Quebec' | stats distinct_count(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(1, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true), + "distinct_count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index c638cd750..7a3a886dd 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -797,4 +797,126 @@ class FlintSparkPPLAggregationsITSuite } assert(thrown.getMessage === "Unsupported value 'percent': -4 (expected: >= 0 <= 100))") } + + test("create ppl simple country distinct_count ") { + val frame = sql(s""" + | source = $testTable| stats distinct_count(country) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = true), + "distinct_count(country)")() + + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl simple age distinct_count group by country query test with sort") { + val frame = sql(s""" + | source = $testTable | stats distinct_count(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true), + "distinct_count(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl simple age distinct_count group by country with state filter query test") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats distinct_count(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true), + "distinct_count(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 3ea0a477a..bc8a96c52 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -270,6 +270,7 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | where a < 50 | stats avg(c) ` - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` + - `source = table | stats distinct_count(c)` - `source = table | stats stddev_samp(c)` - `source = table | stats stddev_pop(c)` - `source = table | stats percentile(c, 90)` diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 244f71f09..3c367a948 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -34,15 +34,15 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction // Additional aggregation function operators will be added here switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { case MAX: - return new UnresolvedFunction(seq("MAX"), seq(arg),false, empty(),false); + return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false); case MIN: - return new UnresolvedFunction(seq("MIN"), seq(arg),false, empty(),false); + return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false); case AVG: - return new UnresolvedFunction(seq("AVG"), seq(arg),false, empty(),false); + return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false); case COUNT: - return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false); + return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false); case SUM: - return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false); + return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false); case STDDEV_POP: return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); case STDDEV_SAMP: diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 457faeaa3..e47ebedbf 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -758,4 +758,122 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test distinct count product group by brand sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count(product) by brand | sort brand", + false), + context) + val star = Seq(UnresolvedStar(None)) + val brandField = UnresolvedAttribute("brand") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(brandField, "brand")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(productField), isDistinct = true), + "distinct_count(product)")() + val brandAlias = Alias(brandField, "brand")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, brandAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(brandField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test distinct count product with alias and filter") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table price > 100 | stats distinct_count(product) as dc_product", + false), + context) + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(productField), isDistinct = true), + "dc_product")()) + val filterExpr = GreaterThan(priceField, Literal(100)) + val filterPlan = Filter(filterExpr, tableRelation) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test distinct count age by span of interval of 10 years query with sort ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count(age) by span(age, 10) as age_span | sort age", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = true), + "distinct_count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test distinct count status by week window and group by status with limit") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count(status) by span(@timestamp, 1w) as status_count_by_week, status | head 100", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusCount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "status_count_by_week")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(statusCount), isDistinct = true), + "distinct_count(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + table) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + }