diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 6ba49b031..d78f4c030 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -91,6 +91,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`Cryptographic Functions`](functions/ppl-cryptographic.md) - [`IP Address Functions`](functions/ppl-ip.md) + + - [`Lambda Functions`](functions/ppl-lambda.md) --- ### PPL On Spark @@ -109,4 +111,4 @@ See samples of [PPL queries](PPL-Example-Commands.md) --- ### PPL Project Roadmap -[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) \ No newline at end of file +[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) diff --git a/docs/ppl-lang/functions/ppl-lambda.md b/docs/ppl-lang/functions/ppl-lambda.md new file mode 100644 index 000000000..cdb6f9e8f --- /dev/null +++ b/docs/ppl-lang/functions/ppl-lambda.md @@ -0,0 +1,187 @@ +## Lambda Functions + +### `FORALL` + +**Description** + +`forall(array, lambda)` Evaluates whether a lambda predicate holds for all elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + **Note:** The lambda expression can access the nested fields of the array elements. This applies to all lambda functions introduced in this document. + +Consider constructing the following array: + + array = [ + {"a":1, "b":1}, + {"a":-1, "b":2} + ] + +and perform lambda functions against the nested fields `a` or `b`. See the examples: + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + +### `EXISTS` + +**Description** + +`exists(array, lambda)` Evaluates whether a lambda predicate holds for one or more elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if at least one element in the array satisfies the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + +### `FILTER` + +**Description** + +`filter(array, lambda)` Filters the input array using the given lambda function. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains all elements in the input array that satisfy the lambda predicate. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [1, 2] | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [] | + +-----------+ + +### `TRANSFORM` + +**Description** + +`transform(array, lambda)` Transform elements in an array using the lambda transform function. The second argument implies the index of the element if using binary lambda function. This is similar to a `map` in functional programming. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains the result of applying the lambda transform function to each element in the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [2, 3, 4] | + +--------------+ + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [1, 3, 5] | + +--------------+ + +### `REDUCE` + +**Description** + +`reduce(array, start, merge_lambda, finish_lambda)` Applies a binary merge lambda function to a start value and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish lambda function. + +**Argument type:** ARRAY, ANY, LAMBDA, LAMBDA + +**Return type:** ANY + +The final result of applying the lambda functions to the start value and the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 6 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 16 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 60 | + +-----------+ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala new file mode 100644 index 000000000..f86502521 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions.{col, to_json} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLLambdaFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test forall()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame4) + } + + test("test exists()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame4) + + } + + test("test filter()") { + val frame = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 2, 1.1))), frame) + + val frame2 = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq())), frame2) + + val frame3 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + + assertSameRows(Seq(Row("""[{"a":1,"b":-1}]""")), frame3.select(to_json(col("result")))) + + val frame4 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""[]""")), frame4.select(to_json(col("result")))) + } + + test("test transform()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(2, 3, 4))), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 3, 5))), frame2) + } + + test("test reduce()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(6)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 1, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(7)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(60)), frame3) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 38bd1f9d2..1a4d717b9 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -205,6 +205,7 @@ RT_SQR_PRTHS: ']'; SINGLE_QUOTE: '\''; DOUBLE_QUOTE: '"'; BACKTICK: '`'; +ARROW: '->'; // Operators. Bit @@ -386,15 +387,22 @@ JSON_VALID: 'JSON_VALID'; //JSON_DELETE: 'JSON_DELETE'; //JSON_EXTEND: 'JSON_EXTEND'; //JSON_SET: 'JSON_SET'; -//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH'; -//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH'; -//JSON_ARRAY_FILTER: 'JSON_FILTER'; +//JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH'; +//JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH'; +//JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER'; //JSON_ARRAY_MAP: 'JSON_ARRAY_MAP'; //JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE'; // COLLECTION FUNCTIONS ARRAY: 'ARRAY'; +// LAMBDA FUNCTIONS +//EXISTS: 'EXISTS'; +FORALL: 'FORALL'; +FILTER: 'FILTER'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; + // BOOL FUNCTIONS LIKE: 'LIKE'; ISNULL: 'ISNULL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index a55d4fe14..1cc701776 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -451,6 +451,8 @@ valueExpression | geoipFunction # geoFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | ident ARROW expression # lambda + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda ; primaryExpression @@ -582,6 +584,7 @@ evalFunctionName | jsonFunctionName | collectionFunctionName | geoipFunctionName + | lambdaFunctionName ; functionArgs @@ -889,6 +892,14 @@ collectionFunctionName : ARRAY ; +lambdaFunctionName + : FORALL + | EXISTS + | FILTER + | TRANSFORM + | REDUCE + ; + geoipFunctionName : GEOIP ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 525a0954c..189d9084a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -18,6 +18,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; @@ -183,6 +184,10 @@ public T visitFunction(Function node, C context) { return visitChildren(node, context); } + public T visitLambdaFunction(LambdaFunction node, C context) { + return visitChildren(node, context); + } + public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java new file mode 100644 index 000000000..e1ee755b8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of lambda function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class LambdaFunction extends UnresolvedExpression { + private final UnresolvedExpression function; + private final List funcArgs; + + @Override + public List getChild() { + List children = new ArrayList<>(); + children.add(function); + children.addAll(funcArgs); + return children; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLambdaFunction(this, context); + } + + @Override + public String toString() { + return String.format( + "(%s) -> %s", + funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")), + function.toString() + ); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index d81dc7ce4..746fcc099 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -206,7 +206,8 @@ public enum BuiltinFunctionName { SUBSTRING(FunctionName.of("substring")), TRIM(FunctionName.of("trim")), UPPER(FunctionName.of("upper")), - + /** GEOSPATIAL Functions. */ + GEOIP(FunctionName.of("geoip")), /** JSON Functions. */ // If the function argument is a valid JSON, return itself, or return NULL JSON(FunctionName.of("json")), @@ -229,6 +230,13 @@ public enum BuiltinFunctionName { /** COLLECTION Functions **/ ARRAY(FunctionName.of("array")), + /** LAMBDA Functions **/ + ARRAY_FORALL(FunctionName.of("forall")), + ARRAY_EXISTS(FunctionName.of("exists")), + ARRAY_FILTER(FunctionName.of("filter")), + ARRAY_TRANSFORM(FunctionName.of("transform")), + ARRAY_AGGREGATE(FunctionName.of("reduce")), + /** NULL Test. */ IS_NULL(FunctionName.of("is null")), IS_NOT_NULL(FunctionName.of("is not null")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 5e0f0775d..4b7c8a1c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -29,6 +29,7 @@ import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; @@ -429,6 +430,15 @@ public UnresolvedExpression visitTimestampFunctionCall( ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx)); } + @Override + public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { + + List arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect( + Collectors.toList()); + UnresolvedExpression function = visitExpression(ctx.expression()); + return new LambdaFunction(function, arguments); + } + private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java new file mode 100644 index 000000000..40246d7c9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + + +import scala.PartialFunction; +import scala.runtime.AbstractPartialFunction; + +public interface JavaToScalaTransformer { + static PartialFunction toPartialFunction( + java.util.function.Predicate isDefinedAt, + java.util.function.Function apply) { + return new AbstractPartialFunction() { + @Override + public boolean isDefinedAt(T t) { + return isDefinedAt.test(t); + } + + @Override + public T apply(T t) { + if (isDefinedAt.test(t)) return apply.apply(t); + else return t; + } + }; + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 2da93d5d8..bede3953a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, GreaterThan, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -27,6 +27,27 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() + test("test geoip command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = t | where isV6 = false | eval country = geoip(ip_field, 'country')"), context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) + comparePlans(expectedPlan, logPlan, false) + + } + test("test geoip command lat / lon") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = t | eval lat = geoip(ip_field, 'lat'), lon = geoip(ip_field, 'lon')"), context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) + comparePlans(expectedPlan, logPlan, false) + + } + test("test error describe clause") { val context = new CatalystPlanContext val thrown = intercept[IllegalArgumentException] { diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index f5dfc4ec8..216c0f232 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..9c3c1c8a0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, LambdaFunction, Literal, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project + +class PPLLogicalPlanLambdaFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test forall()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = forall(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("forall", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test exits()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = exists(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("exists", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test filter()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = filter(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("filter", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform()") { + val context = new CatalystPlanContext + // test single argument of lambda + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, x -> x + 1)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction("+", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(1)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform() - test binary lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - without finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - with finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y, x -> x * 10)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val finishLambda = LambdaFunction( + UnresolvedFunction("*", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(10)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda, finishLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala new file mode 100644 index 000000000..213f201cc --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DataTypes + +class PPLLogicalPlanParseCidrmatchTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test cidrmatch for ipv4 for 192.168.1.0/24") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("192.168.1.0/24") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(false)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(false)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field projected") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and cidrmatch(ipAddress, '2003:db8::/32') | fields ip"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedAttribute("ip")), + Filter(And(filterIpv6, cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field bool respond for each ip") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true | eval inRange = case(cidrmatch(ipAddress, '2003:db8::/32'), 'in' else 'out') | fields ip, inRange"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterClause = Filter(filterIpv6, UnresolvedRelation(Seq("t"))) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val equalTo = EqualTo(Literal(true), cidr) + val caseFunction = CaseWhen(Seq((equalTo, Literal("in"))), Literal("out")) + val aliasStatusCategory = Alias(caseFunction, "inRange")() + val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory) + val evalProject = Project(evalProjectList, filterClause) + + val expectedPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("inRange")), evalProject) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + +}