From c2882222ddef94217f6eb90533d6d2a03a02cf91 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 29 Oct 2024 08:49:46 +0800 Subject: [PATCH] All-fields as an arg of aggregator count() can be resolved after other fields (#814) Signed-off-by: Lantao Jin --- .../FlintSparkPPLAggregationsITSuite.scala | 162 ++++++++++++++++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 6 +- ...ggregationQueriesTranslatorTestSuite.scala | 54 ++++++ 3 files changed, 217 insertions(+), 5 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index 55d3d0709..bcfe22764 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -1122,4 +1122,166 @@ class FlintSparkPPLAggregationsITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test count() at the first of stats clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats count() as cnt, sum(a) as sum, avg(a) as avg + | """.stripMargin) + assertSameRows(Seq(Row(4, 4, 1.0)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(count, sum, avg), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() in the middle of stats clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, count() as cnt, avg(a) as avg + | """.stripMargin) + assertSameRows(Seq(Row(4, 4, 1.0)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(sum, count, avg), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() at the end of stats clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt + | """.stripMargin) + assertSameRows(Seq(Row(4, 1.0, 4)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(sum, avg, count), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() at the first of stats by clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats count() as cnt, sum(a) as sum, avg(a) as avg by country + | """.stripMargin) + assertSameRows(Seq(Row(2, 2, 1.0, "Canada"), Row(2, 2, 1.0, "USA")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(count, sum, avg, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() in the middle of stats by clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, count() as cnt, avg(a) as avg by country + | """.stripMargin) + assertSameRows(Seq(Row(2, 2, 1.0, "Canada"), Row(2, 2, 1.0, "USA")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(sum, count, avg, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() at the end of stats by clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt by country + | """.stripMargin) + assertSameRows(Seq(Row(2, 1.0, 2, "Canada"), Row(2, 1.0, 2, "USA")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(sum, avg, count, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index ef806d3ef..8bfc01c03 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -672,11 +672,7 @@ public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContex @Override public Expression visitAllFields(AllFields node, CatalystPlanContext context) { - // Case of aggregation step - no start projection can be added - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - } + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); return context.getNamedParseExpressions().peek(); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 03d7f0ab0..9946bff6a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -959,4 +959,58 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test("test count() as the last aggregator in stats clause") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt"), + context) + val tableRelation = UnresolvedRelation(Seq("table")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), tableRelation) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(sum, avg, count), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test count() as the last aggregator in stats by clause") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt by country"), + context) + val tableRelation = UnresolvedRelation(Seq("table")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), tableRelation) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(sum, avg, count, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } }