Skip to content

Commit

Permalink
add AggregatorTranslator support
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Sep 11, 2023
1 parent 89dd114 commit 65f4372
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,42 @@

package org.opensearch.flint.spark

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{EqualTo, GreaterThan, LessThanOrEqual, Literal, Not}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GreaterThan, LessThanOrEqual, Literal, Not}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}

class FlintSparkPPLITSuite
extends QueryTest
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "default.flint_ppl_tst"
private val testTable = "default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
sql(s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT
| )
| USING CSV
| OPTIONS (
| header 'false',
| delimiter '\t'
| )
| PARTITIONED BY (
| year INT,
| month INT
| )
|""".stripMargin)
sql(
s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT
| )
| USING CSV
| OPTIONS (
| header 'false',
| delimiter '\t'
| )
| PARTITIONED BY (
| year INT,
| month INT
| )
|""".stripMargin)

// Insert data
sql(
Expand All @@ -60,7 +62,7 @@ class FlintSparkPPLITSuite
job.awaitTermination()
}
}

test("create ppl simple query with start fields result test") {
val frame = sql(
s"""
Expand All @@ -82,15 +84,15 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default","flint_ppl_tst")))
val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple query two with fields result test") {
val frame = sql(
s"""
| source = $testTable | fields name, age
| source = $testTable| fields name, age
| """.stripMargin)

// Retrieve the results
Expand All @@ -108,12 +110,12 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan = Project(Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age")),
UnresolvedRelation(Seq("default","flint_ppl_tst")))
val expectedPlan: LogicalPlan = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age literal equal filter query with two fields result test") {
val frame = sql(
s"""
Expand All @@ -134,15 +136,15 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("default","flint_ppl_tst"))
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val filterExpr = EqualTo(UnresolvedAttribute("age"), Literal(25))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age"))
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age literal greater than filter query with two fields result test") {
val frame = sql(
s"""
Expand All @@ -162,15 +164,15 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("default","flint_ppl_tst"))
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val filterExpr = GreaterThan(UnresolvedAttribute("age"), Literal(25))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age"))
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}
}

test("create ppl simple age literal smaller than equals filter query with two fields result test") {
val frame = sql(
s"""
Expand All @@ -191,15 +193,15 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("default","flint_ppl_tst"))
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age"))
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple name literal equal filter query with two fields result test") {
val frame = sql(
s"""
Expand All @@ -218,15 +220,15 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("default","flint_ppl_tst"))
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val filterExpr = EqualTo(UnresolvedAttribute("name"), Literal("Jake"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age"))
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}
}

test("create ppl simple name literal not equal filter query with two fields result test") {
val frame = sql(
s"""
Expand All @@ -247,12 +249,73 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("default","flint_ppl_tst"))
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val filterExpr = Not(EqualTo(UnresolvedAttribute("name"), Literal("Jake")))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age"))
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple age avg query test") {
val frame = sql(
s"""
| source = $testTable| stats avg(age)
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(37.5),
)

// Compare the results
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val priceField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(age)")())
val aggregatePlan = Project(aggregateExpressions, table)

// Compare the two plans
assert(compareByString(aggregatePlan) === compareByString(logicalPlan))
}

ignore("create ppl simple age avg group by query test ") {
val checkData = sql(s"SELECT name, AVG(age) AS avg_age FROM $testTable group by name");
checkData.show()
checkData.queryExecution.logical.show()

val frame = sql(
s"""
| source = $testTable| stats avg(age) by name
| """.stripMargin)


// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(37.5),
)

// Compare the results
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val priceField = UnresolvedAttribute("price")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))
val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")())
val aggregatePlan = Project( aggregateExpressions, table)

// Compare the two plans
assert(aggregatePlan === logicalPlan)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package org.opensearch.flint.spark

import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}

/**
* general utility functions for ppl to spark transformation test
*/
trait LogicalPlanTestUtils {
/**
* utility function to compare two logical plans while ignoring the auto-generated expressionId associated with the alias
* which is used for projection or aggregation
* @param plan
* @return
*/
def compareByString(plan: LogicalPlan): String = {
// Create a rule to replace Alias's ExprId with a dummy id
val rule: PartialFunction[LogicalPlan, LogicalPlan] = {
case p: Project =>
val newProjections = p.projectList.map {
case alias: Alias => Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier)
case other => other
}
p.copy(projectList = newProjections)

case agg: Aggregate =>
val newGrouping = agg.groupingExpressions.map {
case alias: Alias => Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier)
case other => other
}
val newAggregations = agg.aggregateExpressions.map {
case alias: Alias => Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier)
case other => other
}
agg.copy(groupingExpressions = newGrouping, aggregateExpressions = newAggregations)

case other => other
}

// Apply the rule using transform
val transformedPlan = plan.transform(rule)

// Return the string representation of the transformed plan
transformedPlan.toString
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,12 @@ public String visitAggregation(Aggregation node, CatalystPlanContext context) {
final String visitExpressionList = visitExpressionList(node.getAggExprList(), context);
final String group = visitExpressionList(node.getGroupExprList(), context);

Seq<Expression> groupBy = isNullOrEmpty(group) ? asScalaBuffer(emptyList()) : asScalaBuffer(singletonList(context.getNamedParseExpressions().pop())).toSeq();
context.plan(p->new Aggregate(groupBy,asScalaBuffer(singletonList((NamedExpression) context.getNamedParseExpressions().pop())).toSeq(),p));
NamedExpression namedExpression = (NamedExpression) context.getNamedParseExpressions().peek();
Seq<NamedExpression> namedExpressionSeq = asScalaBuffer(singletonList(namedExpression)).toSeq();

if(!isNullOrEmpty(group)) {
context.plan(p->new Aggregate(asScalaBuffer(singletonList((Expression) namedExpression)),namedExpressionSeq,p));
}
return format(
"%s | stats %s",
child, String.join(" ", visitExpressionList, groupBy(group)).trim());
Expand Down Expand Up @@ -311,8 +315,8 @@ public String visitNot(Not node, CatalystPlanContext context) {
@Override
public String visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) {
String arg = node.getField().accept(this, context);
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction aggregator = AggregatorTranslator.aggregator(node, context);
context.getNamedParseExpressions().add((org.apache.spark.sql.catalyst.expressions.Expression) aggregator);
org.apache.spark.sql.catalyst.expressions.Expression aggregator = AggregatorTranslator.aggregator(node, context);
context.getNamedParseExpressions().add(aggregator);
return format("%s(%s)", node.getFuncName(), arg);
}

Expand Down Expand Up @@ -342,16 +346,23 @@ public String visitField(Field node, CatalystPlanContext context) {

@Override
public String visitAllFields(AllFields node, CatalystPlanContext context) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().add(UnresolvedStar$.MODULE$.apply(Option.<Seq<String>>empty()));
return "*";
// Case of aggregation step - no start projection can be added
if(!context.getNamedParseExpressions().isEmpty()) {
// if named expression exist - just return their names
return context.getNamedParseExpressions().peek().toString();
} else {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().add(UnresolvedStar$.MODULE$.apply(Option.<Seq<String>>empty()));
return "*";
}
}

@Override
public String visitAlias(Alias node, CatalystPlanContext context) {
String expr = node.getDelegated().accept(this, context);
Expression expression = (Expression) context.getNamedParseExpressions().pop();
context.getNamedParseExpressions().add(
org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply((Expression) context.getNamedParseExpressions().pop(),
org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply((Expression) expression,
expr,
NamedExpression.newExprId(),
asScalaBufferConverter(new java.util.ArrayList<String>()).asScala().seq(),
Expand Down
Loading

0 comments on commit 65f4372

Please sign in to comment.