Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Lambda and add related array functions #864

Merged
merged 13 commits into from
Nov 5, 2024
4 changes: 3 additions & 1 deletion docs/ppl-lang/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md).
- [`Cryptographic Functions`](functions/ppl-cryptographic.md)

- [`IP Address Functions`](functions/ppl-ip.md)

- [`Lambda Functions`](functions/ppl-lambda.md)

---
### PPL On Spark
Expand All @@ -109,4 +111,4 @@ See samples of [PPL queries](PPL-Example-Commands.md)

---
### PPL Project Roadmap
[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214)
[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214)
178 changes: 178 additions & 0 deletions docs/ppl-lang/functions/ppl-lambda.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
## Lambda Functions

### `FORALL`

**Description**

`forall(array, lambda)` Evaluates whether a lambda predicate holds for all elements in the array.

**Argument type:** ARRAY, LAMBDA

**Return type:** BOOLEAN

Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherwise `FALSE`.

Example:

os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| false |
+-----------+

os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| true |
+-----------+

**Note:** The lambda expression can access the nested fields of the array elements. This applies to all lambda functions introduced in this document. See the examples below:

os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| false |
+-----------+

os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| true |
+-----------+

### `EXISTS`

**Description**

`exists(array, lambda)` Evaluates whether a lambda predicate holds for one or more elements in the array.

**Argument type:** ARRAY, LAMBDA

**Return type:** BOOLEAN

Returns `TRUE` if at least one element in the array satisfies the lambda predicate, otherwise `FALSE`.

Example:

os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| true |
+-----------+

os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| false |
+-----------+


### `FILTER`

**Description**

`filter(array, lambda)` Filters the input array using the given lambda function.

**Argument type:** ARRAY, LAMBDA

**Return type:** ARRAY

An ARRAY that contains all elements in the input array that satisfy the lambda predicate.

Example:

os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| [1, 2] |
+-----------+

os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| [] |
+-----------+

### `TRANSFORM`

**Description**

`transform(array, lambda)` Transform elements in an array using the lambda transform function. The second argument implies the index of the element if using binary lambda function. This is similar to a `map` in functional programming.

**Argument type:** ARRAY, LAMBDA

**Return type:** ARRAY

An ARRAY that contains the result of applying the lambda transform function to each element in the input array.

Example:

os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result
fetched rows / total rows = 1/1
+--------------+
| result |
+--------------+
| [2, 3, 4] |
+--------------+

os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result
fetched rows / total rows = 1/1
+--------------+
| result |
+--------------+
| [1, 3, 5] |
+--------------+

### `REDUCE`

**Description**

`reduce(array, start, merge_lambda, finish_lambda)` Applies a binary merge lambda function to a start value and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish lambda function.

**Argument type:** ARRAY, ANY, LAMBDA, LAMBDA

**Return type:** ANY

The final result of applying the lambda functions to the start value and the input array.

Example:

os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| 6 |
+-----------+

os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| 16 |
+-----------+

os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result
fetched rows / total rows = 1/1
+-----------+
| result |
+-----------+
| 60 |
+-----------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.functions.{col, to_json}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLLambdaFunctionITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createNullableJsonContentTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test forall()") {
val frame = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame)

val frame2 = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame2)

val frame3 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame3)

val frame4 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame4)
}

test("test exists()") {
val frame = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame)

val frame2 = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame2)

val frame3 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.a > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame3)

val frame4 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame4)

}

test("test filter()") {
val frame = sql(s"""
| source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(Seq(1, 2, 1.1))), frame)

val frame2 = sql(s"""
| source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(Seq())), frame2)

val frame3 = sql(s"""
| source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.a > 0) | head 1 | fields result
| """.stripMargin)

assertSameRows(Seq(Row("""[{"a":1,"b":-1}]""")), frame3.select(to_json(col("result"))))

val frame4 = sql(s"""
| source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.b > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row("""[]""")), frame4.select(to_json(col("result"))))
}

test("test transform()") {
val frame = sql(s"""
| source = $testTable | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(Seq(2, 3, 4))), frame)

val frame2 = sql(s"""
| source = $testTable | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(Seq(1, 3, 5))), frame2)
}

test("test reduce()") {
val frame = sql(s"""
| source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(6)), frame)

val frame2 = sql(s"""
| source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 1, (acc, x) -> acc + x) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(7)), frame2)

val frame3 = sql(s"""
| source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(60)), frame3)
}
}
14 changes: 11 additions & 3 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ RT_SQR_PRTHS: ']';
SINGLE_QUOTE: '\'';
DOUBLE_QUOTE: '"';
BACKTICK: '`';
ARROW: '->';

// Operators. Bit

Expand Down Expand Up @@ -384,15 +385,22 @@ JSON_VALID: 'JSON_VALID';
//JSON_DELETE: 'JSON_DELETE';
//JSON_EXTEND: 'JSON_EXTEND';
//JSON_SET: 'JSON_SET';
//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH';
//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH';
//JSON_ARRAY_FILTER: 'JSON_FILTER';
//JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH';
//JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH';
//JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER';
//JSON_ARRAY_MAP: 'JSON_ARRAY_MAP';
//JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE';

// COLLECTION FUNCTIONS
ARRAY: 'ARRAY';

// LAMBDA FUNCTIONS
//EXISTS: 'EXISTS';
FORALL: 'FORALL';
FILTER: 'FILTER';
TRANSFORM: 'TRANSFORM';
REDUCE: 'REDUCE';

// BOOL FUNCTIONS
LIKE: 'LIKE';
ISNULL: 'ISNULL';
Expand Down
11 changes: 11 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ valueExpression
| timestampFunction # timestampFunctionCall
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
| LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr
| ident ARROW expression # lambda
| LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda
;

primaryExpression
Expand Down Expand Up @@ -568,6 +570,7 @@ evalFunctionName
| cryptographicFunctionName
| jsonFunctionName
| collectionFunctionName
| lambdaFunctionName
;

functionArgs
Expand Down Expand Up @@ -875,6 +878,14 @@ collectionFunctionName
: ARRAY
;

lambdaFunctionName
: FORALL
| EXISTS
| FILTER
| TRANSFORM
| REDUCE
;

positionFunctionName
: POSITION
;
Expand Down
Loading
Loading