From e56874657c1ec5fa924c3fc01df689f369a9a2e4 Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Mon, 4 Nov 2024 17:02:46 +0800 Subject: [PATCH] Support parsing lambda function Signed-off-by: Heng Qian --- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 2 + .../sql/ast/AbstractNodeVisitor.java | 5 ++ .../sql/ast/expression/LambdaFunction.java | 48 +++++++++++++++++++ .../sql/ppl/CatalystExpressionVisitor.java | 34 +++++++++---- .../sql/ppl/parser/AstExpressionBuilder.java | 12 ++++- .../sql/ppl/utils/JavaToScalaTransformer.java | 28 +++++++++++ ...lPlanBasicQueriesTranslatorTestSuite.scala | 20 +++++++- 8 files changed, 138 insertions(+), 12 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 991a4dffe..80cfb9287 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -203,6 +203,7 @@ RT_SQR_PRTHS: ']'; SINGLE_QUOTE: '\''; DOUBLE_QUOTE: '"'; BACKTICK: '`'; +ARROW: '->'; // Operators. Bit diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index cd6fe5dc1..59305aec1 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -443,6 +443,8 @@ valueExpression | timestampFunction # timestampFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | ident ARROW expression # lambda + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda ; primaryExpression diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 525a0954c..189d9084a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -18,6 +18,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; @@ -183,6 +184,10 @@ public T visitFunction(Function node, C context) { return visitChildren(node, context); } + public T visitLambdaFunction(LambdaFunction node, C context) { + return visitChildren(node, context); + } + public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java new file mode 100644 index 000000000..e1ee755b8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of lambda function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class LambdaFunction extends UnresolvedExpression { + private final UnresolvedExpression function; + private final List funcArgs; + + @Override + public List getChild() { + List children = new ArrayList<>(); + children.add(function); + children.addAll(funcArgs); + return children; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLambdaFunction(this, context); + } + + @Override + public String toString() { + return String.format( + "(%s) -> %s", + funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")), + function.toString() + ); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index 69a89b83a..43b7ec9a6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -5,28 +5,25 @@ package org.opensearch.sql.ppl; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; -import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.InSubquery$; -import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.RowFrame$; import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; -import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; -import org.apache.spark.sql.catalyst.expressions.WindowExpression; -import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -38,7 +35,6 @@ import org.opensearch.sql.ast.expression.BinaryExpression; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -47,6 +43,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; @@ -61,14 +58,14 @@ import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN; -import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; -import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.SerializableUdf; import org.opensearch.sql.ppl.utils.AggregatorTransformer; import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.JavaToScalaTransformer; import scala.Option; +import scala.PartialFunction; import scala.Tuple2; import scala.collection.Seq; @@ -432,6 +429,25 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys return context.getNamedParseExpressions().push(udf); } + @Override + public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext context) { + PartialFunction transformer = JavaToScalaTransformer.toPartialFunction( + expr -> expr instanceof UnresolvedAttribute, + expr -> { + UnresolvedAttribute attr = (UnresolvedAttribute) expr; + return new UnresolvedNamedLambdaVariable(attr.nameParts()); + } + ); + Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer); + context.popNamedParseExpressions(); + List argsResult = node.getFuncArgs().stream() + .map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts()))) + .collect(Collectors.toList()); + org.apache.spark.sql.catalyst.expressions.LambdaFunction lambdaFunction = new org.apache.spark.sql.catalyst.expressions.LambdaFunction(functionResult, seq(argsResult), false); + context.getNamedParseExpressions().push(lambdaFunction); + return lambdaFunction; + } + private List visitExpressionList(List expressionList, CatalystPlanContext context) { return expressionList.isEmpty() ? emptyList() diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 5e0f0775d..8f3e10338 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -29,6 +29,7 @@ import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; @@ -43,8 +44,6 @@ import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; -import org.opensearch.sql.ast.tree.Trendline; -import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -429,6 +428,15 @@ public UnresolvedExpression visitTimestampFunctionCall( ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx)); } + @Override + public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { + + List arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect( + Collectors.toList()); + UnresolvedExpression function = visitExpression(ctx.expression()); + return new LambdaFunction(function, arguments); + } + private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java new file mode 100644 index 000000000..34e8b8460 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import scala.PartialFunction; +import scala.runtime.AbstractPartialFunction; + +public interface JavaToScalaTransformer { + static PartialFunction toPartialFunction( + java.util.function.Predicate isDefinedAt, + java.util.function.Function apply) { + return new AbstractPartialFunction() { + @Override + public boolean isDefinedAt(T t) { + return isDefinedAt.test(t); + } + + @Override + public T apply(T t) { + if (isDefinedAt.test(t)) return apply.apply(t); + else return t; + } + }; + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 2da93d5d8..f64403109 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -8,12 +8,13 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, GreaterThan, LambdaFunction, Literal, NamedExpression, SortOrder, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -396,4 +397,21 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite |""".stripMargin), context) } + + test("test lambda function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t | eval lambda = (x -> x > 0) """.stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val alias = Alias(lambda, "lambda")() + val evalProject = Project(Seq(UnresolvedStar(None), alias), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } }