From 8984490d04395e64918cf48f4ec5f0595cd6fa0d Mon Sep 17 00:00:00 2001 From: YANGDB Date: Mon, 11 Nov 2024 15:51:24 -0800 Subject: [PATCH] Ppl count approximate support (#884) * add functional approximation support for: - distinct count - top - rare Signed-off-by: YANGDB * update license and scalafmt Signed-off-by: YANGDB * update additional tests using APPROX_COUNT_DISTINCT Signed-off-by: YANGDB * add visitFirstChild(node, context) method for the PlanVisitor for simplify node inner child access visibility Signed-off-by: YANGDB * update inline documentation Signed-off-by: YANGDB * update according to PR comments - DISTINCT_COUNT_APPROX should be added to keywordsCanBeId Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 5 + docs/ppl-lang/ppl-rare-command.md | 10 +- docs/ppl-lang/ppl-top-command.md | 7 +- ...ntSparkPPLAggregationWithSpanITSuite.scala | 39 +++ .../FlintSparkPPLAggregationsITSuite.scala | 124 ++++++++ .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 270 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 3 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 9 +- .../sql/ast/tree/CountedAggregation.java | 16 ++ .../sql/ast/tree/RareAggregation.java | 10 +- .../sql/ast/tree/TopAggregation.java | 2 +- .../function/BuiltinFunctionName.java | 2 + .../sql/ppl/CatalystPlanContext.java | 3 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 20 +- .../sql/ppl/parser/AstExpressionBuilder.java | 3 +- .../sql/ppl/utils/AggregatorTransformer.java | 2 + .../ppl/utils/BuiltinFunctionTransformer.java | 3 + ...ggregationQueriesTranslatorTestSuite.scala | 92 ++++++ ...TopAndRareQueriesTranslatorTestSuite.scala | 36 +++ 19 files changed, 635 insertions(+), 21 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 4ea564111..cb50431f6 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -177,6 +177,7 @@ source = table | where ispresent(a) | - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` - `source = table | stats distinct_count(c)` +- `source = table | stats distinct_count_approx(c)` - `source = table | stats stddev_samp(c)` - `source = table | stats stddev_pop(c)` - `source = table | stats percentile(c, 90)` @@ -202,6 +203,7 @@ source = table | where ispresent(a) | - `source = table | where a < 50 | eventstats avg(c) ` - `source = table | eventstats max(c) by b` - `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats count(c) by b | head 5` - `source = table | eventstats stddev_samp(c)` - `source = table | eventstats stddev_pop(c)` - `source = table | eventstats percentile(c, 90)` @@ -246,12 +248,15 @@ source = table | where ispresent(a) | - `source=accounts | rare gender` - `source=accounts | rare age by gender` +- `source=accounts | rare 5 age by gender` +- `source=accounts | rare_approx age by gender` #### **Top** [See additional command details](ppl-top-command.md) - `source=accounts | top gender` - `source=accounts | top 1 gender` +- `source=accounts | top_approx 5 gender` - `source=accounts | top 1 age by gender` #### **Parse** diff --git a/docs/ppl-lang/ppl-rare-command.md b/docs/ppl-lang/ppl-rare-command.md index 5645382f8..e3ad21f4e 100644 --- a/docs/ppl-lang/ppl-rare-command.md +++ b/docs/ppl-lang/ppl-rare-command.md @@ -6,10 +6,13 @@ Using ``rare`` command to find the least common tuple of values of all fields in **Note**: A maximum of 10 results is returned for each distinct tuple of values of the group-by fields. **Syntax** -`rare [by-clause]` +`rare [N] [by-clause]` +`rare_approx [N] [by-clause]` +* N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. +* rare_approx: approximate count of the rare (n) fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the least common values in a field @@ -19,6 +22,8 @@ The example finds least common gender of all the accounts. PPL query: os> source=accounts | rare gender; + os> source=accounts | rare_approx 10 gender; + os> source=accounts | rare_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -34,7 +39,8 @@ The example finds least common age of all the accounts group by gender. PPL query: - os> source=accounts | rare age by gender; + os> source=accounts | rare 5 age by gender; + os> source=accounts | rare_approx 5 age by gender; fetched rows / total rows = 4/4 +----------+-------+ | gender | age | diff --git a/docs/ppl-lang/ppl-top-command.md b/docs/ppl-lang/ppl-top-command.md index 4ba56f692..93d3a7148 100644 --- a/docs/ppl-lang/ppl-top-command.md +++ b/docs/ppl-lang/ppl-top-command.md @@ -6,11 +6,12 @@ Using ``top`` command to find the most common tuple of values of all fields in t ### Syntax `top [N] [by-clause]` +`top_approx [N] [by-clause]` * N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. - +* top_approx: approximate count of the (n) top fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the most common values in a field @@ -19,6 +20,7 @@ The example finds most common gender of all the accounts. PPL query: os> source=accounts | top gender; + os> source=accounts | top_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -33,7 +35,7 @@ The example finds most common gender of all the accounts. PPL query: - os> source=accounts | top 1 gender; + os> source=accounts | top_approx 1 gender; fetched rows / total rows = 1/1 +----------+ | gender | @@ -48,6 +50,7 @@ The example finds most common age of all the accounts group by gender. PPL query: os> source=accounts | top 1 age by gender; + os> source=accounts | top_approx 1 age by gender; fetched rows / total rows = 2/2 +----------+-------+ | gender | age | diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index 0bebca9b0..aa96d0991 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -494,4 +494,43 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans comparePlans(expectedPlan, logicalPlan, false) } + + test( + "create ppl simple distinct count age by span of interval of 10 years query with state filter test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Quebec' | stats distinct_count_approx(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(1, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index bcfe22764..2275c775c 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -835,6 +835,43 @@ class FlintSparkPPLAggregationsITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl simple country distinct_count using approximation ") { + val frame = sql(s""" + | source = $testTable| stats distinct_count_approx(country) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = true), + "distinct_count_approx(country)")() + + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl simple age distinct_count group by country query test with sort") { val frame = sql(s""" | source = $testTable | stats distinct_count(age) by country | sort country @@ -881,6 +918,53 @@ class FlintSparkPPLAggregationsITSuite s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") } + test( + "create ppl simple age distinct_count group by country query test with sort using approximation") { + val frame = sql(s""" + | source = $testTable | stats distinct_count_approx(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + test("create ppl simple age distinct_count group by country with state filter query test") { val frame = sql(s""" | source = $testTable | where state != 'Ontario' | stats distinct_count(age) by country @@ -920,6 +1004,46 @@ class FlintSparkPPLAggregationsITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test( + "create ppl simple age distinct_count group by country with state filter query test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats distinct_count_approx(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test("two-level stats") { val frame = sql(s""" | source = $testTable| stats avg(age) as avg_age by state, country | stats avg(avg_age) as avg_state_age by country diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index f10b6e2f5..4a1633035 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -84,6 +84,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl rare address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl rare address by age field query test") { val frame = sql(s""" | source = $testTable| rare address by age @@ -132,6 +174,104 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl rare 3 address by age field query test") { + val frame = sql(s""" + | source = $testTable| rare 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRow = Row(1, "Vancouver", 60) + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl rare 3 address by age field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl top address field query test") { val frame = sql(s""" | source = $testTable| top address @@ -179,6 +319,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| top_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 3 countries query test") { val frame = sql(s""" | source = $newTestTable| top 3 country @@ -226,6 +408,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top 3 countries query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField) + val aggregatePlan = + Aggregate( + Seq(countryField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 2 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation @@ -277,4 +501,50 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("create ppl top 2 countries by occupation field query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country by occupation + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index e205d03d2..02818c1fb 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -23,7 +23,9 @@ DEDUP: 'DEDUP'; SORT: 'SORT'; EVAL: 'EVAL'; HEAD: 'HEAD'; +TOP_APPROX: 'TOP_APPROX'; TOP: 'TOP'; +RARE_APPROX: 'RARE_APPROX'; RARE: 'RARE'; PARSE: 'PARSE'; METHOD: 'METHOD'; @@ -216,6 +218,7 @@ BIT_XOR_OP: '^'; AVG: 'AVG'; COUNT: 'COUNT'; DISTINCT_COUNT: 'DISTINCT_COUNT'; +DISTINCT_COUNT_APPROX: 'DISTINCT_COUNT_APPROX'; ESTDC: 'ESTDC'; ESTDC_ERROR: 'ESTDC_ERROR'; MAX: 'MAX'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index d58868ab2..ade568bc7 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -74,7 +74,9 @@ commandName | SORT | HEAD | TOP + | TOP_APPROX | RARE + | RARE_APPROX | EVAL | GROK | PARSE @@ -176,11 +178,11 @@ headCommand ; topCommand - : TOP (number = integerLiteral)? fieldList (byClause)? + : (TOP | TOP_APPROX) (number = integerLiteral)? fieldList (byClause)? ; rareCommand - : RARE fieldList (byClause)? + : (RARE | RARE_APPROX) (number = integerLiteral)? fieldList (byClause)? ; grokCommand @@ -381,7 +383,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall | percentileFunctionName = (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS valueExpression COMMA percent = integerLiteral RT_PRTHS # percentileFunctionCall ; @@ -1118,6 +1120,7 @@ keywordsCanBeId // AGGREGATIONS | statsFunctionName | DISTINCT_COUNT + | DISTINCT_COUNT_APPROX | PERCENTILE | PERCENTILE_APPROX | ESTDC diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java new file mode 100644 index 000000000..9a4aa5d7d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.Literal; + +import java.util.Optional; + +/** + * marker interface for numeric based count aggregation (specific number of returned results) + */ +public interface CountedAggregation { + Optional getResults(); +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java index d5a637f3d..8e454685a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java @@ -6,21 +6,29 @@ package org.opensearch.sql.ast.tree; import lombok.EqualsAndHashCode; +import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */ @ToString +@Getter @EqualsAndHashCode(callSuper = true) -public class RareAggregation extends Aggregation { +public class RareAggregation extends Aggregation implements CountedAggregation{ + private final Optional results; + /** Aggregation Constructor without span and argument. */ public RareAggregation( + Optional results, List aggExprList, List sortExprList, List groupExprList) { super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + this.results = results; } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java index e87a3b0b0..90aac5838 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java @@ -20,7 +20,7 @@ @ToString @Getter @EqualsAndHashCode(callSuper = true) -public class TopAggregation extends Aggregation { +public class TopAggregation extends Aggregation implements CountedAggregation { private final Optional results; /** Aggregation Constructor without span and argument. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 1959d0f6d..f039bf47f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -185,6 +185,7 @@ public enum BuiltinFunctionName { NESTED(FunctionName.of("nested")), PERCENTILE(FunctionName.of("percentile")), PERCENTILE_APPROX(FunctionName.of("percentile_approx")), + APPROX_COUNT_DISTINCT(FunctionName.of("approx_count_distinct")), /** Text Functions. */ ASCII(FunctionName.of("ascii")), @@ -332,6 +333,7 @@ public FunctionName getName() { .put("take", BuiltinFunctionName.TAKE) .put("percentile", BuiltinFunctionName.PERCENTILE) .put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX) + .put("approx_count_distinct", BuiltinFunctionName.APPROX_COUNT_DISTINCT) .build(); public static Optional of(String str) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 53dc17576..1621e65d5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -26,6 +26,7 @@ import java.util.Stack; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -187,7 +188,7 @@ public LogicalPlan reduce(BiFunction tran return result; }).orElse(getPlan())); } - + /** * apply for each plan with the given function * diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index f6581016f..7d1cc072b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -432,8 +432,9 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.TOP_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -458,14 +459,12 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) .collect(Collectors.toList())) .orElse(emptyList()) ); - UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); - TopAggregation aggregation = - new TopAggregation( - Optional.ofNullable((Literal) unresolvedPlan), + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new TopAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } /** Fieldsummary command. */ @@ -479,8 +478,9 @@ public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryC public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.RARE_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -505,12 +505,12 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct .collect(Collectors.toList())) .orElse(emptyList()) ); - RareAggregation aggregation = - new RareAggregation( + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new RareAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index bf029c49c..089c92b13 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -213,7 +213,8 @@ public UnresolvedExpression visitCountAllFunctionCall(OpenSearchPPLParser.CountA @Override public UnresolvedExpression visitDistinctCountFunctionCall(OpenSearchPPLParser.DistinctCountFunctionCallContext ctx) { - return new AggregateFunction("count", visit(ctx.valueExpression()), true); + String funcName = ctx.DISTINCT_COUNT_APPROX()!=null ? "approx_count_distinct" :"count"; + return new AggregateFunction(funcName, visit(ctx.valueExpression()), true); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java index 9788ac1bc..c06f37aa3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java @@ -57,6 +57,8 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); + case APPROX_COUNT_DISTINCT: + return new UnresolvedFunction(seq("APPROX_COUNT_DISTINCT"), seq(arg), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index 0b0fb8314..0a4f19b53 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -26,8 +26,10 @@ import java.util.Map; import java.util.function.Function; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLLexer.DISTINCT_COUNT_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.APPROX_COUNT_DISTINCT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; @@ -109,6 +111,7 @@ public interface BuiltinFunctionTransformer { .put(TO_JSON_STRING, "to_json") .put(JSON_KEYS, "json_object_keys") .put(JSON_EXTRACT, "get_json_object") + .put(APPROX_COUNT_DISTINCT, "approx_count_distinct") .build(); /** diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 9946bff6a..42cc7ed10 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -754,6 +754,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test approx distinct count product group by brand sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(product) by brand | sort brand"), + context) + val star = Seq(UnresolvedStar(None)) + val brandField = UnresolvedAttribute("brand") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(brandField, "brand")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(productField), isDistinct = true), + "distinct_count_approx(product)")() + val brandAlias = Alias(brandField, "brand")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, brandAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(brandField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count product with alias and filter") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -803,6 +831,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count age by span of interval of 10 years query with sort using approximation ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(age) by span(age, 10) as age_span | sort age"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count status by week window and group by status with limit") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -838,6 +894,42 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count status by week window and group by status with limit using approximation") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(status) by span(@timestamp, 1w) as status_count_by_week, status | head 100"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusCount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "status_count_by_week")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(statusCount), isDistinct = true), + "distinct_count_approx(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + table) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test("multiple stats - test average price and average age") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 792a2dee6..106cba93a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -59,6 +59,42 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test simple rare command with a single field approximation") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=accounts | rare_approx address"), context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, tableRelation) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test simple rare command with a by field test") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext