Skip to content

Commit

Permalink
add Max,Min,Count,Sum aggregation functions support
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Sep 12, 2023
1 parent d5f33b0 commit fe11134
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class FlintSparkPPLITSuite
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") {
val frame = sql(
s"""
Expand Down Expand Up @@ -184,7 +184,7 @@ class FlintSparkPPLITSuite
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") {
val frame = sql(
s"""
Expand Down Expand Up @@ -438,6 +438,160 @@ class FlintSparkPPLITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age max group by country query test ") {
val frame = sql(
s"""
| source = $testTable| stats max(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(70, "USA"),
Row(25, "Canada"),
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions = Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age min group by country query test ") {
val frame = sql(
s"""
| source = $testTable| stats min(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(30, "USA"),
Row(20, "Canada"),
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions = Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age sum group by country query test ") {
val frame = sql(
s"""
| source = $testTable| stats sum(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(100L, "USA"),
Row(45L, "Canada"),
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions = Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age count group by country query test ") {
val frame = sql(
s"""
| source = $testTable| stats count(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(2L, "Canada"),
Row(2L, "USA"),
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}"
)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}"
)
}

test("create ppl simple age avg group by country with state filter query test ") {
val frame = sql(
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,20 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction
// Additional aggregation function operators will be added here
switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) {
case MAX:
break;
return new UnresolvedFunction(asScalaBuffer(of("MAX")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
case MIN:
break;
return new UnresolvedFunction(asScalaBuffer(of("MIN")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
case AVG:
return new UnresolvedFunction(asScalaBuffer(of("AVG")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
case COUNT:
break;
return new UnresolvedFunction(asScalaBuffer(of("COUNT")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
case SUM:
break;
case STDDEV_POP:
break;
case STDDEV_SAMP:
break;
case TAKE:
break;
case VARPOP:
break;
case VARSAMP:
break;
return new UnresolvedFunction(asScalaBuffer(of("SUM")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
}
throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
assertEquals(logPlan, "source=[table] | stats avg(price) by product | fields + *")
assertEquals(compareByString(expectedPlan), compareByString(context.getPlan))
}

test("test average price group by product and filter") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
Expand Down

0 comments on commit fe11134

Please sign in to comment.