From 98962c31c1b5cf8ea8156982f0688c2df5a93f2b Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 31 Oct 2024 11:47:50 -0700 Subject: [PATCH] Refactor `CatalystQueryPlanVisitor` into distinct Plan & Expression visitors (#852) * We would like to refactor the CatalystQueryPlanVisitor and separate it into two distinct visitors: Plan Visitor ( which extends AbstractNodeVisitor ) Expression Visitor (which extends AbstractNodeVisitor) This would match the existing PPL AST visitors composition: AstBuilder ( which extends OpenSearchPPLParserBaseVisitor) AstExpressionBuilder ( which extends OpenSearchPPLParserBaseVisitor ) In addition unify the ppl utils classes to match one of the following naming: *Transformer - transforms PPL (logical) expressions into Spark (logical) expressions *Utils - utility class Signed-off-by: YANGDB * update the AstBuilder ctor Signed-off-by: YANGDB * resolve latest merge conflicts Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- .../sql/ppl/CatalystExpressionVisitor.java | 432 ++++++++++++++++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 358 +-------------- .../opensearch/sql/ppl/parser/AstBuilder.java | 4 +- .../sql/ppl/parser/AstExpressionBuilder.java | 11 +- .../sql/ppl/parser/AstStatementBuilder.java | 4 +- ...slator.java => AggregatorTransformer.java} | 4 +- ...r.java => BuiltinFunctionTransformer.java} | 2 +- .../sql/ppl/utils/DedupeTransformer.java | 14 +- .../sql/ppl/utils/LookupTransformer.java | 9 +- ...rseStrategy.java => ParseTransformer.java} | 3 +- .../flint/spark/ppl/PPLSyntaxParser.scala | 9 +- 11 files changed, 463 insertions(+), 387 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java rename ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/{AggregatorTranslator.java => AggregatorTransformer.java} (97%) rename ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/{BuiltinFunctionTranslator.java => BuiltinFunctionTransformer.java} (99%) rename ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/{ParseStrategy.java => ParseTransformer.java} (97%) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java new file mode 100644 index 000000000..571905f8a --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -0,0 +1,432 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.Exists$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.In$; +import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.ListQuery$; +import org.apache.spark.sql.catalyst.expressions.MakeInterval$; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Between; +import org.opensearch.sql.ast.expression.BinaryExpression; +import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.FieldsMapping; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.When; +import org.opensearch.sql.ast.expression.WindowFunction; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FillNull; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.utils.AggregatorTransformer; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; +import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import scala.Option; +import scala.Tuple2; +import scala.collection.Seq; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Stack; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; +import static org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer.createIntervalArgs; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; +import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; + +/** + * Class of building catalyst AST Expression nodes. + */ +public class CatalystExpressionVisitor extends AbstractNodeVisitor { + + private final AbstractNodeVisitor planVisitor; + + public CatalystExpressionVisitor(AbstractNodeVisitor planVisitor) { + this.planVisitor = planVisitor; + } + + public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + return unresolved.accept(this, context); + } + + @Override + public Expression visitLiteral(Literal node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( + translate(node.getValue(), node.getType()), translate(node.getType()))); + } + + /** + * generic binary (And, Or, Xor , ...) arithmetic expression resolver + * + * @param node + * @param transformer + * @param context + * @return + */ + public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Optional left = context.popNamedParseExpressions(); + node.getRight().accept(this, context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + return transformer.apply(left.get(), right.get()); + } else if (left.isPresent()) { + return context.getNamedParseExpressions().push(left.get()); + } else if (right.isPresent()) { + return context.getNamedParseExpressions().push(right.get()); + } + return null; + + } + + @Override + public Expression visitAnd(And node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); + } + + @Override + public Expression visitOr(Or node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); + } + + @Override + public Expression visitXor(Xor node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); + } + + @Override + public Expression visitNot(Not node, CatalystPlanContext context) { + node.getExpression().accept(this, context); + Optional arg = context.popNamedParseExpressions(); + return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); + } + + @Override + public Expression visitSpan(Span node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression field = (Expression) context.popNamedParseExpressions().get(); + node.getValue().accept(this, context); + Expression value = (Expression) context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); + } + + @Override + public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression arg = (Expression) context.popNamedParseExpressions().get(); + Expression aggregator = AggregatorTransformer.aggregator(node, arg); + return context.getNamedParseExpressions().push(aggregator); + } + + @Override + public Expression visitCompare(Compare node, CatalystPlanContext context) { + analyze(node.getLeft(), context); + Optional left = context.popNamedParseExpressions(); + analyze(node.getRight(), context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + } + return null; + } + + @Override + public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + List relation = findRelation(context.traversalContext()); + if (!relation.isEmpty()) { + Optional resolveField = resolveField(relation, node, context.getRelations()); + return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) + .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); + } + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + + /** + * Resolve the qualified name with subquery alias:
+ * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
+ * - tableName1.joinKey = subqueryAlias2.joinKey
+ * - subqueryAlias1.joinKey = tableName2.joinKey
+ */ + private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { + if (node.getPrefix().isPresent() && + context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { + if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) + .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) + .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + } + return null; + } + + @Override + public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { + return node.getChild().stream().map(expression -> + visitCompare((Compare) expression, context) + ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + + @Override + public Expression visitAllFields(AllFields node, CatalystPlanContext context) { + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + return context.getNamedParseExpressions().peek(); + } + + @Override + public Expression visitAlias(Alias node, CatalystPlanContext context) { + node.getDelegated().accept(this, context); + Expression arg = context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, + node.getAlias() != null ? node.getAlias() : node.getName(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + + @Override + public Expression visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public Expression visitFunction(Function node, CatalystPlanContext context) { + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTransformer.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); + } + + @Override + public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { + Stack namedParseExpressions = new Stack<>(); + namedParseExpressions.addAll(context.getNamedParseExpressions()); + Expression expression = visitCase(node.getCaseValue(), context); + namedParseExpressions.add(expression); + context.setNamedParseExpressions(namedParseExpressions); + return expression; + } + + @Override + public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : FillNull"); + } + + @Override + public Expression visitInterval(Interval node, CatalystPlanContext context) { + node.getValue().accept(this, context); + Expression value = context.getNamedParseExpressions().pop(); + Expression[] intervalArgs = createIntervalArgs(node.getUnit(), value); + Expression interval = MakeInterval$.MODULE$.apply( + intervalArgs[0], intervalArgs[1], intervalArgs[2], intervalArgs[3], + intervalArgs[4], intervalArgs[5], intervalArgs[6], true); + return context.getNamedParseExpressions().push(interval); + } + + @Override + public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Dedupe"); + } + + @Override + public Expression visitIn(In node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression value = context.popNamedParseExpressions().get(); + List list = node.getValueList().stream().map( expression -> { + expression.accept(this, context); + return context.popNamedParseExpressions().get(); + }).collect(Collectors.toList()); + return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); + } + + @Override + public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public Expression visitCase(Case node, CatalystPlanContext context) { + Stack initialNameExpressions = new Stack<>(); + initialNameExpressions.addAll(context.getNamedParseExpressions()); + analyze(node.getElseClause(), context); + Expression elseValue = context.getNamedParseExpressions().pop(); + List> whens = new ArrayList<>(); + for (When when : node.getWhenClauses()) { + if (node.getCaseValue() == null) { + whens.add( + new Tuple2<>( + analyze(when.getCondition(), context), + analyze(when.getResult(), context) + ) + ); + } else { + // Merge case value and condition (compare value) into a single equal condition + Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); + whens.add( + new Tuple2<>( + analyze(compare, context), analyze(when.getResult(), context) + ) + ); + } + context.retainAllNamedParseExpressions(e -> e); + } + context.setNamedParseExpressions(initialNameExpressions); + return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); + } + + @Override + public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + + @Override + public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + visitExpressionList(node.getChild(), innerContext); + Seq values = innerContext.retainAllNamedParseExpressions(p -> p); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression inSubQuery = InSubquery$.MODULE$.apply( + values, + ListQuery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + -1, + seq(new java.util.ArrayList()), + Option.empty())); + return outerContext.getNamedParseExpressions().push(inSubQuery); + } + + @Override + public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + Option.empty()); + return context.getNamedParseExpressions().push(scalarSubQuery); + } + + @Override + public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression existsSubQuery = Exists$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty()); + return context.getNamedParseExpressions().push(existsSubQuery); + } + + @Override + public Expression visitBetween(Between node, CatalystPlanContext context) { + Expression value = analyze(node.getValue(), context); + Expression lower = analyze(node.getLowerBound(), context); + Expression upper = analyze(node.getUpperBound(), context); + context.retainAllNamedParseExpressions(p -> p); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); + } + + @Override + public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) { + analyze(node.getIpAddress(), context); + Expression ipAddressExpression = context.getNamedParseExpressions().pop(); + analyze(node.getCidrBlock(), context); + Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); + + ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddressExpression,cidrBlockExpression), + seq(), + Option.empty(), + Option.apply("cidr"), + false, + true); + + return context.getNamedParseExpressions().push(udf); + } + + private List visitExpressionList(List expressionList, CatalystPlanContext context) { + return expressionList.isEmpty() + ? emptyList() + : expressionList.stream().map(field -> analyze(field, context)) + .collect(Collectors.toList()); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 90df01e66..5d2fe986b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -91,11 +91,11 @@ import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.expression.function.SerializableUdf; -import org.opensearch.sql.ppl.utils.AggregatorTranslator; -import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; +import org.opensearch.sql.ppl.utils.AggregatorTransformer; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; -import org.opensearch.sql.ppl.utils.ParseStrategy; +import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; import scala.Option; @@ -115,7 +115,6 @@ import static java.util.List.of; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; -import static org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator.createIntervalArgs; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; @@ -138,20 +137,16 @@ */ public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { - private final ExpressionAnalyzer expressionAnalyzer; + private final CatalystExpressionVisitor expressionAnalyzer; public CatalystQueryPlanVisitor() { - this.expressionAnalyzer = new ExpressionAnalyzer(); + this.expressionAnalyzer = new CatalystExpressionVisitor(this); } public LogicalPlan visit(Statement plan, CatalystPlanContext context) { return plan.accept(this, context); } - - public LogicalPlan visitSubSearch(UnresolvedPlan plan, CatalystPlanContext context) { - return plan.accept(this, context); - } - + /** * Handle Query Statement. */ @@ -480,7 +475,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); String pattern = (String) node.getPattern().getValue(); - return ParseStrategy.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); + return ParseTransformer.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); } @Override @@ -574,343 +569,4 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { } } } - - /** - * Expression Analyzer. - */ - public class ExpressionAnalyzer extends AbstractNodeVisitor { - - public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { - return unresolved.accept(this, context); - } - - @Override - public Expression visitLiteral(Literal node, CatalystPlanContext context) { - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( - translate(node.getValue(), node.getType()), translate(node.getType()))); - } - - /** - * generic binary (And, Or, Xor , ...) arithmetic expression resolver - * - * @param node - * @param transformer - * @param context - * @return - */ - public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { - node.getLeft().accept(this, context); - Optional left = context.popNamedParseExpressions(); - node.getRight().accept(this, context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - return transformer.apply(left.get(), right.get()); - } else if (left.isPresent()) { - return context.getNamedParseExpressions().push(left.get()); - } else if (right.isPresent()) { - return context.getNamedParseExpressions().push(right.get()); - } - return null; - - } - - @Override - public Expression visitAnd(And node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); - } - - @Override - public Expression visitOr(Or node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); - } - - @Override - public Expression visitXor(Xor node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); - } - - @Override - public Expression visitNot(Not node, CatalystPlanContext context) { - node.getExpression().accept(this, context); - Optional arg = context.popNamedParseExpressions(); - return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); - } - - @Override - public Expression visitSpan(Span node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression field = (Expression) context.popNamedParseExpressions().get(); - node.getValue().accept(this, context); - Expression value = (Expression) context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); - } - - @Override - public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression arg = (Expression) context.popNamedParseExpressions().get(); - Expression aggregator = AggregatorTranslator.aggregator(node, arg); - return context.getNamedParseExpressions().push(aggregator); - } - - @Override - public Expression visitCompare(Compare node, CatalystPlanContext context) { - analyze(node.getLeft(), context); - Optional left = context.popNamedParseExpressions(); - analyze(node.getRight(), context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); - return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); - } - return null; - } - - @Override - public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { - List relation = findRelation(context.traversalContext()); - if (!relation.isEmpty()) { - Optional resolveField = resolveField(relation, node, context.getRelations()); - return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) - .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); - } - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - - /** - * Resolve the qualified name with subquery alias:
- * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
- * - tableName1.joinKey = subqueryAlias2.joinKey
- * - subqueryAlias1.joinKey = tableName2.joinKey
- */ - private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { - if (node.getPrefix().isPresent() && - context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { - if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) - .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) - .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - } - return null; - } - - @Override - public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { - return node.getChild().stream().map(expression -> - visitCompare((Compare) expression, context) - ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); - } - - @Override - public Expression visitAllFields(AllFields node, CatalystPlanContext context) { - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - return context.getNamedParseExpressions().peek(); - } - - @Override - public Expression visitAlias(Alias node, CatalystPlanContext context) { - node.getDelegated().accept(this, context); - Expression arg = context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, - node.getAlias() != null ? node.getAlias() : node.getName(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - } - - @Override - public Expression visitEval(Eval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Eval"); - } - - @Override - public Expression visitFunction(Function node, CatalystPlanContext context) { - List arguments = - node.getFuncArgs().stream() - .map( - unresolvedExpression -> { - var ret = analyze(unresolvedExpression, context); - if (ret == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", unresolvedExpression)); - } else { - return context.popNamedParseExpressions().get(); - } - }) - .collect(Collectors.toList()); - Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); - return context.getNamedParseExpressions().push(function); - } - - @Override - public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { - Stack namedParseExpressions = new Stack<>(); - namedParseExpressions.addAll(context.getNamedParseExpressions()); - Expression expression = visitCase(node.getCaseValue(), context); - namedParseExpressions.add(expression); - context.setNamedParseExpressions(namedParseExpressions); - return expression; - } - - @Override - public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : FillNull"); - } - - @Override - public Expression visitInterval(Interval node, CatalystPlanContext context) { - node.getValue().accept(this, context); - Expression value = context.getNamedParseExpressions().pop(); - Expression[] intervalArgs = createIntervalArgs(node.getUnit(), value); - Expression interval = MakeInterval$.MODULE$.apply( - intervalArgs[0], intervalArgs[1], intervalArgs[2], intervalArgs[3], - intervalArgs[4], intervalArgs[5], intervalArgs[6], true); - return context.getNamedParseExpressions().push(interval); - } - - @Override - public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Dedupe"); - } - - @Override - public Expression visitIn(In node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression value = context.popNamedParseExpressions().get(); - List list = node.getValueList().stream().map( expression -> { - expression.accept(this, context); - return context.popNamedParseExpressions().get(); - }).collect(Collectors.toList()); - return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); - } - - @Override - public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Kmeans"); - } - - @Override - public Expression visitCase(Case node, CatalystPlanContext context) { - Stack initialNameExpressions = new Stack<>(); - initialNameExpressions.addAll(context.getNamedParseExpressions()); - analyze(node.getElseClause(), context); - Expression elseValue = context.getNamedParseExpressions().pop(); - List> whens = new ArrayList<>(); - for (When when : node.getWhenClauses()) { - if (node.getCaseValue() == null) { - whens.add( - new Tuple2<>( - analyze(when.getCondition(), context), - analyze(when.getResult(), context) - ) - ); - } else { - // Merge case value and condition (compare value) into a single equal condition - Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); - whens.add( - new Tuple2<>( - analyze(compare, context), analyze(when.getResult(), context) - ) - ); - } - context.retainAllNamedParseExpressions(e -> e); - } - context.setNamedParseExpressions(initialNameExpressions); - return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); - } - - @Override - public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : RareTopN"); - } - - @Override - public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : WindowFunction"); - } - - @Override - public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - visitExpressionList(node.getChild(), innerContext); - Seq values = innerContext.retainAllNamedParseExpressions(p -> p); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression inSubQuery = InSubquery$.MODULE$.apply( - values, - ListQuery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - -1, - seq(new java.util.ArrayList()), - Option.empty())); - return outerContext.getNamedParseExpressions().push(inSubQuery); - } - - @Override - public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - Option.empty()); - return context.getNamedParseExpressions().push(scalarSubQuery); - } - - @Override - public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression existsSubQuery = Exists$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty()); - return context.getNamedParseExpressions().push(existsSubQuery); - } - - @Override - public Expression visitBetween(Between node, CatalystPlanContext context) { - Expression value = analyze(node.getValue(), context); - Expression lower = analyze(node.getLowerBound(), context); - Expression upper = analyze(node.getUpperBound(), context); - context.retainAllNamedParseExpressions(p -> p); - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); - } - - @Override - public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) { - analyze(node.getIpAddress(), context); - Expression ipAddressExpression = context.getNamedParseExpressions().pop(); - analyze(node.getCidrBlock(), context); - Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); - - ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddressExpression,cidrBlockExpression), - seq(), - Option.empty(), - Option.apply("cidr"), - false, - true); - - return context.getNamedParseExpressions().push(udf); - } - } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ed7717188..c69e9541e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -82,8 +82,8 @@ public class AstBuilder extends OpenSearchPPLParserBaseVisitor { */ private String query; - public AstBuilder(AstExpressionBuilder expressionBuilder, String query) { - this.expressionBuilder = expressionBuilder; + public AstBuilder(String query) { + this.expressionBuilder = new AstExpressionBuilder(this); this.query = query; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6eb72c91e..b6dfd0447 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -67,7 +67,6 @@ */ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor { - private static final int DEFAULT_TAKE_FUNCTION_SIZE_VALUE = 10; /** * The function name mapping between fronted and core engine. */ @@ -79,16 +78,10 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); @@ -63,7 +63,7 @@ static LogicalPlan retainOneDuplicateEventAndKeepEmpty( static LogicalPlan retainOneDuplicateEvent( Dedupe node, Seq dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); @@ -87,7 +87,7 @@ static LogicalPlan retainOneDuplicateEvent( static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { // Build isnull Filter for right @@ -137,7 +137,7 @@ static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( static LogicalPlan retainMultipleDuplicateEvents( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // Build isnotnull Filter Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); @@ -163,7 +163,7 @@ static LogicalPlan retainMultipleDuplicateEvents( return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); } - private static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNotNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNotNullExpressions = context.retainAllNamedParseExpressions( @@ -180,7 +180,7 @@ private static Expression buildIsNotNullFilterExpression(Dedupe node, Expression return isNotNullExpr; } - private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNullExpressions = context.retainAllNamedParseExpressions( diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java index 58ef15ea9..3673d96d6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.tree.Lookup; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; import org.opensearch.sql.ppl.CatalystQueryPlanVisitor; import scala.Option; @@ -32,7 +33,7 @@ public interface LookupTransformer { /** lookup mapping fields + input fields*/ static List buildLookupRelationProjectList( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List inputFields = new ArrayList<>(node.getInputFieldList()); if (inputFields.isEmpty()) { @@ -45,7 +46,7 @@ static List buildLookupRelationProjectList( static List buildProjectListFromFields( List fields, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { return fields.stream().map(field -> expressionAnalyzer.visitField(field, context)) .map(NamedExpression.class::cast) @@ -54,7 +55,7 @@ static List buildProjectListFromFields( static Expression buildLookupMappingCondition( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // only equi-join conditions are accepted in lookup command List equiConditions = new ArrayList<>(); @@ -81,7 +82,7 @@ static Expression buildLookupMappingCondition( static List buildOutputProjectList( Lookup node, Lookup.OutputStrategy strategy, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List outputProjectList = new ArrayList<>(); for (Map.Entry entry : node.getOutputCandidateMap().entrySet()) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java similarity index 97% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java index 8775d077b..eed7db228 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java @@ -8,7 +8,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Field; @@ -27,7 +26,7 @@ import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -public interface ParseStrategy { +public interface ParseTransformer { /** * transform the parse/grok/patterns command into a standard catalyst RegExpExtract expression * Since spark's RegExpExtract cant accept actual regExp group name we need to translate the group's name into its corresponding index diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala index c435af53d..ed498e98b 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -29,11 +29,8 @@ class PPLSyntaxParser extends Parser { object PlaneUtils { def plan(parser: PPLSyntaxParser, query: String): Statement = { - val astExpressionBuilder = new AstExpressionBuilder() - val astBuilder = new AstBuilder(astExpressionBuilder, query) - astExpressionBuilder.setAstBuilder(astBuilder) - val builder = - new AstStatementBuilder(astBuilder, AstStatementBuilder.StatementBuilderContext.builder()) - builder.visit(parser.parse(query)) + new AstStatementBuilder( + new AstBuilder(query), + AstStatementBuilder.StatementBuilderContext.builder()).visit(parser.parse(query)) } }