Skip to content

Commit

Permalink
update ppl CatalystPlan visitor to produce the logical plan as part o…
Browse files Browse the repository at this point in the history
…f the visitor instead of String

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Sep 26, 2023
1 parent 284f082 commit 119fd5e
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 287 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@
package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Union;
import scala.collection.Seq;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Stack;
import java.util.function.Function;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static scala.collection.JavaConverters.asScalaBuffer;

/**
Expand Down Expand Up @@ -49,14 +47,14 @@ public class CatalystPlanContext {
/**
* SortOrder sort by parameters
**/
private Seq<SortOrder> sortOrders = asScalaBuffer(Collections.emptyList());
private Seq<SortOrder> sortOrders = seq(emptyList());

public LogicalPlan getPlan() {
if (this.planBranches.size() == 1) {
return planBranches.peek();
}
//default unify sub-plans
return new Union(asScalaBuffer(this.planBranches).toSeq(), true, true);
return new Union(asScalaBuffer(this.planBranches), true, true);
}

public Stack<Expression> getNamedParseExpressions() {
Expand Down Expand Up @@ -100,9 +98,9 @@ public void sort(Seq<SortOrder> sortOrders) {
* @return
*/
public <T> Seq<T> retainAllNamedParseExpressions(Function<Expression, T> transformFunction) {
Seq<T> aggregateExpressions = asScalaBuffer(getNamedParseExpressions().stream()
.map(transformFunction::apply).collect(Collectors.toList())).toSeq();
getNamedParseExpressions().retainAll(Collections.emptyList());
Seq<T> aggregateExpressions = seq(getNamedParseExpressions().stream()
.map(transformFunction::apply).collect(Collectors.toList()));
getNamedParseExpressions().retainAll(emptyList());
return aggregateExpressions;
}

Expand All @@ -111,9 +109,9 @@ public <T> Seq<T> retainAllNamedParseExpressions(Function<Expression, T> transfo
* @return
*/
public <T> Seq<T> retainAllGroupingNamedParseExpressions(Function<Expression, T> transformFunction) {
Seq<T> aggregateExpressions = asScalaBuffer(getGroupingParseExpressions().stream()
.map(transformFunction::apply).collect(Collectors.toList())).toSeq();
getGroupingParseExpressions().retainAll(Collections.emptyList());
Seq<T> aggregateExpressions = seq(getGroupingParseExpressions().stream()
.map(transformFunction::apply).collect(Collectors.toList()));
getGroupingParseExpressions().retainAll(emptyList());
return aggregateExpressions;
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.ppl.CatalystPlanContext;

import static java.util.List.of;
import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static scala.Option.empty;
import static scala.collection.JavaConverters.asScalaBuffer;

/**
* aggregator expression builder building a catalyst aggregation function from PPL's aggregation logical step
Expand All @@ -21,27 +19,22 @@
*/
public interface AggregatorTranslator {

static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, CatalystPlanContext context) {
static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) {
if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty())
throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName());

// Additional aggregation function operators will be added here
switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) {
case MAX:
return new UnresolvedFunction(asScalaBuffer(of("MAX")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
return new UnresolvedFunction(seq("MAX"), seq(arg),false, empty(),false);
case MIN:
return new UnresolvedFunction(asScalaBuffer(of("MIN")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
return new UnresolvedFunction(seq("MIN"), seq(arg),false, empty(),false);
case AVG:
return new UnresolvedFunction(asScalaBuffer(of("AVG")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
return new UnresolvedFunction(seq("AVG"), seq(arg),false, empty(),false);
case COUNT:
return new UnresolvedFunction(asScalaBuffer(of("COUNT")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false);
case SUM:
return new UnresolvedFunction(asScalaBuffer(of("SUM")).toSeq(),
asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false);
return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false);
}
throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.ppl.CatalystPlanContext;

/**
* Transform the PPL Logical comparator into catalyst comparator
Expand All @@ -26,16 +25,17 @@ public interface ComparatorTransformer {
*
* @return
*/
static Predicate comparator(Compare expression, CatalystPlanContext context) {
static Predicate comparator(Compare expression, Expression left, Expression right) {
if (BuiltinFunctionName.of(expression.getOperator()).isEmpty())
throw new IllegalStateException("Unexpected value: " + expression.getOperator());

if (context.getNamedParseExpressions().isEmpty()) {
throw new IllegalStateException("Unexpected value: No operands found in expression");
if (left == null) {
throw new IllegalStateException("Unexpected value: No Left operands found in expression");
}

Expression right = context.getNamedParseExpressions().pop();
Expression left = context.getNamedParseExpressions().isEmpty() ? null : context.getNamedParseExpressions().pop();
if (right == null) {
throw new IllegalStateException("Unexpected value: No Right operands found in expression");
}

// Additional function operators will be added here
switch (BuiltinFunctionName.of(expression.getOperator()).get()) {
Expand All @@ -54,4 +54,5 @@ static Predicate comparator(Compare expression, CatalystPlanContext context) {
}
throw new IllegalStateException("Not Supported value: " + expression.getOperator());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ static SortOrder getSortDirection(Sort node, NamedExpression expression) {
.filter(f -> f.getField().toString().equals(expression.name()))
.findAny();

if(field.isPresent()) {
return sortOrder((Expression) expression, (Boolean) field.get().getFieldArgs().get(0).getValue().getValue());
}
return null;
return field.map(value -> sortOrder((Expression) expression,
(Boolean) value.getFieldArgs().get(0).getValue().getValue()))
.orElse(null);
}

@NotNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, Literal, Multiply, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, Literal, Multiply, SortOrder, TimeWindow}
import org.apache.spark.sql.catalyst.plans.logical._

class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Expand All @@ -38,8 +38,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(context.getPlan))
assertEquals(logPlan, "source=[table] | stats avg(price) | fields + *")
assertEquals(compareByString(expectedPlan), compareByString(logPlan))
}

ignore("test average price with Alias") {
Expand All @@ -58,8 +57,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(context.getPlan))
assertEquals(logPlan, "source=[table] | stats avg(price) as avg_price | fields + *")
assertEquals(compareByString(expectedPlan), compareByString(logPlan))
}

test("test average price group by product ") {
Expand All @@ -83,8 +81,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(logPlan, "source=[table] | stats avg(price) by product | fields + *")
assertEquals(compareByString(expectedPlan), compareByString(context.getPlan))
assertEquals(compareByString(expectedPlan), compareByString(logPlan))
}

test("test average price group by product and filter") {
Expand Down Expand Up @@ -112,10 +109,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(
logPlan,
"source=[table] | where country = 'USA' | stats avg(price) by product | fields + *")
assertEquals(compareByString(expectedPlan), compareByString(context.getPlan))
assertEquals(compareByString(expectedPlan), compareByString(logPlan))
}

test("test average price group by product and filter sorted") {
Expand Down Expand Up @@ -148,10 +142,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("product"), Ascending)), global = true, expectedPlan)

assertEquals(
logPlan,
"source=[table] | where country = 'USA' | stats avg(price) by product | sort product | fields + *")
assertEquals(compareByString(sortedPlan), compareByString(context.getPlan))
assertEquals(compareByString(sortedPlan), compareByString(logPlan))
}
test("create ppl simple avg age by span of interval of 10 years query test ") {
val context = new CatalystPlanContext
Expand All @@ -171,8 +162,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(logPlan, "source=[table] | stats avg(age) | fields + *")
assert(compareByString(expectedPlan) === compareByString(context.getPlan))
assert(compareByString(expectedPlan) === compareByString(logPlan))
}

test("create ppl simple avg age by span of interval of 10 years query with sort test ") {
Expand All @@ -198,8 +188,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan)

assertEquals(logPlan, "source=[table] | stats avg(age) | sort age | fields + *")
assert(compareByString(sortedPlan) === compareByString(context.getPlan))
assert(compareByString(sortedPlan) === compareByString(logPlan))
}

test("create ppl simple avg age by span of interval of 10 years by country query test ") {
Expand Down Expand Up @@ -228,8 +217,85 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(logPlan, "source=[table] | stats avg(age) by country | fields + *")
assert(compareByString(expectedPlan) === compareByString(context.getPlan))
assert(compareByString(expectedPlan) === compareByString(logPlan))
}

test("create ppl query count sales by weeks window and productId with sorting test") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(
plan(
pplParser,
"source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date | sort age_date",
false),
context)

// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val productsAmount = UnresolvedAttribute("productsAmount")
val table = UnresolvedRelation(Seq("table"))

val windowExpression = Alias(
TimeWindow(
UnresolvedAttribute("transactionDate"),
TimeWindow.parseExpression(Literal("1 week")),
TimeWindow.parseExpression(Literal("1 week")),
0),
"age_date")()

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false),
"sum(productsAmount)")()
val aggregatePlan = Aggregate(
Seq(windowExpression),
Seq(aggregateExpressions, windowExpression),
table)
val expectedPlan = Project(star, aggregatePlan)
val sortedPlan: LogicalPlan = Sort(
Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)),
global = true,
expectedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logPlan))
}

test("create ppl query count sales by days window and productId with sorting test") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(
plan(
pplParser,
"source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date, productId | sort age_date",
false),
context)
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val productsId = Alias(UnresolvedAttribute("productId"), "productId")()
val productsAmount = UnresolvedAttribute("productsAmount")
val table = UnresolvedRelation(Seq("table"))

val windowExpression = Alias(
TimeWindow(
UnresolvedAttribute("transactionDate"),
TimeWindow.parseExpression(Literal("1 day")),
TimeWindow.parseExpression(Literal("1 day")),
0),
"age_date")()

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false),
"sum(productsAmount)")()
val aggregatePlan = Aggregate(
Seq(productsId, windowExpression),
Seq(aggregateExpressions, productsId, windowExpression),
table)
val expectedPlan = Project(star, aggregatePlan)
val sortedPlan: LogicalPlan = Sort(
Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)),
global = true,
expectedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logPlan))
}

}
Loading

0 comments on commit 119fd5e

Please sign in to comment.