From aaba489a75a7b4ef09c5ff55b111fc294864dccb Mon Sep 17 00:00:00 2001 From: qianheng Date: Tue, 5 Nov 2024 19:27:46 +0800 Subject: [PATCH] Support Lambda and add related array functions (#864) * json function enhancement Signed-off-by: Heng Qian * Add JavaToScalaTransformer Signed-off-by: Heng Qian * Apply scalafmtAll Signed-off-by: Heng Qian * Address comments Signed-off-by: Heng Qian * Add IT and change to use the same function name as spark Signed-off-by: Heng Qian * Address comments Signed-off-by: Heng Qian * Add document and separate lambda functions from json functions Signed-off-by: Heng Qian * Add lambda functions transform and reduce Signed-off-by: Heng Qian * polish lambda function document Signed-off-by: Heng Qian * polish lambda function document Signed-off-by: Heng Qian * Minor fix Signed-off-by: Heng Qian * Minor change to polish the documents Signed-off-by: Heng Qian --------- Signed-off-by: Heng Qian --- docs/ppl-lang/README.md | 4 +- docs/ppl-lang/functions/ppl-lambda.md | 187 ++++++++++++++++ .../FlintSparkPPLLambdaFunctionITSuite.scala | 132 +++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 14 +- .../src/main/antlr4/OpenSearchPPLParser.g4 | 11 + .../sql/ast/AbstractNodeVisitor.java | 5 + .../sql/ast/expression/LambdaFunction.java | 48 ++++ .../function/BuiltinFunctionName.java | 7 + .../sql/ppl/CatalystExpressionVisitor.java | 24 ++ .../sql/ppl/parser/AstExpressionBuilder.java | 10 + .../sql/ppl/utils/JavaToScalaTransformer.java | 29 +++ ...lPlanBasicQueriesTranslatorTestSuite.scala | 2 +- ...PlanJsonFunctionsTranslatorTestSuite.scala | 2 +- ...anLambdaFunctionsTranslatorTestSuite.scala | 211 ++++++++++++++++++ 14 files changed, 680 insertions(+), 6 deletions(-) create mode 100644 docs/ppl-lang/functions/ppl-lambda.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 6ba49b031..d78f4c030 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -91,6 +91,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`Cryptographic Functions`](functions/ppl-cryptographic.md) - [`IP Address Functions`](functions/ppl-ip.md) + + - [`Lambda Functions`](functions/ppl-lambda.md) --- ### PPL On Spark @@ -109,4 +111,4 @@ See samples of [PPL queries](PPL-Example-Commands.md) --- ### PPL Project Roadmap -[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) \ No newline at end of file +[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) diff --git a/docs/ppl-lang/functions/ppl-lambda.md b/docs/ppl-lang/functions/ppl-lambda.md new file mode 100644 index 000000000..cdb6f9e8f --- /dev/null +++ b/docs/ppl-lang/functions/ppl-lambda.md @@ -0,0 +1,187 @@ +## Lambda Functions + +### `FORALL` + +**Description** + +`forall(array, lambda)` Evaluates whether a lambda predicate holds for all elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + **Note:** The lambda expression can access the nested fields of the array elements. This applies to all lambda functions introduced in this document. + +Consider constructing the following array: + + array = [ + {"a":1, "b":1}, + {"a":-1, "b":2} + ] + +and perform lambda functions against the nested fields `a` or `b`. See the examples: + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + +### `EXISTS` + +**Description** + +`exists(array, lambda)` Evaluates whether a lambda predicate holds for one or more elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if at least one element in the array satisfies the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + +### `FILTER` + +**Description** + +`filter(array, lambda)` Filters the input array using the given lambda function. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains all elements in the input array that satisfy the lambda predicate. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [1, 2] | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [] | + +-----------+ + +### `TRANSFORM` + +**Description** + +`transform(array, lambda)` Transform elements in an array using the lambda transform function. The second argument implies the index of the element if using binary lambda function. This is similar to a `map` in functional programming. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains the result of applying the lambda transform function to each element in the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [2, 3, 4] | + +--------------+ + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [1, 3, 5] | + +--------------+ + +### `REDUCE` + +**Description** + +`reduce(array, start, merge_lambda, finish_lambda)` Applies a binary merge lambda function to a start value and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish lambda function. + +**Argument type:** ARRAY, ANY, LAMBDA, LAMBDA + +**Return type:** ANY + +The final result of applying the lambda functions to the start value and the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 6 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 16 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 60 | + +-----------+ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala new file mode 100644 index 000000000..f86502521 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions.{col, to_json} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLLambdaFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test forall()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame4) + } + + test("test exists()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame4) + + } + + test("test filter()") { + val frame = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 2, 1.1))), frame) + + val frame2 = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq())), frame2) + + val frame3 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + + assertSameRows(Seq(Row("""[{"a":1,"b":-1}]""")), frame3.select(to_json(col("result")))) + + val frame4 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""[]""")), frame4.select(to_json(col("result")))) + } + + test("test transform()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(2, 3, 4))), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 3, 5))), frame2) + } + + test("test reduce()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(6)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 1, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(7)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(60)), frame3) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 991a4dffe..fcec4d13f 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -203,6 +203,7 @@ RT_SQR_PRTHS: ']'; SINGLE_QUOTE: '\''; DOUBLE_QUOTE: '"'; BACKTICK: '`'; +ARROW: '->'; // Operators. Bit @@ -384,15 +385,22 @@ JSON_VALID: 'JSON_VALID'; //JSON_DELETE: 'JSON_DELETE'; //JSON_EXTEND: 'JSON_EXTEND'; //JSON_SET: 'JSON_SET'; -//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH'; -//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH'; -//JSON_ARRAY_FILTER: 'JSON_FILTER'; +//JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH'; +//JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH'; +//JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER'; //JSON_ARRAY_MAP: 'JSON_ARRAY_MAP'; //JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE'; // COLLECTION FUNCTIONS ARRAY: 'ARRAY'; +// LAMBDA FUNCTIONS +//EXISTS: 'EXISTS'; +FORALL: 'FORALL'; +FILTER: 'FILTER'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; + // BOOL FUNCTIONS LIKE: 'LIKE'; ISNULL: 'ISNULL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index cd6fe5dc1..b7f293a4a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -443,6 +443,8 @@ valueExpression | timestampFunction # timestampFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | ident ARROW expression # lambda + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda ; primaryExpression @@ -568,6 +570,7 @@ evalFunctionName | cryptographicFunctionName | jsonFunctionName | collectionFunctionName + | lambdaFunctionName ; functionArgs @@ -875,6 +878,14 @@ collectionFunctionName : ARRAY ; +lambdaFunctionName + : FORALL + | EXISTS + | FILTER + | TRANSFORM + | REDUCE + ; + positionFunctionName : POSITION ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 525a0954c..189d9084a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -18,6 +18,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; @@ -183,6 +184,10 @@ public T visitFunction(Function node, C context) { return visitChildren(node, context); } + public T visitLambdaFunction(LambdaFunction node, C context) { + return visitChildren(node, context); + } + public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java new file mode 100644 index 000000000..e1ee755b8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of lambda function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class LambdaFunction extends UnresolvedExpression { + private final UnresolvedExpression function; + private final List funcArgs; + + @Override + public List getChild() { + List children = new ArrayList<>(); + children.add(function); + children.addAll(funcArgs); + return children; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLambdaFunction(this, context); + } + + @Override + public String toString() { + return String.format( + "(%s) -> %s", + funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")), + function.toString() + ); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index d81dc7ce4..13b5c20ef 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -229,6 +229,13 @@ public enum BuiltinFunctionName { /** COLLECTION Functions **/ ARRAY(FunctionName.of("array")), + /** LAMBDA Functions **/ + ARRAY_FORALL(FunctionName.of("forall")), + ARRAY_EXISTS(FunctionName.of("exists")), + ARRAY_FILTER(FunctionName.of("filter")), + ARRAY_TRANSFORM(FunctionName.of("transform")), + ARRAY_AGGREGATE(FunctionName.of("reduce")), + /** NULL Test. */ IS_NULL(FunctionName.of("is null")), IS_NOT_NULL(FunctionName.of("is not null")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index 69a89b83a..a0506ceee 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; @@ -15,6 +16,7 @@ import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LambdaFunction$; import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; @@ -24,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.RowFrame$; import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$; import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; import org.apache.spark.sql.catalyst.expressions.WindowExpression; import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; @@ -47,6 +51,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; @@ -68,7 +73,9 @@ import org.opensearch.sql.ppl.utils.AggregatorTransformer; import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.JavaToScalaTransformer; import scala.Option; +import scala.PartialFunction; import scala.Tuple2; import scala.collection.Seq; @@ -432,6 +439,23 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys return context.getNamedParseExpressions().push(udf); } + @Override + public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext context) { + PartialFunction transformer = JavaToScalaTransformer.toPartialFunction( + expr -> expr instanceof UnresolvedAttribute, + expr -> { + UnresolvedAttribute attr = (UnresolvedAttribute) expr; + return new UnresolvedNamedLambdaVariable(attr.nameParts()); + } + ); + Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer); + context.popNamedParseExpressions(); + List argsResult = node.getFuncArgs().stream() + .map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts()))) + .collect(Collectors.toList()); + return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false)); + } + private List visitExpressionList(List expressionList, CatalystPlanContext context) { return expressionList.isEmpty() ? emptyList() diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 5e0f0775d..4b7c8a1c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -29,6 +29,7 @@ import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; @@ -429,6 +430,15 @@ public UnresolvedExpression visitTimestampFunctionCall( ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx)); } + @Override + public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { + + List arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect( + Collectors.toList()); + UnresolvedExpression function = visitExpression(ctx.expression()); + return new LambdaFunction(function, arguments); + } + private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java new file mode 100644 index 000000000..40246d7c9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + + +import scala.PartialFunction; +import scala.runtime.AbstractPartialFunction; + +public interface JavaToScalaTransformer { + static PartialFunction toPartialFunction( + java.util.function.Predicate isDefinedAt, + java.util.function.Function apply) { + return new AbstractPartialFunction() { + @Override + public boolean isDefinedAt(T t) { + return isDefinedAt.test(t); + } + + @Override + public T apply(T t) { + if (isDefinedAt.test(t)) return apply.apply(t); + else return t; + } + }; + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 2da93d5d8..2a569dbdf 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, GreaterThan, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index f5dfc4ec8..216c0f232 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..9c3c1c8a0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, LambdaFunction, Literal, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project + +class PPLLogicalPlanLambdaFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test forall()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = forall(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("forall", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test exits()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = exists(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("exists", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test filter()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = filter(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("filter", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform()") { + val context = new CatalystPlanContext + // test single argument of lambda + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, x -> x + 1)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction("+", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(1)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform() - test binary lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - without finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - with finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y, x -> x * 10)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val finishLambda = LambdaFunction( + UnresolvedFunction("*", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(10)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda, finishLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } +}