diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 321c08e5b..7aab2c8f7 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -6,19 +6,17 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Union; import scala.collection.Seq; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Stack; import java.util.function.Function; import java.util.stream.Collectors; +import static java.util.Collections.emptyList; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.collection.JavaConverters.asScalaBuffer; /** @@ -49,14 +47,14 @@ public class CatalystPlanContext { /** * SortOrder sort by parameters **/ - private Seq sortOrders = asScalaBuffer(Collections.emptyList()); + private Seq sortOrders = seq(emptyList()); public LogicalPlan getPlan() { if (this.planBranches.size() == 1) { return planBranches.peek(); } //default unify sub-plans - return new Union(asScalaBuffer(this.planBranches).toSeq(), true, true); + return new Union(asScalaBuffer(this.planBranches), true, true); } public Stack getNamedParseExpressions() { @@ -100,9 +98,9 @@ public void sort(Seq sortOrders) { * @return */ public Seq retainAllNamedParseExpressions(Function transformFunction) { - Seq aggregateExpressions = asScalaBuffer(getNamedParseExpressions().stream() - .map(transformFunction::apply).collect(Collectors.toList())).toSeq(); - getNamedParseExpressions().retainAll(Collections.emptyList()); + Seq aggregateExpressions = seq(getNamedParseExpressions().stream() + .map(transformFunction::apply).collect(Collectors.toList())); + getNamedParseExpressions().retainAll(emptyList()); return aggregateExpressions; } @@ -111,9 +109,9 @@ public Seq retainAllNamedParseExpressions(Function transfo * @return */ public Seq retainAllGroupingNamedParseExpressions(Function transformFunction) { - Seq aggregateExpressions = asScalaBuffer(getGroupingParseExpressions().stream() - .map(transformFunction::apply).collect(Collectors.toList())).toSeq(); - getGroupingParseExpressions().retainAll(Collections.emptyList()); + Seq aggregateExpressions = seq(getGroupingParseExpressions().stream() + .map(transformFunction::apply).collect(Collectors.toList())); + getGroupingParseExpressions().retainAll(emptyList()); return aggregateExpressions; } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 71e41cc16..210bfeeae 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -5,9 +5,6 @@ package org.opensearch.sql.ppl; -import com.google.common.collect.ImmutableList; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; @@ -31,7 +28,6 @@ import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; @@ -62,19 +58,16 @@ import java.util.Objects; import java.util.stream.Collectors; -import static com.google.common.base.Strings.isNullOrEmpty; -import static java.lang.String.format; -import static java.util.Collections.singletonList; +import static java.util.Collections.emptyList; import static java.util.List.of; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; -import static scala.collection.JavaConverters.asScalaBuffer; /** * Utility class to traverse PPL logical plan and translate it into catalyst logical plan */ -public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { +public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { private final ExpressionAnalyzer expressionAnalyzer; @@ -82,75 +75,73 @@ public CatalystQueryPlanVisitor() { this.expressionAnalyzer = new ExpressionAnalyzer(); } - public String visit(Statement plan, CatalystPlanContext context) { + public LogicalPlan visit(Statement plan, CatalystPlanContext context) { //build plan - String planDesc = plan.accept(this, context); + plan.accept(this, context); //add limit statement visitLimit(context); //add order statement visitSort(context); - return planDesc; + return context.getPlan(); } /** * Handle Query Statement. */ @Override - public String visitQuery(Query node, CatalystPlanContext context) { + public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { return node.getPlan().accept(this, context); } @Override - public String visitExplain(Explain node, CatalystPlanContext context) { + public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { return node.getStatement().accept(this, context); } @Override - public String visitRelation(Relation node, CatalystPlanContext context) { + public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { node.getTableName().forEach(t -> { // Resolving the qualifiedName which is composed of a datasource.schema.table - context.with(new UnresolvedRelation(asScalaBuffer(of(t.split("\\."))).toSeq(), CaseInsensitiveStringMap.empty(), false)); + context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)); }); - return format("source=%s", node.getTableName()); + return context.getPlan(); } @Override - public String visitFilter(Filter node, CatalystPlanContext context) { - String child = node.getChild().get(0).accept(this, context); - String innerCondition = visitExpression(node.getCondition(), context); + public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { + LogicalPlan translatedPlan = node.getChild().get(0).accept(this, context); + Expression conditionExpression = visitExpression(node.getCondition(), context); Expression innerConditionExpression = context.getNamedParseExpressions().pop(); context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(innerConditionExpression, p)); - return format("%s | where %s", child, innerCondition); + return translatedPlan; } @Override - public String visitAggregation(Aggregation node, CatalystPlanContext context) { - String child = node.getChild().get(0).accept(this, context); - final String visitExpressionList = visitExpressionList(node.getAggExprList(), context); - final String group = visitExpressionList(node.getGroupExprList(), context); - - if (!isNullOrEmpty(group)) { + public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List aggsExpList = visitExpressionList(node.getAggExprList(), context); + List groupExpList = visitExpressionList(node.getGroupExprList(), context); + + if (!groupExpList.isEmpty()) { //add group by fields to context - extractedGroupBy(node.getGroupExprList().size(),context); + extractedGroupBy(node.getGroupExprList().size(), context); } - + UnresolvedExpression span = node.getSpan(); if (!Objects.isNull(span)) { span.accept(this, context); //add span's group by field to context - extractedGroupBy(1,context); + extractedGroupBy(1, context); } // build the aggregation logical step extractedAggregation(context); - return format( - "%s | stats %s", - child, String.join(" ", visitExpressionList, groupBy(group)).trim()); + return child; } private static void extractedGroupBy(int groupByElementsCount, CatalystPlanContext context) { //copy the group by aliases from the namedExpressionList to the groupByExpressionList for (int i = 1; i <= groupByElementsCount; i++) { - context.getGroupingParseExpressions().add(context.getNamedParseExpressions().get(context.getNamedParseExpressions().size()-i)); + context.getGroupingParseExpressions().add(context.getNamedParseExpressions().get(context.getNamedParseExpressions().size() - i)); } } @@ -161,18 +152,18 @@ private static void extractedAggregation(CatalystPlanContext context) { } @Override - public String visitAlias(Alias node, CatalystPlanContext context) { - return expressionAnalyzer.visitAlias(node, context); + public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { + expressionAnalyzer.visitAlias(node, context); + return context.getPlan(); } @Override - public String visitProject(Project node, CatalystPlanContext context) { - String child = node.getChild().get(0).accept(this, context); - String arg = "+"; - String fields = visitExpressionList(node.getProjectList(), context); + public LogicalPlan visitProject(Project node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions - Seq projectList = asScalaBuffer(context.getNamedParseExpressions()).toSeq(); + Seq projectList = seq(context.getNamedParseExpressions()); if (!projectList.isEmpty()) { Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step @@ -180,12 +171,9 @@ public String visitProject(Project node, CatalystPlanContext context) { } if (node.hasArgument()) { Argument argument = node.getArgExprList().get(0); - Boolean exclude = (Boolean) argument.getValue().getValue(); - if (exclude) { - arg = "-"; - } + //todo exclude the argument from the projected aruments list } - return format("%s | fields %s %s", child, arg, fields); + return child; } private static void visitSort(CatalystPlanContext context) { @@ -202,242 +190,216 @@ private static void visitLimit(CatalystPlanContext context) { } @Override - public String visitEval(Eval node, CatalystPlanContext context) { - String child = node.getChild().get(0).accept(this, context); - ImmutableList.Builder> expressionsBuilder = new ImmutableList.Builder<>(); - for (Let let : node.getExpressionList()) { - String expression = visitExpression(let.getExpression(), context); - String target = let.getVar().getField().toString(); - expressionsBuilder.add(ImmutablePair.of(target, expression)); - } - String expressions = - expressionsBuilder.build().stream() - .map(pair -> format("%s" + "=%s", pair.getLeft(), pair.getRight())) - .collect(Collectors.joining(" ")); - return format("%s | eval %s", child, expressions); - } - - @Override - public String visitSort(Sort node, CatalystPlanContext context) { - String child = node.getChild().get(0).accept(this, context); - String sortList = visitFieldList(node.getSortList(), context); + public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFieldList(node.getSortList(), context); context.sort(context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp))); - return format("%s | sort %s", child, sortList); + return child; } @Override - public String visitHead(Head node, CatalystPlanContext context) { - String child = node.getChild().get(0).accept(this, context); - Integer size = node.getSize(); - context.limit(size); - return format("%s | head %d", child, size); + public LogicalPlan visitHead(Head node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + context.limit(node.getSize()); + return child; } - private String visitFieldList(List fieldList, CatalystPlanContext context) { - return fieldList.stream().map(field -> visitExpression(field, context)).collect(Collectors.joining(",")); + private void visitFieldList(List fieldList, CatalystPlanContext context) { + fieldList.forEach(field -> visitExpression(field, context)); } - private String visitExpressionList(List expressionList, CatalystPlanContext context) { + private List visitExpressionList(List expressionList, CatalystPlanContext context) { return expressionList.isEmpty() - ? "" + ? emptyList() : expressionList.stream().map(field -> visitExpression(field, context)) - .collect(Collectors.joining(",")); + .collect(Collectors.toList()); } - private String visitExpression(UnresolvedExpression expression, CatalystPlanContext context) { + private Expression visitExpression(UnresolvedExpression expression, CatalystPlanContext context) { return expressionAnalyzer.analyze(expression, context); } - private String groupBy(String groupBy) { - return isNullOrEmpty(groupBy) ? "" : format("by %s", groupBy); + @Override + public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); } @Override - public String visitKmeans(Kmeans node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Kmeans" ); + public LogicalPlan visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); } @Override - public String visitIn(In node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : In" ); + public LogicalPlan visitIn(In node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : In"); } @Override - public String visitCase(Case node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Case" ); + public LogicalPlan visitCase(Case node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Case"); } @Override - public String visitRareTopN(RareTopN node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : RareTopN" ); + public LogicalPlan visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); } @Override - public String visitWindowFunction(WindowFunction node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : WindowFunction" ); + public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); } @Override - public String visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : dedupe " ); + public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : dedupe "); } + /** * Expression Analyzer. */ - private static class ExpressionAnalyzer extends AbstractNodeVisitor { + private static class ExpressionAnalyzer extends AbstractNodeVisitor { - public String analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { return unresolved.accept(this, context); } @Override - public String visitLiteral(Literal node, CatalystPlanContext context) { - context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Literal( + public Expression visitLiteral(Literal node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( translate(node.getValue(), node.getType()), translate(node.getType()))); - return node.toString(); - } - - @Override - public String visitInterval(Interval node, CatalystPlanContext context) { - String value = node.getValue().accept(this, context); - String unit = node.getUnit().name(); - return format("INTERVAL %s %s", value, unit); } @Override - public String visitAnd(And node, CatalystPlanContext context) { - String left = node.getLeft().accept(this, context); - String right = node.getRight().accept(this, context); - context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.And( - (Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop())); - return format("%s and %s", left, right); + public Expression visitAnd(And node, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + node.getRight().accept(this, context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)); } @Override - public String visitOr(Or node, CatalystPlanContext context) { - String left = node.getLeft().accept(this, context); - String right = node.getRight().accept(this, context); - context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Or( - (Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop())); - return format("%s or %s", left, right); + public Expression visitOr(Or node, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + node.getRight().accept(this, context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)); } @Override - public String visitXor(Xor node, CatalystPlanContext context) { - String left = node.getLeft().accept(this, context); - String right = node.getRight().accept(this, context); - context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.BitwiseXor( - (Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop())); - return format("%s xor %s", left, right); + public Expression visitXor(Xor node, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + node.getRight().accept(this, context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)); } @Override - public String visitNot(Not node, CatalystPlanContext context) { - String expr = node.getExpression().accept(this, context); - context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Not( - (Expression) context.getNamedParseExpressions().pop())); - return format("not %s", expr); + public Expression visitNot(Not node, CatalystPlanContext context) { + node.getExpression().accept(this, context); + Expression arg = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(arg)); } @Override - public String visitSpan(Span node, CatalystPlanContext context) { - String field = node.getField().accept(this, context); - String value = node.getValue().accept(this, context); - String unit = node.getUnit().name(); - - Expression valueExpression = context.getNamedParseExpressions().pop(); - Expression fieldExpression = context.getNamedParseExpressions().pop(); - - context.getNamedParseExpressions().push(window(fieldExpression,valueExpression,node.getUnit())); - return format("span (%s,%s,%s)", field, value, unit); + public Expression visitSpan(Span node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression field = (Expression) context.getNamedParseExpressions().pop(); + node.getValue().accept(this, context); + Expression value = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); } @Override - public String visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { - String arg = node.getField().accept(this, context); - org.apache.spark.sql.catalyst.expressions.Expression aggregator = AggregatorTranslator.aggregator(node, context); - context.getNamedParseExpressions().add(aggregator); - return format("%s(%s)", node.getFuncName(), arg); + public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression arg = (Expression) context.getNamedParseExpressions().pop(); + Expression aggregator = AggregatorTranslator.aggregator(node, arg); + return context.getNamedParseExpressions().push(aggregator); } @Override - public String visitFunction(Function node, CatalystPlanContext context) { - String arguments = - node.getFuncArgs().stream() - .map(unresolvedExpression -> analyze(unresolvedExpression, context)) - .collect(Collectors.joining(",")); - return format("%s(%s)", node.getFuncName(), arguments); + public Expression visitCompare(Compare node, CatalystPlanContext context) { + analyze(node.getLeft(), context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + analyze(node.getRight(), context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + Predicate comparator = ComparatorTransformer.comparator(node, left, right); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); } @Override - public String visitCompare(Compare node, CatalystPlanContext context) { - String left = analyze(node.getLeft(), context); - String right = analyze(node.getRight(), context); - Predicate comparator = ComparatorTransformer.comparator(node, context); - context.getNamedParseExpressions().add((org.apache.spark.sql.catalyst.expressions.Expression) comparator); - return format("%s %s %s", left, node.getOperator(), right); + public Expression visitField(Field node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getField().toString()))); } @Override - public String visitField(Field node, CatalystPlanContext context) { - context.getNamedParseExpressions().add(UnresolvedAttribute$.MODULE$.apply(asScalaBuffer(singletonList(node.getField().toString())))); - return node.getField().toString(); - } - - @Override - public String visitAllFields(AllFields node, CatalystPlanContext context) { + public Expression visitAllFields(AllFields node, CatalystPlanContext context) { // Case of aggregation step - no start projection can be added - if (!context.getNamedParseExpressions().isEmpty()) { - // if named expression exist - just return their names - return context.getNamedParseExpressions().peek().toString(); - } else { + if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().add(UnresolvedStar$.MODULE$.apply(Option.>empty())); - return "*"; + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); } + return context.getNamedParseExpressions().peek(); } @Override - public String visitAlias(Alias node, CatalystPlanContext context) { - String expr = node.getDelegated().accept(this, context); - Expression expression = (Expression) context.getNamedParseExpressions().pop(); - context.getNamedParseExpressions().add( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply((Expression) expression, - node.getAlias()!=null ? node.getAlias() : expr, + public Expression visitAlias(Alias node, CatalystPlanContext context) { + node.getDelegated().accept(this, context); + Expression arg = context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, + node.getAlias() != null ? node.getAlias() : node.getName(), NamedExpression.newExprId(), seq(new java.util.ArrayList()), Option.empty(), seq(new java.util.ArrayList()))); - return format("%s", expr); } @Override - public String visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Dedupe" ); + public Expression visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public Expression visitFunction(Function node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Function"); + } + + @Override + public Expression visitInterval(Interval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Interval"); + } + + @Override + public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Dedupe"); } @Override - public String visitIn(In node, CatalystPlanContext context) { + public Expression visitIn(In node, CatalystPlanContext context) { throw new IllegalStateException("Not Supported operation : In"); } @Override - public String visitKmeans(Kmeans node, CatalystPlanContext context) { + public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { throw new IllegalStateException("Not Supported operation : Kmeans"); } @Override - public String visitCase(Case node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Case" ); + public Expression visitCase(Case node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Case"); } @Override - public String visitRareTopN(RareTopN node, CatalystPlanContext context) { + public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { throw new IllegalStateException("Not Supported operation : RareTopN"); } @Override - public String visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { throw new IllegalStateException("Not Supported operation : WindowFunction"); } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 9b66d370f..e15324cc0 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -8,11 +8,9 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.Expression; import org.opensearch.sql.expression.function.BuiltinFunctionName; -import org.opensearch.sql.ppl.CatalystPlanContext; -import static java.util.List.of; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; -import static scala.collection.JavaConverters.asScalaBuffer; /** * aggregator expression builder building a catalyst aggregation function from PPL's aggregation logical step @@ -21,27 +19,22 @@ */ public interface AggregatorTranslator { - static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, CatalystPlanContext context) { + static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName()); // Additional aggregation function operators will be added here switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { case MAX: - return new UnresolvedFunction(asScalaBuffer(of("MAX")).toSeq(), - asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); + return new UnresolvedFunction(seq("MAX"), seq(arg),false, empty(),false); case MIN: - return new UnresolvedFunction(asScalaBuffer(of("MIN")).toSeq(), - asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); + return new UnresolvedFunction(seq("MIN"), seq(arg),false, empty(),false); case AVG: - return new UnresolvedFunction(asScalaBuffer(of("AVG")).toSeq(), - asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); + return new UnresolvedFunction(seq("AVG"), seq(arg),false, empty(),false); case COUNT: - return new UnresolvedFunction(asScalaBuffer(of("COUNT")).toSeq(), - asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); + return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false); case SUM: - return new UnresolvedFunction(asScalaBuffer(of("SUM")).toSeq(), - asScalaBuffer(of(context.getNamedParseExpressions().pop())).toSeq(),false, empty(),false); + return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ComparatorTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ComparatorTransformer.java index 44f5cb9f4..2a176ec3d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ComparatorTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ComparatorTransformer.java @@ -15,7 +15,6 @@ import org.apache.spark.sql.catalyst.expressions.Predicate; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.expression.function.BuiltinFunctionName; -import org.opensearch.sql.ppl.CatalystPlanContext; /** * Transform the PPL Logical comparator into catalyst comparator @@ -26,16 +25,17 @@ public interface ComparatorTransformer { * * @return */ - static Predicate comparator(Compare expression, CatalystPlanContext context) { + static Predicate comparator(Compare expression, Expression left, Expression right) { if (BuiltinFunctionName.of(expression.getOperator()).isEmpty()) throw new IllegalStateException("Unexpected value: " + expression.getOperator()); - if (context.getNamedParseExpressions().isEmpty()) { - throw new IllegalStateException("Unexpected value: No operands found in expression"); + if (left == null) { + throw new IllegalStateException("Unexpected value: No Left operands found in expression"); } - Expression right = context.getNamedParseExpressions().pop(); - Expression left = context.getNamedParseExpressions().isEmpty() ? null : context.getNamedParseExpressions().pop(); + if (right == null) { + throw new IllegalStateException("Unexpected value: No Right operands found in expression"); + } // Additional function operators will be added here switch (BuiltinFunctionName.of(expression.getOperator()).get()) { @@ -54,4 +54,5 @@ static Predicate comparator(Compare expression, CatalystPlanContext context) { } throw new IllegalStateException("Not Supported value: " + expression.getOperator()); } + } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java index 17ed3fa94..83603b031 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java @@ -37,10 +37,9 @@ static SortOrder getSortDirection(Sort node, NamedExpression expression) { .filter(f -> f.getField().toString().equals(expression.name())) .findAny(); - if(field.isPresent()) { - return sortOrder((Expression) expression, (Boolean) field.get().getFieldArgs().get(0).getValue().getValue()); - } - return null; + return field.map(value -> sortOrder((Expression) expression, + (Boolean) value.getFieldArgs().get(0).getValue().getValue())) + .orElse(null); } @NotNull diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index e45f30c6c..95ad2d19f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -12,7 +12,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, Literal, Multiply, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, Literal, Multiply, SortOrder, TimeWindow} import org.apache.spark.sql.catalyst.plans.logical._ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite @@ -38,8 +38,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(context.getPlan)) - assertEquals(logPlan, "source=[table] | stats avg(price) | fields + *") + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) } ignore("test average price with Alias") { @@ -58,8 +57,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(context.getPlan)) - assertEquals(logPlan, "source=[table] | stats avg(price) as avg_price | fields + *") + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) } test("test average price group by product ") { @@ -83,8 +81,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(logPlan, "source=[table] | stats avg(price) by product | fields + *") - assertEquals(compareByString(expectedPlan), compareByString(context.getPlan)) + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) } test("test average price group by product and filter") { @@ -112,10 +109,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) val expectedPlan = Project(star, aggregatePlan) - assertEquals( - logPlan, - "source=[table] | where country = 'USA' | stats avg(price) by product | fields + *") - assertEquals(compareByString(expectedPlan), compareByString(context.getPlan)) + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) } test("test average price group by product and filter sorted") { @@ -148,10 +142,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("product"), Ascending)), global = true, expectedPlan) - assertEquals( - logPlan, - "source=[table] | where country = 'USA' | stats avg(price) by product | sort product | fields + *") - assertEquals(compareByString(sortedPlan), compareByString(context.getPlan)) + assertEquals(compareByString(sortedPlan), compareByString(logPlan)) } test("create ppl simple avg age by span of interval of 10 years query test ") { val context = new CatalystPlanContext @@ -171,8 +162,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(logPlan, "source=[table] | stats avg(age) | fields + *") - assert(compareByString(expectedPlan) === compareByString(context.getPlan)) + assert(compareByString(expectedPlan) === compareByString(logPlan)) } test("create ppl simple avg age by span of interval of 10 years query with sort test ") { @@ -198,8 +188,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan) - assertEquals(logPlan, "source=[table] | stats avg(age) | sort age | fields + *") - assert(compareByString(sortedPlan) === compareByString(context.getPlan)) + assert(compareByString(sortedPlan) === compareByString(logPlan)) } test("create ppl simple avg age by span of interval of 10 years by country query test ") { @@ -228,8 +217,85 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(logPlan, "source=[table] | stats avg(age) by country | fields + *") - assert(compareByString(expectedPlan) === compareByString(context.getPlan)) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("create ppl query count sales by weeks window and productId with sorting test") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date | sort age_date", + false), + context) + + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = Aggregate( + Seq(windowExpression), + Seq(aggregateExpressions, windowExpression), + table) + val expectedPlan = Project(star, aggregatePlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logPlan)) + } + + test("create ppl query count sales by days window and productId with sorting test") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date, productId | sort age_date", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsId = Alias(UnresolvedAttribute("productId"), "productId")() + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = Aggregate( + Seq(productsId, windowExpression), + Seq(aggregateExpressions, productsId, windowExpression), + table) + val expectedPlan = Project(star, aggregatePlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logPlan)) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 3b5f99e8f..ed600c064 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -31,8 +31,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[table] | fields + *") + assertEquals(expectedPlan, logPlan) } @@ -43,8 +42,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[schema.table] | fields + *") + assertEquals(expectedPlan, logPlan) } @@ -55,8 +53,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[schema.table] | fields + A") + assertEquals(expectedPlan, logPlan) } test("test simple search with only one table with one field projected") { @@ -66,8 +63,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[table] | fields + A") + assertEquals(expectedPlan, logPlan) } test("test simple search with only one table with two fields projected") { @@ -77,8 +73,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val table = UnresolvedRelation(Seq("t")) val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) val expectedPlan = Project(projectList, table) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | fields + A,B") + assertEquals(expectedPlan, logPlan) } test("test simple search with one table with two fields projected sorted by one field") { @@ -93,8 +88,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Ascending)) val sorted = Sort(sortOrder, true, expectedPlan) - assert(compareByString(sorted) === compareByString(context.getPlan)) - assertEquals(logPlan, "source=[t] | sort A | fields + A,B") + assert(compareByString(sorted) === compareByString(logPlan)) } test( @@ -107,8 +101,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) val planWithLimit = Project(Seq(UnresolvedStar(None)), Project(projectList, table)) val expectedPlan = GlobalLimit(Literal(5), LocalLimit(Literal(5), planWithLimit)) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | fields + A,B | head 5 | fields + *") + assertEquals(expectedPlan, logPlan) } test( @@ -127,8 +120,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending)) val sorted = Sort(sortOrder, true, expectedPlan) - assertEquals(logPlan, "source=[t] | sort A | fields + A,B | head 5 | fields + *") - assertEquals(compareByString(sorted), compareByString(context.getPlan)) + assertEquals(compareByString(sorted), compareByString(logPlan)) } test( @@ -150,8 +142,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) - assertEquals(logPlan, "source=[table1, table2] | fields + A,B") - assertEquals(expectedPlan, context.getPlan) + assertEquals(expectedPlan, logPlan) } test("Search multiple tables - translated into union call with fields") { @@ -171,7 +162,6 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) - assertEquals(logPlan, "source=[table1, table2] | fields + *") - assertEquals(expectedPlan, context.getPlan) + assertEquals(expectedPlan, logPlan) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala index f87267200..5f2b42848 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala @@ -40,8 +40,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a = 1 | fields + *") + assertEquals(expectedPlan, logPlan) } test("test simple search with only one table with two field with 'and' filtered ") { @@ -52,11 +51,10 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val table = UnresolvedRelation(Seq("t")) val filterAExpr = EqualTo(UnresolvedAttribute("a"), Literal(1)) val filterBExpr = Not(EqualTo(UnresolvedAttribute("b"), Literal(2))) - val filterPlan = Filter(And(filterBExpr, filterAExpr), table) + val filterPlan = Filter(And(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a = 1 and b != 2 | fields + *") + assertEquals(expectedPlan, logPlan) } test("test simple search with only one table with two field with 'or' filtered ") { @@ -67,11 +65,10 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val table = UnresolvedRelation(Seq("t")) val filterAExpr = EqualTo(UnresolvedAttribute("a"), Literal(1)) val filterBExpr = Not(EqualTo(UnresolvedAttribute("b"), Literal(2))) - val filterPlan = Filter(Or(filterBExpr, filterAExpr), table) + val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a = 1 or b != 2 | fields + *") + assertEquals(expectedPlan, logPlan) } test("test simple search with only one table with two field with 'not' filtered ") { @@ -82,11 +79,10 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val table = UnresolvedRelation(Seq("t")) val filterAExpr = Not(EqualTo(UnresolvedAttribute("a"), Literal(1))) val filterBExpr = Not(EqualTo(UnresolvedAttribute("b"), Literal(2))) - val filterPlan = Filter(Or(filterBExpr, filterAExpr), table) + val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where not a = 1 or b != 2 | fields + *") + assertEquals(expectedPlan, logPlan) } test( @@ -100,8 +96,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a = 1 | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -116,8 +111,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a = 'hi' | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -133,8 +127,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a != 'bye' | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -148,8 +141,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a > 1 | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -163,8 +155,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a >= 1 | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -178,8 +169,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a < 1 | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -193,8 +183,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a <= 1 | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -208,8 +197,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "source=[t] | where a != 1 | fields + a") + assertEquals(expectedPlan, logPlan) } test( @@ -227,7 +215,6 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("a"), Ascending)), global = true, expectedPlan) - assertEquals(compareByString(sortedPlan), compareByString(context.getPlan)) - assertEquals(logPlan, "source=[t] | where a != 1 | fields + a | sort a | fields + *") + assertEquals(compareByString(sortedPlan), compareByString(logPlan)) } }