From fe1113496c309d40b05eee7e84af878f2d04d8d7 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 12 Sep 2023 15:53:54 -0700 Subject: [PATCH] add Max,Min,Count,Sum aggregation functions support Signed-off-by: YANGDB --- .../flint/spark/FlintSparkPPLITSuite.scala | 158 +++++++++++++++++- .../sql/ppl/utils/AggregatorTranslator.java | 22 +-- ...ggregationQueriesTranslatorTestSuite.scala | 1 + 3 files changed, 165 insertions(+), 16 deletions(-) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index 820ca3f16..49fc2879f 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -153,7 +153,7 @@ class FlintSparkPPLITSuite // Compare the two plans assert(expectedPlan === logicalPlan) } - + test("create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { val frame = sql( s""" @@ -184,7 +184,7 @@ class FlintSparkPPLITSuite // Compare the two plans assert(expectedPlan === logicalPlan) } - + test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { val frame = sql( s""" @@ -438,6 +438,160 @@ class FlintSparkPPLITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test("create ppl simple age max group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats max(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(70, "USA"), + Row(25, "Canada"), + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age min group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats min(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(30, "USA"), + Row(20, "Canada"), + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age sum group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats sum(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(100L, "USA"), + Row(45L, "Canada"), + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age count group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats count(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(2L, "Canada"), + Row(2L, "USA"), + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}" + ) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}" + ) + } + test("create ppl simple age avg group by country with state filter query test ") { val frame = sql( s""" diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 25daa5590..7dcebe8dc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -23,26 +23,20 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction // Additional aggregation function operators will be added here switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { case MAX: - break; + return new UnresolvedFunction(asScalaBuffer(of("MAX")).toSeq(), + asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); case MIN: - break; + return new UnresolvedFunction(asScalaBuffer(of("MIN")).toSeq(), + asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); case AVG: return new UnresolvedFunction(asScalaBuffer(of("AVG")).toSeq(), asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); case COUNT: - break; + return new UnresolvedFunction(asScalaBuffer(of("COUNT")).toSeq(), + asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); case SUM: - break; - case STDDEV_POP: - break; - case STDDEV_SAMP: - break; - case TAKE: - break; - case VARPOP: - break; - case VARSAMP: - break; + return new UnresolvedFunction(asScalaBuffer(of("SUM")).toSeq(), + asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 38984f516..000f77afc 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -57,6 +57,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite assertEquals(logPlan, "source=[table] | stats avg(price) by product | fields + *") assertEquals(compareByString(expectedPlan), compareByString(context.getPlan)) } + test("test average price group by product and filter") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext