From b53a6993ed028671d33a6debe36574751b05d9de Mon Sep 17 00:00:00 2001 From: YANGDB Date: Mon, 11 Nov 2024 15:51:24 -0800 Subject: [PATCH] Ppl count approximate support (#884) * add functional approximation support for: - distinct count - top - rare Signed-off-by: YANGDB * update license and scalafmt Signed-off-by: YANGDB * update additional tests using APPROX_COUNT_DISTINCT Signed-off-by: YANGDB * add visitFirstChild(node, context) method for the PlanVisitor for simplify node inner child access visibility Signed-off-by: YANGDB * update inline documentation Signed-off-by: YANGDB * update according to PR comments - DISTINCT_COUNT_APPROX should be added to keywordsCanBeId Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 5 + docs/ppl-lang/ppl-rare-command.md | 10 +- docs/ppl-lang/ppl-top-command.md | 7 +- ...ntSparkPPLAggregationWithSpanITSuite.scala | 39 +++ .../FlintSparkPPLAggregationsITSuite.scala | 124 ++++++++ .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 270 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 3 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 9 +- .../sql/ast/tree/CountedAggregation.java | 16 ++ .../sql/ast/tree/RareAggregation.java | 10 +- .../sql/ast/tree/TopAggregation.java | 2 +- .../function/BuiltinFunctionName.java | 2 + .../sql/ppl/CatalystPlanContext.java | 3 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 68 +++-- .../opensearch/sql/ppl/parser/AstBuilder.java | 20 +- .../sql/ppl/parser/AstExpressionBuilder.java | 3 +- .../sql/ppl/utils/AggregatorTransformer.java | 2 + .../ppl/utils/BuiltinFunctionTransformer.java | 3 + ...ggregationQueriesTranslatorTestSuite.scala | 92 ++++++ ...TopAndRareQueriesTranslatorTestSuite.scala | 36 +++ 20 files changed, 668 insertions(+), 56 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 4ea564111..cb50431f6 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -177,6 +177,7 @@ source = table | where ispresent(a) | - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` - `source = table | stats distinct_count(c)` +- `source = table | stats distinct_count_approx(c)` - `source = table | stats stddev_samp(c)` - `source = table | stats stddev_pop(c)` - `source = table | stats percentile(c, 90)` @@ -202,6 +203,7 @@ source = table | where ispresent(a) | - `source = table | where a < 50 | eventstats avg(c) ` - `source = table | eventstats max(c) by b` - `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats count(c) by b | head 5` - `source = table | eventstats stddev_samp(c)` - `source = table | eventstats stddev_pop(c)` - `source = table | eventstats percentile(c, 90)` @@ -246,12 +248,15 @@ source = table | where ispresent(a) | - `source=accounts | rare gender` - `source=accounts | rare age by gender` +- `source=accounts | rare 5 age by gender` +- `source=accounts | rare_approx age by gender` #### **Top** [See additional command details](ppl-top-command.md) - `source=accounts | top gender` - `source=accounts | top 1 gender` +- `source=accounts | top_approx 5 gender` - `source=accounts | top 1 age by gender` #### **Parse** diff --git a/docs/ppl-lang/ppl-rare-command.md b/docs/ppl-lang/ppl-rare-command.md index 5645382f8..e3ad21f4e 100644 --- a/docs/ppl-lang/ppl-rare-command.md +++ b/docs/ppl-lang/ppl-rare-command.md @@ -6,10 +6,13 @@ Using ``rare`` command to find the least common tuple of values of all fields in **Note**: A maximum of 10 results is returned for each distinct tuple of values of the group-by fields. **Syntax** -`rare [by-clause]` +`rare [N] [by-clause]` +`rare_approx [N] [by-clause]` +* N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. +* rare_approx: approximate count of the rare (n) fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the least common values in a field @@ -19,6 +22,8 @@ The example finds least common gender of all the accounts. PPL query: os> source=accounts | rare gender; + os> source=accounts | rare_approx 10 gender; + os> source=accounts | rare_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -34,7 +39,8 @@ The example finds least common age of all the accounts group by gender. PPL query: - os> source=accounts | rare age by gender; + os> source=accounts | rare 5 age by gender; + os> source=accounts | rare_approx 5 age by gender; fetched rows / total rows = 4/4 +----------+-------+ | gender | age | diff --git a/docs/ppl-lang/ppl-top-command.md b/docs/ppl-lang/ppl-top-command.md index 4ba56f692..93d3a7148 100644 --- a/docs/ppl-lang/ppl-top-command.md +++ b/docs/ppl-lang/ppl-top-command.md @@ -6,11 +6,12 @@ Using ``top`` command to find the most common tuple of values of all fields in t ### Syntax `top [N] [by-clause]` +`top_approx [N] [by-clause]` * N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. - +* top_approx: approximate count of the (n) top fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the most common values in a field @@ -19,6 +20,7 @@ The example finds most common gender of all the accounts. PPL query: os> source=accounts | top gender; + os> source=accounts | top_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -33,7 +35,7 @@ The example finds most common gender of all the accounts. PPL query: - os> source=accounts | top 1 gender; + os> source=accounts | top_approx 1 gender; fetched rows / total rows = 1/1 +----------+ | gender | @@ -48,6 +50,7 @@ The example finds most common age of all the accounts group by gender. PPL query: os> source=accounts | top 1 age by gender; + os> source=accounts | top_approx 1 age by gender; fetched rows / total rows = 2/2 +----------+-------+ | gender | age | diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index 0bebca9b0..aa96d0991 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -494,4 +494,43 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans comparePlans(expectedPlan, logicalPlan, false) } + + test( + "create ppl simple distinct count age by span of interval of 10 years query with state filter test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Quebec' | stats distinct_count_approx(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(1, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index bcfe22764..2275c775c 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -835,6 +835,43 @@ class FlintSparkPPLAggregationsITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl simple country distinct_count using approximation ") { + val frame = sql(s""" + | source = $testTable| stats distinct_count_approx(country) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = true), + "distinct_count_approx(country)")() + + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl simple age distinct_count group by country query test with sort") { val frame = sql(s""" | source = $testTable | stats distinct_count(age) by country | sort country @@ -881,6 +918,53 @@ class FlintSparkPPLAggregationsITSuite s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") } + test( + "create ppl simple age distinct_count group by country query test with sort using approximation") { + val frame = sql(s""" + | source = $testTable | stats distinct_count_approx(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + test("create ppl simple age distinct_count group by country with state filter query test") { val frame = sql(s""" | source = $testTable | where state != 'Ontario' | stats distinct_count(age) by country @@ -920,6 +1004,46 @@ class FlintSparkPPLAggregationsITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test( + "create ppl simple age distinct_count group by country with state filter query test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats distinct_count_approx(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test("two-level stats") { val frame = sql(s""" | source = $testTable| stats avg(age) as avg_age by state, country | stats avg(avg_age) as avg_state_age by country diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index f10b6e2f5..4a1633035 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -84,6 +84,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl rare address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl rare address by age field query test") { val frame = sql(s""" | source = $testTable| rare address by age @@ -132,6 +174,104 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl rare 3 address by age field query test") { + val frame = sql(s""" + | source = $testTable| rare 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRow = Row(1, "Vancouver", 60) + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl rare 3 address by age field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl top address field query test") { val frame = sql(s""" | source = $testTable| top address @@ -179,6 +319,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| top_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 3 countries query test") { val frame = sql(s""" | source = $newTestTable| top 3 country @@ -226,6 +408,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top 3 countries query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField) + val aggregatePlan = + Aggregate( + Seq(countryField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 2 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation @@ -277,4 +501,50 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("create ppl top 2 countries by occupation field query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country by occupation + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 2c3344b3c..10b2e01b8 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -23,7 +23,9 @@ DEDUP: 'DEDUP'; SORT: 'SORT'; EVAL: 'EVAL'; HEAD: 'HEAD'; +TOP_APPROX: 'TOP_APPROX'; TOP: 'TOP'; +RARE_APPROX: 'RARE_APPROX'; RARE: 'RARE'; PARSE: 'PARSE'; METHOD: 'METHOD'; @@ -216,6 +218,7 @@ BIT_XOR_OP: '^'; AVG: 'AVG'; COUNT: 'COUNT'; DISTINCT_COUNT: 'DISTINCT_COUNT'; +DISTINCT_COUNT_APPROX: 'DISTINCT_COUNT_APPROX'; ESTDC: 'ESTDC'; ESTDC_ERROR: 'ESTDC_ERROR'; MAX: 'MAX'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 1cfd172f7..63efd8c6c 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -76,7 +76,9 @@ commandName | SORT | HEAD | TOP + | TOP_APPROX | RARE + | RARE_APPROX | EVAL | GROK | PARSE @@ -180,11 +182,11 @@ headCommand ; topCommand - : TOP (number = integerLiteral)? fieldList (byClause)? + : (TOP | TOP_APPROX) (number = integerLiteral)? fieldList (byClause)? ; rareCommand - : RARE fieldList (byClause)? + : (RARE | RARE_APPROX) (number = integerLiteral)? fieldList (byClause)? ; grokCommand @@ -400,7 +402,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall | percentileFunctionName = (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS valueExpression COMMA percent = integerLiteral RT_PRTHS # percentileFunctionCall ; @@ -1122,6 +1124,7 @@ keywordsCanBeId // AGGREGATIONS | statsFunctionName | DISTINCT_COUNT + | DISTINCT_COUNT_APPROX | PERCENTILE | PERCENTILE_APPROX | ESTDC diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java new file mode 100644 index 000000000..9a4aa5d7d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.Literal; + +import java.util.Optional; + +/** + * marker interface for numeric based count aggregation (specific number of returned results) + */ +public interface CountedAggregation { + Optional getResults(); +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java index d5a637f3d..8e454685a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java @@ -6,21 +6,29 @@ package org.opensearch.sql.ast.tree; import lombok.EqualsAndHashCode; +import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */ @ToString +@Getter @EqualsAndHashCode(callSuper = true) -public class RareAggregation extends Aggregation { +public class RareAggregation extends Aggregation implements CountedAggregation{ + private final Optional results; + /** Aggregation Constructor without span and argument. */ public RareAggregation( + Optional results, List aggExprList, List sortExprList, List groupExprList) { super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + this.results = results; } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java index e87a3b0b0..90aac5838 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java @@ -20,7 +20,7 @@ @ToString @Getter @EqualsAndHashCode(callSuper = true) -public class TopAggregation extends Aggregation { +public class TopAggregation extends Aggregation implements CountedAggregation { private final Optional results; /** Aggregation Constructor without span and argument. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 1959d0f6d..f039bf47f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -185,6 +185,7 @@ public enum BuiltinFunctionName { NESTED(FunctionName.of("nested")), PERCENTILE(FunctionName.of("percentile")), PERCENTILE_APPROX(FunctionName.of("percentile_approx")), + APPROX_COUNT_DISTINCT(FunctionName.of("approx_count_distinct")), /** Text Functions. */ ASCII(FunctionName.of("ascii")), @@ -332,6 +333,7 @@ public FunctionName getName() { .put("take", BuiltinFunctionName.TAKE) .put("percentile", BuiltinFunctionName.PERCENTILE) .put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX) + .put("approx_count_distinct", BuiltinFunctionName.APPROX_COUNT_DISTINCT) .build(); public static Optional of(String str) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 53dc17576..1621e65d5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -26,6 +26,7 @@ import java.util.Stack; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -187,7 +188,7 @@ public LogicalPlan reduce(BiFunction tran return result; }).orElse(getPlan())); } - + /** * apply for each plan with the given function * diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index d2ee46ae6..00a7905f0 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -14,13 +14,6 @@ import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; -import org.apache.spark.sql.catalyst.expressions.In$; -import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.InSubquery$; -import org.apache.spark.sql.catalyst.expressions.LessThan; -import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.ListQuery$; -import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; @@ -38,6 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; @@ -53,6 +47,7 @@ import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Correlation; +import org.opensearch.sql.ast.tree.CountedAggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; @@ -72,7 +67,6 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; -import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -90,6 +84,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -132,6 +127,10 @@ public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { return node.getPlan().accept(this, context); } + public LogicalPlan visitFirstChild(Node node, CatalystPlanContext context) { + return node.getChild().get(0).accept(this, context); + } + @Override public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { node.getStatement().accept(this, context); @@ -140,6 +139,7 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { + //relations doesnt have a visitFirstChild call since its the leaf of the AST tree if (node instanceof DescribeRelation) { TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName()); return context.with( @@ -159,7 +159,7 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { @Override public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { Expression conditionExpression = visitExpression(node.getCondition(), context); Optional innerConditionExpression = context.popNamedParseExpressions(); @@ -173,8 +173,7 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { */ @Override public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); return context.apply( searchSide -> { LogicalPlan lookupTable = node.getLookupRelation().accept(this, context); Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context); @@ -230,8 +229,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { @Override public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); node.getSortByField() .ifPresent(sortField -> { Expression sortFieldExpression = visitExpression(sortField, context); @@ -254,7 +252,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); context.reduce((left, right) -> { visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); Seq fields = context.retainAllNamedParseExpressions(e -> e); @@ -272,7 +270,7 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex @Override public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); Optional joinCondition = node.getJoinCondition() @@ -285,7 +283,7 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { @Override public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { var alias = org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$.MODULE$.apply(node.getAlias(), p); context.withSubqueryAlias(alias); @@ -296,7 +294,7 @@ public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext co @Override public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); if (!groupExpList.isEmpty()) { @@ -327,9 +325,9 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan)); } //visit TopAggregation results limit - if ((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) { + if ((node instanceof CountedAggregation) && ((CountedAggregation) node).getResults().isPresent()) { context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( - ((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); + ((CountedAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); } return logicalPlan; } @@ -342,7 +340,7 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { @Override public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); List partitionExpList = visitExpressionList(node.getPartExprList(), context); @@ -372,10 +370,11 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { + //update plan's context prior to visiting node children if (node.isExcluded()) { List intersect = context.getProjectedFields().stream() - .filter(node.getProjectList()::contains) - .collect(Collectors.toList()); + .filter(node.getProjectList()::contains) + .collect(Collectors.toList()); if (!intersect.isEmpty()) { // Fields in parent projection, but they have be excluded in child. For example, // source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved" @@ -384,7 +383,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { } else { context.withProjectedFields(node.getProjectList()); } - LogicalPlan child = node.getChild().get(0).accept(this, context); + LogicalPlan child = visitFirstChild(node, context); visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions @@ -405,7 +404,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { @Override public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); visitFieldList(node.getSortList(), context); Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); @@ -413,20 +412,20 @@ public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { @Override public LogicalPlan visitHead(Head node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( node.getSize(), DataTypes.IntegerType), p)); } @Override public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { - fieldSummary.getChild().get(0).accept(this, context); + visitFirstChild(fieldSummary, context); return FieldSummaryTransformer.translate(fieldSummary, context); } @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { - fillNull.getChild().get(0).accept(this, context); + visitFirstChild(fillNull, context); List aliases = new ArrayList<>(); for(FillNull.NullableFieldFill nullableFieldFill : fillNull.getNullableFieldFills()) { Field field = nullableFieldFill.getNullableFieldReference(); @@ -457,7 +456,7 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) @Override public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { - flatten.getChild().get(0).accept(this, context); + visitFirstChild(flatten, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); @@ -471,7 +470,7 @@ public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { @Override public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); @@ -507,7 +506,7 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); Expression sourceField = visitExpression(node.getSourceField(), context); ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); @@ -517,7 +516,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { @Override public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); @@ -534,7 +533,7 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aliases = new ArrayList<>(); List letExpressions = node.getExpressionList(); for (Let let : letExpressions) { @@ -548,8 +547,7 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { List expressionList = visitExpressionList(aliases, context); Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - return child; + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); } @Override @@ -574,7 +572,7 @@ public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext @Override public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List options = node.getOptions(); Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index f6581016f..7d1cc072b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -432,8 +432,9 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.TOP_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -458,14 +459,12 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) .collect(Collectors.toList())) .orElse(emptyList()) ); - UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); - TopAggregation aggregation = - new TopAggregation( - Optional.ofNullable((Literal) unresolvedPlan), + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new TopAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } /** Fieldsummary command. */ @@ -479,8 +478,9 @@ public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryC public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.RARE_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -505,12 +505,12 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct .collect(Collectors.toList())) .orElse(emptyList()) ); - RareAggregation aggregation = - new RareAggregation( + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new RareAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 4b7c8a1c1..36d9f9577 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -211,7 +211,8 @@ public UnresolvedExpression visitCountAllFunctionCall(OpenSearchPPLParser.CountA @Override public UnresolvedExpression visitDistinctCountFunctionCall(OpenSearchPPLParser.DistinctCountFunctionCallContext ctx) { - return new AggregateFunction("count", visit(ctx.valueExpression()), true); + String funcName = ctx.DISTINCT_COUNT_APPROX()!=null ? "approx_count_distinct" :"count"; + return new AggregateFunction(funcName, visit(ctx.valueExpression()), true); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java index 9788ac1bc..c06f37aa3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java @@ -57,6 +57,8 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); + case APPROX_COUNT_DISTINCT: + return new UnresolvedFunction(seq("APPROX_COUNT_DISTINCT"), seq(arg), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index 0b0fb8314..0a4f19b53 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -26,8 +26,10 @@ import java.util.Map; import java.util.function.Function; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLLexer.DISTINCT_COUNT_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.APPROX_COUNT_DISTINCT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; @@ -109,6 +111,7 @@ public interface BuiltinFunctionTransformer { .put(TO_JSON_STRING, "to_json") .put(JSON_KEYS, "json_object_keys") .put(JSON_EXTRACT, "get_json_object") + .put(APPROX_COUNT_DISTINCT, "approx_count_distinct") .build(); /** diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 9946bff6a..42cc7ed10 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -754,6 +754,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test approx distinct count product group by brand sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(product) by brand | sort brand"), + context) + val star = Seq(UnresolvedStar(None)) + val brandField = UnresolvedAttribute("brand") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(brandField, "brand")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(productField), isDistinct = true), + "distinct_count_approx(product)")() + val brandAlias = Alias(brandField, "brand")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, brandAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(brandField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count product with alias and filter") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -803,6 +831,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count age by span of interval of 10 years query with sort using approximation ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(age) by span(age, 10) as age_span | sort age"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count status by week window and group by status with limit") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -838,6 +894,42 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count status by week window and group by status with limit using approximation") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(status) by span(@timestamp, 1w) as status_count_by_week, status | head 100"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusCount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "status_count_by_week")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(statusCount), isDistinct = true), + "distinct_count_approx(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + table) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test("multiple stats - test average price and average age") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 792a2dee6..106cba93a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -59,6 +59,42 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test simple rare command with a single field approximation") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=accounts | rare_approx address"), context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, tableRelation) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test simple rare command with a by field test") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext