diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 3efca3205..e80f8c906 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -58,8 +58,16 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where a not in (1, 2, 3) | fields a,b,c` - `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' -- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` +- `source = table | trendline sma(2, temperature) as temp_trend` + +#### **IP related queries** +[See additional command details](functions/ppl-ip.md) + +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')` +- `source = table | where isV6 = true | eval inRange = case(cidrmatch(ipAddress, '2003:db8::/32'), 'in' else 'out') | fields ip, inRange` ```sql source = table | eval status_category = @@ -122,6 +130,15 @@ Assumptions: `a`, `b`, `c`, `d`, `e` are existing fields in `table` - `source = table | fillnull using a = 101, b = 102` - `source = table | fillnull using a = concat(b, c), d = 2 * pi() * e` +### Flatten +[See additional command details](ppl-flatten-command.md) +Assumptions: `bridges`, `coor` are existing fields in `table`, and the field's types are `struct` or `array>` +- `source = table | flatten bridges` +- `source = table | flatten coor` +- `source = table | flatten bridges | flatten coor` +- `source = table | fields bridges | flatten bridges` +- `source = table | fields country, bridges | flatten bridges | fields country, length | stats avg(length) as avg by country` + ```sql source = table | eval e = eval status_category = case(a >= 200 AND a < 300, 'Success', @@ -289,7 +306,11 @@ source = table | where ispresent(a) | - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` - +- `source = table1 | inner join on table1.a = table2.a table2 | fields table1.a, table2.a, table1.b, table1.c` (directly refer table name) +- `source = table1 | inner join on a = c table2 | fields a, b, c, d` (ignore side aliases as long as no ambiguous) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields l.a, r.a` (side alias overrides table alias) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields t1.a, t2.a` (error, side alias overrides table alias) +- `source = table1 | join left = l right = r on l.a = r.a [ source = table2 ] as s | fields l.a, s.a` (error, side alias overrides subquery alias) #### **Lookup** [See additional command details](ppl-lookup-command.md) diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 8d9b86eda..d78f4c030 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -31,6 +31,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`describe command`](PPL-Example-Commands.md/#describe) - [`fillnull command`](ppl-fillnull-command.md) + + - [`flatten command`](ppl-flatten-command.md) - [`eval command`](ppl-eval-command.md) @@ -67,7 +69,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`subquery commands`](ppl-subquery-command.md) - [`correlation commands`](ppl-correlation-command.md) - + + - [`trendline commands`](ppl-trendline-command.md) * **Functions** @@ -88,6 +91,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`Cryptographic Functions`](functions/ppl-cryptographic.md) - [`IP Address Functions`](functions/ppl-ip.md) + + - [`Lambda Functions`](functions/ppl-lambda.md) --- ### PPL On Spark @@ -106,4 +111,4 @@ See samples of [PPL queries](PPL-Example-Commands.md) --- ### PPL Project Roadmap -[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) \ No newline at end of file +[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) diff --git a/docs/ppl-lang/functions/ppl-json.md b/docs/ppl-lang/functions/ppl-json.md index 1953e8c70..5b26ee427 100644 --- a/docs/ppl-lang/functions/ppl-json.md +++ b/docs/ppl-lang/functions/ppl-json.md @@ -4,11 +4,11 @@ **Description** -`json(value)` Evaluates whether a value can be parsed as JSON. Returns the json string if valid, null otherwise. +`json(value)` Evaluates whether a string can be parsed as JSON format. Returns the string value if valid, null otherwise. -**Argument type:** STRING/JSON_ARRAY/JSON_OBJECT +**Argument type:** STRING -**Return type:** STRING +**Return type:** STRING/NULL A STRING expression of a valid JSON object format. @@ -47,7 +47,7 @@ A StructType expression of a valid JSON object. Example: - os> source=people | eval result = json(json_object('key', 123.45)) | fields result + os> source=people | eval result = json_object('key', 123.45) | fields result fetched rows / total rows = 1/1 +------------------+ | result | @@ -55,7 +55,7 @@ Example: | {"key":123.45} | +------------------+ - os> source=people | eval result = json(json_object('outer', json_object('inner', 123.45))) | fields result + os> source=people | eval result = json_object('outer', json_object('inner', 123.45)) | fields result fetched rows / total rows = 1/1 +------------------------------+ | result | @@ -81,13 +81,13 @@ Example: os> source=people | eval `json_array` = json_array(1, 2, 0, -1, 1.1, -0.11) fetched rows / total rows = 1/1 - +----------------------------+ - | json_array | - +----------------------------+ - | 1.0,2.0,0.0,-1.0,1.1,-0.11 | - +----------------------------+ + +------------------------------+ + | json_array | + +------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +------------------------------+ - os> source=people | eval `json_array_object` = json(json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11))) + os> source=people | eval `json_array_object` = json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11)) fetched rows / total rows = 1/1 +----------------------------------------+ | json_array_object | @@ -95,15 +95,44 @@ Example: | {"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]} | +----------------------------------------+ +### `TO_JSON_STRING` + +**Description** + +`to_json_string(jsonObject)` Returns a JSON string with a given json object value. + +**Argument type:** JSON_OBJECT (Spark StructType/ArrayType) + +**Return type:** STRING + +Example: + + os> source=people | eval `json_string` = to_json_string(json_array(1, 2, 0, -1, 1.1, -0.11)) | fields json_string + fetched rows / total rows = 1/1 + +--------------------------------+ + | json_string | + +--------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +--------------------------------+ + + os> source=people | eval `json_string` = to_json_string(json_object('key', 123.45)) | fields json_string + fetched rows / total rows = 1/1 + +-----------------+ + | json_string | + +-----------------+ + | {'key', 123.45} | + +-----------------+ + + ### `JSON_ARRAY_LENGTH` **Description** -`json_array_length(jsonArray)` Returns the number of elements in the outermost JSON array. +`json_array_length(jsonArrayString)` Returns the number of elements in the outermost JSON array string. -**Argument type:** STRING/JSON_ARRAY +**Argument type:** STRING -A STRING expression of a valid JSON array format, or JSON_ARRAY object. +A STRING expression of a valid JSON array format. **Return type:** INTEGER @@ -119,6 +148,21 @@ Example: | 4 | 5 | null | +-----------+-----------+-------------+ + +### `ARRAY_LENGTH` + +**Description** + +`array_length(jsonArray)` Returns the number of elements in the outermost array. + +**Argument type:** ARRAY + +ARRAY or JSON_ARRAY object. + +**Return type:** INTEGER + +Example: + os> source=people | eval `json_array` = json_array_length(json_array(1,2,3,4)), `empty_array` = json_array_length(json_array()) fetched rows / total rows = 1/1 +--------------+---------------+ @@ -127,6 +171,7 @@ Example: | 4 | 0 | +--------------+---------------+ + ### `JSON_EXTRACT` **Description** diff --git a/docs/ppl-lang/functions/ppl-lambda.md b/docs/ppl-lang/functions/ppl-lambda.md new file mode 100644 index 000000000..cdb6f9e8f --- /dev/null +++ b/docs/ppl-lang/functions/ppl-lambda.md @@ -0,0 +1,187 @@ +## Lambda Functions + +### `FORALL` + +**Description** + +`forall(array, lambda)` Evaluates whether a lambda predicate holds for all elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + **Note:** The lambda expression can access the nested fields of the array elements. This applies to all lambda functions introduced in this document. + +Consider constructing the following array: + + array = [ + {"a":1, "b":1}, + {"a":-1, "b":2} + ] + +and perform lambda functions against the nested fields `a` or `b`. See the examples: + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + +### `EXISTS` + +**Description** + +`exists(array, lambda)` Evaluates whether a lambda predicate holds for one or more elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if at least one element in the array satisfies the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + +### `FILTER` + +**Description** + +`filter(array, lambda)` Filters the input array using the given lambda function. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains all elements in the input array that satisfy the lambda predicate. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [1, 2] | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [] | + +-----------+ + +### `TRANSFORM` + +**Description** + +`transform(array, lambda)` Transform elements in an array using the lambda transform function. The second argument implies the index of the element if using binary lambda function. This is similar to a `map` in functional programming. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains the result of applying the lambda transform function to each element in the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [2, 3, 4] | + +--------------+ + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [1, 3, 5] | + +--------------+ + +### `REDUCE` + +**Description** + +`reduce(array, start, merge_lambda, finish_lambda)` Applies a binary merge lambda function to a start value and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish lambda function. + +**Argument type:** ARRAY, ANY, LAMBDA, LAMBDA + +**Return type:** ANY + +The final result of applying the lambda functions to the start value and the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 6 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 16 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 60 | + +-----------+ diff --git a/docs/ppl-lang/ppl-flatten-command.md b/docs/ppl-lang/ppl-flatten-command.md new file mode 100644 index 000000000..4c1ae5d0d --- /dev/null +++ b/docs/ppl-lang/ppl-flatten-command.md @@ -0,0 +1,90 @@ +## PPL `flatten` command + +### Description +Using `flatten` command to flatten a field of type: +- `struct` +- `array>` + + +### Syntax +`flatten ` + +* field: to be flattened. The field must be of supported type. + +### Test table +#### Schema +| col\_name | data\_type | +|-----------|-------------------------------------------------| +| \_time | string | +| bridges | array\\> | +| city | string | +| coor | struct\ | +| country | string | +#### Data +| \_time | bridges | city | coor | country | +|---------------------|----------------------------------------------|---------|------------------------|---------------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | {35, 51.5074, -0.1278} | England | +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | {35, 48.8566, 2.3522} | France | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | {2, 45.4408, 12.3155} | Italy | +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | {200, 50.0755, 14.4378}| Czech Republic| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| {96, 47.4979, 19.0402} | Hungary | +| 1990-09-13T12:00:00 | NULL | Warsaw | NULL | Poland | + + + +### Example 1: flatten struct +This example shows how to flatten a struct field. +PPL query: + - `source=table | flatten coor` + +| \_time | bridges | city | country | alt | lat | long | +|---------------------|----------------------------------------------|---------|---------------|-----|--------|--------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | England | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | France | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | Italy | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | Czech Republic| 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| Hungary | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | NULL | Warsaw | Poland | NULL| NULL | NULL | + + + +### Example 2: flatten array + +The example shows how to flatten an array of struct fields. + +PPL query: + - `source=table | flatten bridges` + +| \_time | city | coor | country | length | name | +|---------------------|---------|------------------------|---------------|--------|-------------------| +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 801 | Tower Bridge | +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 928 | London Bridge | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 232 | Pont Neuf | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 160 | Pont Alexandre III| +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 48 | Rialto Bridge | +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 11 | Bridge of Sighs | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 516 | Charles Bridge | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 343 | Legion Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 375 | Chain Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 333 | Liberty Bridge | +| 1990-09-13T12:00:00 | Warsaw | NULL | Poland | NULL | NULL | + + +### Example 3: flatten array and struct +This example shows how to flatten multiple fields. +PPL query: + - `source=table | flatten bridges | flatten coor` + +| \_time | city | country | length | name | alt | lat | long | +|---------------------|---------|---------------|--------|-------------------|------|--------|--------| +| 2024-09-13T12:00:00 | London | England | 801 | Tower Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | London | England | 928 | London Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | Paris | France | 232 | Pont Neuf | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Paris | France | 160 | Pont Alexandre III| 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Venice | Italy | 48 | Rialto Bridge | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Venice | Italy | 11 | Bridge of Sighs | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 516 | Charles Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 343 | Legion Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Budapest| Hungary | 375 | Chain Bridge | 96 | 47.4979| 19.0402| +| 2024-09-13T12:00:00 | Budapest| Hungary | 333 | Liberty Bridge | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | Warsaw | Poland | NULL | NULL | NULL | NULL | NULL | \ No newline at end of file diff --git a/docs/ppl-lang/ppl-join-command.md b/docs/ppl-lang/ppl-join-command.md index 525373f7c..b374bce5f 100644 --- a/docs/ppl-lang/ppl-join-command.md +++ b/docs/ppl-lang/ppl-join-command.md @@ -65,8 +65,8 @@ WHERE t1.serviceName = `order` SEARCH source= | | [joinType] JOIN - leftAlias - rightAlias + [leftAlias] + [rightAlias] [joinHints] ON joinCriteria @@ -79,12 +79,12 @@ SEARCH source= **leftAlias** - Syntax: `left = ` -- Required +- Optional - Description: The subquery alias to use with the left join side, to avoid ambiguous naming. **rightAlias** - Syntax: `right = ` -- Required +- Optional - Description: The subquery alias to use with the right join side, to avoid ambiguous naming. **joinHints** @@ -138,11 +138,11 @@ Rewritten by PPL Join query: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o - ON c.c_custkey = o.o_custkey AND o_comment NOT LIKE '%unusual%packages%' +| LEFT OUTER JOIN + ON c_custkey = o_custkey AND o_comment NOT LIKE '%unusual%packages%' orders -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` _- **Limitation: sub-searches is unsupported in join right side**_ @@ -151,14 +151,15 @@ If sub-searches is supported, above ppl query could be rewritten as: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o ON c.c_custkey = o.o_custkey +| LEFT OUTER JOIN + ON c_custkey = o_custkey [ SEARCH source=orders | WHERE o_comment NOT LIKE '%unusual%packages%' | FIELDS o_orderkey, o_custkey ] -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md new file mode 100644 index 000000000..393a9dd59 --- /dev/null +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -0,0 +1,60 @@ +## PPL trendline Command + +**Description** +Using ``trendline`` command to calculate moving averages of fields. + + +### Syntax +`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving average should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +And the moment only the Simple Moving Average (SMA) type is supported. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data-point + n: The number of data-points in the moving window (period) + t: The current time index + + SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t + +### Example 1: Calculate simple moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.5| + | 14| 257|2023-04-06 17:07:...| 13.5| + | 15| 258|2023-04-06 17:07:...| 14.5| + +-----------+---------+--------------------+----------+ + +### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id sma(2, temperature) as temp_trend_2 sma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.5| NULL| + | 13| 256|2023-04-06 17:07:...| 13.5| 14.0| + | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| + | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| + +-----------+---------+--------------------+------------+------------------+ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index b4412a3d4..68d2409ee 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -60,7 +60,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w private val flintMetadataCacheWriter = FlintMetadataCacheWriterBuilder.build(flintSparkConf) private val flintAsyncQueryScheduler: AsyncQueryScheduler = { - AsyncQuerySchedulerBuilder.build(flintSparkConf.flintOptions()) + AsyncQuerySchedulerBuilder.build(spark, flintSparkConf.flintOptions()) } override protected val flintMetadataLogService: FlintMetadataLogService = { @@ -183,7 +183,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w attachLatestLogEntry(indexName, metadata) } .toList - .flatMap(FlintSparkIndexFactory.create) + .flatMap(metadata => FlintSparkIndexFactory.create(spark, metadata)) } else { Seq.empty } @@ -202,7 +202,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w if (flintClient.exists(indexName)) { val metadata = flintIndexMetadataService.getIndexMetadata(indexName) val metadataWithEntry = attachLatestLogEntry(indexName, metadata) - FlintSparkIndexFactory.create(metadataWithEntry) + FlintSparkIndexFactory.create(spark, metadataWithEntry) } else { Option.empty } @@ -327,7 +327,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val index = describeIndex(indexName) if (index.exists(_.options.autoRefresh())) { - val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get + val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(spark, index.get).get FlintSparkIndexRefresh .create(updatedIndex.name(), updatedIndex) .validate(spark) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 0391741cf..2ff2883a9 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -92,7 +92,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get) + validateIndex(FlintSparkIndexFactory.create(flint.spark, updatedMetadata).get) } /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index 78636d992..ca659550d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -25,6 +25,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession /** * Flint Spark index factory that encapsulates specific Flint index instance creation. This is for @@ -35,14 +36,16 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index from generic Flint metadata. * + * @param spark + * Spark session * @param metadata * Flint metadata * @return * Flint index instance, or None if any error during creation */ - def create(metadata: FlintMetadata): Option[FlintSparkIndex] = { + def create(spark: SparkSession, metadata: FlintMetadata): Option[FlintSparkIndex] = { try { - Some(doCreate(metadata)) + Some(doCreate(spark, metadata)) } catch { case e: Exception => logWarning(s"Failed to create Flint index from metadata $metadata", e) @@ -53,24 +56,26 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index with default options. * + * @param spark + * Spark session * @param index * Flint index - * @param metadata - * Flint metadata * @return * Flint index with default options */ - def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = { + def createWithDefaultOptions( + spark: SparkSession, + index: FlintSparkIndex): Option[FlintSparkIndex] = { val originalOptions = index.options val updatedOptions = FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions) val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - this.create(updatedMetadata) + this.create(spark, updatedMetadata) } - private def doCreate(metadata: FlintMetadata): FlintSparkIndex = { + private def doCreate(spark: SparkSession, metadata: FlintMetadata): FlintSparkIndex = { val indexOptions = FlintSparkIndexOptions( metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) val latestLogEntry = metadata.latestLogEntry @@ -118,6 +123,7 @@ object FlintSparkIndexFactory extends Logging { FlintSparkMaterializedView( metadata.name, metadata.source, + getMvSourceTables(spark, metadata), metadata.indexedColumns.map { colInfo => getString(colInfo, "columnName") -> getString(colInfo, "columnType") }.toMap, @@ -134,6 +140,15 @@ object FlintSparkIndexFactory extends Logging { .toMap } + private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = { + val sourceTables = getArrayString(metadata.properties, "sourceTables") + if (sourceTables.isEmpty) { + FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source) + } else { + sourceTables + } + } + private def getString(map: java.util.Map[String, AnyRef], key: String): String = { map.get(key).asInstanceOf[String] } @@ -146,4 +161,12 @@ object FlintSparkIndexFactory extends Logging { Some(value.asInstanceOf[String]) } } + + private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = { + map.get(key) match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala index 1aaa85075..7e9922655 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala @@ -11,9 +11,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} +import org.apache.spark.sql.flint.{loadTable, parseTableName} /** * Flint Spark validation helper. @@ -31,16 +30,10 @@ trait FlintSparkValidationHelper extends Logging { * true if all non Hive, otherwise false */ def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = { - // Extract source table name (possibly more than one for MV query) val tableNames = index match { case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName) case covering: FlintSparkCoveringIndex => Seq(covering.tableName) - case mv: FlintSparkMaterializedView => - spark.sessionState.sqlParser - .parsePlan(mv.query) - .collect { case relation: UnresolvedRelation => - qualifyTableName(spark, relation.tableName) - } + case mv: FlintSparkMaterializedView => mv.sourceTables.toSeq } // Validate if any source table is not supported (currently Hive only) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala index e8a91e1be..e1c0f318c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala @@ -10,6 +10,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.FlintSparkIndexOptions +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser /** @@ -46,9 +47,7 @@ case class FlintMetadataCache( object FlintMetadataCache { - // TODO: constant for version - val mockTableName = - "dataSourceName.default.logGroups(logGroupIdentifier:['arn:aws:logs:us-east-1:123456:test-llt-xa', 'arn:aws:logs:us-east-1:123456:sample-lg-1'])" + val metadataCacheVersion = "1.0" def apply(metadata: FlintMetadata): FlintMetadataCache = { val indexOptions = FlintSparkIndexOptions( @@ -61,6 +60,15 @@ object FlintMetadataCache { } else { None } + val sourceTables = metadata.kind match { + case MV_INDEX_TYPE => + metadata.properties.get("sourceTables") match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + case _ => Array(metadata.source) + } val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry => entry.lastRefreshCompleteTime match { case FlintMetadataLogEntry.EMPTY_TIMESTAMP => None @@ -68,7 +76,6 @@ object FlintMetadataCache { } } - // TODO: get source tables from metadata - FlintMetadataCache("1.0", refreshInterval, Array(mockTableName), lastRefreshTime) + FlintMetadataCache(metadataCacheVersion, refreshInterval, sourceTables, lastRefreshTime) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index caa75be75..aecfc99df 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * MV name * @param query * source query that generates MV data + * @param sourceTables + * source table names * @param outputSchema * output schema * @param options @@ -44,6 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class FlintSparkMaterializedView( mvName: String, query: String, + sourceTables: Array[String], outputSchema: Map[String, String], override val options: FlintSparkIndexOptions = empty, override val latestLogEntry: Option[FlintMetadataLogEntry] = None) @@ -64,6 +67,7 @@ case class FlintSparkMaterializedView( metadataBuilder(this) .name(mvName) .source(query) + .addProperty("sourceTables", sourceTables) .indexedColumns(indexColumnMaps) .schema(schema) .build() @@ -165,10 +169,30 @@ object FlintSparkMaterializedView { flintIndexNamePrefix(mvName) } + /** + * Extract source table names (possibly more than one) from the query. + * + * @param spark + * Spark session + * @param query + * source query that generates MV data + * @return + * source table names + */ + def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = { + spark.sessionState.sqlParser + .parsePlan(query) + .collect { case relation: UnresolvedRelation => + qualifyTableName(spark, relation.tableName) + } + .toArray + } + /** Builder class for MV build */ class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { private var mvName: String = "" private var query: String = "" + private var sourceTables: Array[String] = Array.empty[String] /** * Set MV name. @@ -193,6 +217,7 @@ object FlintSparkMaterializedView { */ def query(query: String): Builder = { this.query = query + this.sourceTables = extractSourceTableNames(flint.spark, query) this } @@ -221,7 +246,7 @@ object FlintSparkMaterializedView { field.name -> field.dataType.simpleString } .toMap - FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions) + FlintSparkMaterializedView(mvName, query, sourceTables, outputSchema, indexOptions) } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java index 3620608b0..330b38f02 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java @@ -7,9 +7,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.flint.config.FlintSparkConf; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.core.FlintOptions; +import java.io.IOException; import java.lang.reflect.Constructor; /** @@ -28,11 +31,27 @@ public enum AsyncQuerySchedulerAction { REMOVE } - public static AsyncQueryScheduler build(FlintOptions options) { + public static AsyncQueryScheduler build(SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilder().doBuild(sparkSession, options); + } + + /** + * Builds an AsyncQueryScheduler based on the provided options. + * + * @param sparkSession The SparkSession to be used. + * @param options The FlintOptions containing configuration details. + * @return An instance of AsyncQueryScheduler. + */ + protected AsyncQueryScheduler doBuild(SparkSession sparkSession, FlintOptions options) throws IOException { String className = options.getCustomAsyncQuerySchedulerClass(); if (className.isEmpty()) { - return new OpenSearchAsyncQueryScheduler(options); + OpenSearchAsyncQueryScheduler scheduler = createOpenSearchAsyncQueryScheduler(options); + // Check if the scheduler has access to the required index. Disable the external scheduler otherwise. + if (!hasAccessToSchedulerIndex(scheduler)){ + setExternalSchedulerEnabled(sparkSession, false); + } + return scheduler; } // Attempts to instantiate AsyncQueryScheduler using reflection @@ -45,4 +64,16 @@ public static AsyncQueryScheduler build(FlintOptions options) { throw new RuntimeException("Failed to instantiate AsyncQueryScheduler: " + className, e); } } + + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return new OpenSearchAsyncQueryScheduler(options); + } + + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return scheduler.hasAccessToSchedulerIndex(); + } + + protected void setExternalSchedulerEnabled(SparkSession sparkSession, boolean enabled) { + sparkSession.sqlContext().setConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED().key(), String.valueOf(enabled)); + } } \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java index 19532254b..a1ef45825 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -37,6 +38,7 @@ import org.opensearch.jobscheduler.spi.schedule.Schedule; import org.opensearch.rest.RestStatus; +import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.time.Instant; @@ -55,6 +57,11 @@ public class OpenSearchAsyncQueryScheduler implements AsyncQueryScheduler { private static final ObjectMapper mapper = new ObjectMapper(); private final FlintOptions flintOptions; + @VisibleForTesting + public OpenSearchAsyncQueryScheduler() { + this.flintOptions = new FlintOptions(ImmutableMap.of()); + } + public OpenSearchAsyncQueryScheduler(FlintOptions options) { this.flintOptions = options; } @@ -124,6 +131,28 @@ void createAsyncQuerySchedulerIndex(IRestHighLevelClient client) { } } + /** + * Checks if the current setup has access to the scheduler index. + * + * This method attempts to create a client and ensure that the scheduler index exists. + * If these operations succeed, it indicates that the user has the necessary permissions + * to access and potentially modify the scheduler index. + * + * @see #createClient() + * @see #ensureIndexExists(IRestHighLevelClient) + */ + public boolean hasAccessToSchedulerIndex() throws IOException { + IRestHighLevelClient client = createClient(); + try { + ensureIndexExists(client); + return true; + } catch (Throwable e) { + LOG.error("Failed to ensure index exists", e); + return false; + } finally { + client.close(); + } + } private void ensureIndexExists(IRestHighLevelClient client) { try { if (!client.doesIndexExist(new GetIndexRequest(SCHEDULER_INDEX_NAME), RequestOptions.DEFAULT)) { diff --git a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java index 67b5afee5..3c65a96a5 100644 --- a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java +++ b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java @@ -5,43 +5,80 @@ package org.opensearch.flint.core.scheduler; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.SQLContext; +import org.junit.Before; import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.common.scheduler.model.AsyncQuerySchedulerRequest; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder; import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler; +import java.io.IOException; + import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AsyncQuerySchedulerBuilderTest { + @Mock + private SparkSession sparkSession; + + @Mock + private SQLContext sqlContext; + + private AsyncQuerySchedulerBuilderForLocalTest testBuilder; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + when(sparkSession.sqlContext()).thenReturn(sqlContext); + } + + @Test + public void testBuildWithEmptyClassNameAndAccessibleIndex() throws IOException { + FlintOptions options = mock(FlintOptions.class); + when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); + + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, true, sparkSession, options); + assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext, never()).setConf(anyString(), anyString()); + } @Test - public void testBuildWithEmptyClassName() { + public void testBuildWithEmptyClassNameAndInaccessibleIndex() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, false, sparkSession, options); assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext).setConf("spark.flint.job.externalScheduler.enabled", "false"); } @Test - public void testBuildWithCustomClassName() { + public void testBuildWithCustomClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); - when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); + when(options.getCustomAsyncQuerySchedulerClass()) + .thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(sparkSession, options); assertTrue(scheduler instanceof AsyncQuerySchedulerForLocalTest); } @Test(expected = RuntimeException.class) - public void testBuildWithInvalidClassName() { + public void testBuildWithInvalidClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("invalid.ClassName"); - AsyncQuerySchedulerBuilder.build(options); + AsyncQuerySchedulerBuilder.build(sparkSession, options); } public static class AsyncQuerySchedulerForLocalTest implements AsyncQueryScheduler { @@ -65,4 +102,35 @@ public void removeJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) { // Custom implementation } } + + public static class OpenSearchAsyncQuerySchedulerForLocalTest extends OpenSearchAsyncQueryScheduler { + @Override + public boolean hasAccessToSchedulerIndex() { + return true; + } + } + + public static class AsyncQuerySchedulerBuilderForLocalTest extends AsyncQuerySchedulerBuilder { + private OpenSearchAsyncQueryScheduler mockScheduler; + private Boolean mockHasAccess; + + public AsyncQuerySchedulerBuilderForLocalTest(OpenSearchAsyncQueryScheduler mockScheduler, Boolean mockHasAccess) { + this.mockScheduler = mockScheduler; + this.mockHasAccess = mockHasAccess; + } + + @Override + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return mockScheduler != null ? mockScheduler : super.createOpenSearchAsyncQueryScheduler(options); + } + + @Override + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return mockHasAccess != null ? mockHasAccess : super.hasAccessToSchedulerIndex(scheduler); + } + + public static AsyncQueryScheduler build(OpenSearchAsyncQueryScheduler asyncQueryScheduler, Boolean hasAccess, SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilderForLocalTest(asyncQueryScheduler, hasAccess).doBuild(sparkSession, options); + } + } } \ No newline at end of file diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index e43b0c52c..b675265b7 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.FlintSparkExtensions import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.flint.config.FlintConfigEntry +import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -26,6 +26,10 @@ trait FlintSuite extends SharedSparkSession { // ConstantPropagation etc. .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) .set("spark.sql.extensions", classOf[FlintSparkExtensions].getName) + // Override scheduler class for unit testing + .set( + FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key, + "org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest") conf } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala new file mode 100644 index 000000000..07720ff24 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.scalatest.matchers.should.Matchers._ + +import org.apache.spark.FlintSuite + +class FlintSparkIndexFactorySuite extends FlintSuite { + + test("create mv should generate source tables if missing in metadata") { + val testTable = "spark_catalog.default.mv_build_test" + val testMvName = "spark_catalog.default.mv" + val testQuery = s"SELECT * FROM $testTable" + + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "indexedColumns": [ + | { + | "columnType": "int", + | "columnName": "age" + | } + | ], + | "name": "$testMvName", + | "source": "$testQuery" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val metadata = FlintOpenSearchIndexMetadataService.deserialize(content) + val index = FlintSparkIndexFactory.create(spark, metadata) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala index c6d2cf12a..6ec6cf696 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala @@ -7,6 +7,9 @@ package org.opensearch.flint.spark.metadatacache import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -21,11 +24,12 @@ class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { "", Map.empty[String, Any]) - it should "construct from FlintMetadata" in { + it should "construct from skipping index FlintMetadata" in { val content = - """ { + s""" { | "_meta": { - | "kind": "test_kind", + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", | "options": { | "auto_refresh": "true", | "refresh_interval": "10 Minutes" @@ -43,18 +47,85 @@ class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { .copy(latestLogEntry = Some(flintMetadataLogEntry)) val metadataCache = FlintMetadataCache(metadata) - metadataCache.metadataCacheVersion shouldBe "1.0" + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion metadataCache.refreshInterval.get shouldBe 600 - metadataCache.sourceTables shouldBe Array(FlintMetadataCache.mockTableName) + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from covering index FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$COVERING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from materialized view FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "source": "spark_catalog.default.wrong_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | }, + | "properties": { + | "sourceTables": [ + | "spark_catalog.default.test_table", + | "spark_catalog.default.another_table" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array( + "spark_catalog.default.test_table", + "spark_catalog.default.another_table") metadataCache.lastRefreshTime.get shouldBe 1234567890123L } it should "construct from FlintMetadata excluding invalid fields" in { // Set auto_refresh = false and lastRefreshCompleteTime = 0 val content = - """ { + s""" { | "_meta": { - | "kind": "test_kind", + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", | "options": { | "refresh_interval": "10 Minutes" | } @@ -71,9 +142,9 @@ class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { .copy(latestLogEntry = Some(flintMetadataLogEntry.copy(lastRefreshCompleteTime = 0L))) val metadataCache = FlintMetadataCache(metadata) - metadataCache.metadataCacheVersion shouldBe "1.0" + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion metadataCache.refreshInterval shouldBe empty - metadataCache.sourceTables shouldBe Array(FlintMetadataCache.mockTableName) + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") metadataCache.lastRefreshTime shouldBe empty } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index c1df42883..1c9a9e83c 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConv import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} -import org.scalatest.matchers.should.Matchers.{contain, convertToAnyShouldWrapper, the} +import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite @@ -37,31 +37,34 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = "SELECT 1" test("get mv name") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" } test("get mv name with dots") { val testMvNameDots = "spark_catalog.default.mv.2023.10" - val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv.2023.10" } test("should fail if get name with unqualified MV name") { the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("mv", testQuery, Array.empty, Map.empty).name() the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("default.mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("default.mv", testQuery, Array.empty, Map.empty).name() } test("get metadata") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map("test_col" -> "integer")) + val mv = + FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map("test_col" -> "integer")) val metadata = mv.metadata() metadata.name shouldBe mv.mvName metadata.kind shouldBe MV_INDEX_TYPE metadata.source shouldBe "SELECT 1" + metadata.properties should contain key "sourceTables" + metadata.properties.get("sourceTables").asInstanceOf[Array[String]] should have size 0 metadata.indexedColumns shouldBe Array( Map("columnName" -> "test_col", "columnType" -> "integer").asJava) metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava @@ -74,6 +77,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, testQuery, + Array.empty, Map("test_col" -> "integer"), indexOptions) @@ -83,12 +87,12 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build batch data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.build(spark, None).collect() shouldBe Array(Row(1)) } test("should fail if build given other source data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) the[IllegalArgumentException] thrownBy mv.build(spark, Some(mock[DataFrame])) } @@ -103,7 +107,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -128,7 +132,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -144,7 +148,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { test("build stream with non-aggregate query") { val testQuery = s"SELECT name, age FROM $testTable WHERE age > 30" - withAggregateMaterializedView(testQuery, Map.empty) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), Map.empty) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -158,7 +162,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = s"SELECT name, age FROM $testTable" val options = Map("extra_options" -> s"""{"$testTable": {"maxFilesPerTrigger": "1"}}""") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable, Map("maxFilesPerTrigger" -> "1")) @@ -175,6 +179,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), Map.empty) the[IllegalStateException] thrownBy @@ -182,14 +187,20 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } } - private def withAggregateMaterializedView(query: String, options: Map[String, String])( - codeBlock: LogicalPlan => Unit): Unit = { + private def withAggregateMaterializedView( + query: String, + sourceTables: Array[String], + options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { withTable(testTable) { sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView(testMvName, query, Map.empty, FlintSparkIndexOptions(options)) + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) val actualPlan = mv.buildStream(spark).queryExecution.logical codeBlock(actualPlan) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index 14d41c2bb..fc77faaea 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -17,9 +17,10 @@ import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTableNames, getFlintIndexName} import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler -import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.{DataFrame, Row} @@ -51,6 +52,29 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { deleteTestIndex(testFlintIndex) } + test("extract source table names from materialized view source query successfully") { + val testComplexQuery = s""" + | SELECT * + | FROM ( + | SELECT 1 + | FROM table1 + | LEFT JOIN `table2` + | ) + | UNION ALL + | SELECT 1 + | FROM spark_catalog.default.`table/3` + | INNER JOIN spark_catalog.default.`table.4` + |""".stripMargin + extractSourceTableNames(flint.spark, testComplexQuery) should contain theSameElementsAs + Array( + "spark_catalog.default.table1", + "spark_catalog.default.table2", + "spark_catalog.default.`table/3`", + "spark_catalog.default.`table.4`") + + extractSourceTableNames(flint.spark, "SELECT 1") should have size 0 + } + test("create materialized view with metadata successfully") { withTempDir { checkpointDir => val indexOptions = @@ -91,7 +115,9 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | "scheduler_mode":"internal" | }, | "latestId": "$testLatestId", - | "properties": {} + | "properties": { + | "sourceTables": ["$testTable"] + | } | }, | "properties": { | "startTime": { @@ -107,6 +133,22 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { } } + test("create materialized view should parse source tables successfully") { + val indexOptions = FlintSparkIndexOptions(Map.empty) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(indexOptions, testFlintIndex) + .create() + + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } + test("create materialized view with default checkpoint location successfully") { withTempDir { checkpointDir => setFlintSparkConf( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index c8c902294..c53eee548 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} import java.util.Comparator import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} @@ -23,6 +23,7 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.{FlintSuite, SparkConf} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY} import org.apache.spark.sql.streaming.StreamTest @@ -49,6 +50,8 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit override def beforeAll(): Unit = { super.beforeAll() + // Revoke override in FlintSuite on IT + conf.unsetConf(FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key) // Replace executor to avoid impact on IT. // TODO: Currently no IT test scheduler so no need to restore it back. @@ -534,6 +537,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiValueStructTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_value Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)) ), + | ( 2, array(STRUCT("2_Monday", 2), null) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 4, null ) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( @@ -695,4 +720,100 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (9, '2001:db8::ff00:12:', true, false) | """.stripMargin) } + + protected def createNestedJsonContentTable(tempFile: Path, testTable: String): Unit = { + val json = + """ + |[ + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Tower Bridge", "length": 801}, + | {"name": "London Bridge", "length": 928} + | ], + | "city": "London", + | "country": "England", + | "coor": { + | "lat": 51.5074, + | "long": -0.1278, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Pont Neuf", "length": 232}, + | {"name": "Pont Alexandre III", "length": 160} + | ], + | "city": "Paris", + | "country": "France", + | "coor": { + | "lat": 48.8566, + | "long": 2.3522, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Rialto Bridge", "length": 48}, + | {"name": "Bridge of Sighs", "length": 11} + | ], + | "city": "Venice", + | "country": "Italy", + | "coor": { + | "lat": 45.4408, + | "long": 12.3155, + | "alt": 2 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Charles Bridge", "length": 516}, + | {"name": "Legion Bridge", "length": 343} + | ], + | "city": "Prague", + | "country": "Czech Republic", + | "coor": { + | "lat": 50.0755, + | "long": 14.4378, + | "alt": 200 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Chain Bridge", "length": 375}, + | {"name": "Liberty Bridge", "length": 333} + | ], + | "city": "Budapest", + | "country": "Hungary", + | "coor": { + | "lat": 47.4979, + | "long": 19.0402, + | "alt": 96 + | } + | }, + | { + | "_time": "1990-09-13T12:00:00", + | "bridges": null, + | "city": "Warsaw", + | "country": "Poland", + | "coor": null + | } + |] + |""".stripMargin + val tempFile = Files.createTempFile("jsonTestData", ".json") + val absolutPath = tempFile.toAbsolutePath.toString; + Files.write(tempFile, json.getBytes) + sql(s""" + | CREATE TEMPORARY VIEW $testTable + | USING org.apache.spark.sql.json + | OPTIONS ( + | path "$absolutPath", + | multiLine true + | ); + |""".stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index a6f7e0ed0..c9f6c47f7 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -9,7 +9,6 @@ import scala.jdk.CollectionConverters.mapAsJavaMapConverter import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods._ -import org.opensearch.OpenSearchException import org.opensearch.action.get.GetRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.core.FlintOptions @@ -207,13 +206,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { val indexInitial = flint.describeIndex(testIndex).get indexInitial.options.refreshInterval() shouldBe Some("4 Minute") - the[OpenSearchException] thrownBy { - val client = - OpenSearchClientUtils.createClient(new FlintOptions(openSearchOptions.asJava)) - client.get( - new GetRequest(OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME, testIndex), - RequestOptions.DEFAULT) - } + indexInitial.options.isExternalSchedulerEnabled() shouldBe false // Update Flint index to change refresh interval val updatedIndex = flint @@ -228,6 +221,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { val indexFinal = flint.describeIndex(testIndex).get indexFinal.options.autoRefresh() shouldBe true indexFinal.options.refreshInterval() shouldBe Some("5 Minutes") + indexFinal.options.isExternalSchedulerEnabled() shouldBe true indexFinal.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) // Verify scheduler index is updated diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala index c04209f06..c0d253fd3 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala @@ -17,7 +17,9 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} import org.opensearch.flint.spark.{FlintSparkIndexOptions, FlintSparkSuite} -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} import org.scalatest.Entry import org.scalatest.matchers.should.Matchers @@ -78,9 +80,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat | { | "_meta": { | "version": "${current()}", - | "name": "${testFlintIndex}", + | "name": "$testFlintIndex", | "kind": "test_kind", - | "source": "test_source_table", + | "source": "$testTable", | "indexedColumns": [ | { | "test_field": "spark_type" @@ -90,12 +92,12 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat | "refresh_interval": "10 Minutes" | }, | "properties": { - | "metadataCacheVersion": "1.0", + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", | "refreshInterval": 600, - | "sourceTables": ["${FlintMetadataCache.mockTableName}"], - | "lastRefreshTime": ${testLastRefreshCompleteTime} + | "sourceTables": ["$testTable"], + | "lastRefreshTime": $testLastRefreshCompleteTime | }, - | "latestId": "${testLatestId}" + | "latestId": "$testLatestId" | }, | "properties": { | "test_field": { @@ -107,7 +109,7 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val builder = new FlintMetadata.Builder builder.name(testFlintIndex) builder.kind("test_kind") - builder.source("test_source_table") + builder.source(testTable) builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava) builder.options( Map("auto_refresh" -> "true", "refresh_interval" -> "10 Minutes") @@ -129,12 +131,71 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + Seq(SKIPPING_INDEX_TYPE, COVERING_INDEX_TYPE).foreach { case kind => + test(s"write metadata cache to $kind index mappings with source tables") { + val content = + s""" { + | "_meta": { + | "kind": "$kind", + | "source": "$testTable" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[List[String]] + .toArray should contain theSameElementsAs Array(testTable) + } + } + + test(s"write metadata cache to materialized view index mappings with source tables") { + val testTable2 = "spark_catalog.default.metadatacache_test2" + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": [ + | "$testTable", "$testTable2" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties .get("sourceTables") .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) + .toArray should contain theSameElementsAs Array(testTable, testTable2) } test("write metadata cache to index mappings with refresh interval") { @@ -162,13 +223,11 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 4 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("refreshInterval", 600), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) } test("exclude refresh interval in metadata cache when auto refresh is false") { @@ -195,12 +254,10 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) } test("exclude last refresh time in metadata cache when index has not been refreshed") { @@ -212,11 +269,8 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 2 - properties should contain(Entry("metadataCacheVersion", "1.0")) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) + properties should contain( + Entry("metadataCacheVersion", FlintMetadataCache.metadataCacheVersion)) } test("write metadata cache to index mappings and preserve other index metadata") { @@ -246,12 +300,10 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 var properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) val newContent = """ { @@ -278,12 +330,10 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) } Seq( @@ -296,9 +346,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat "checkpoint_location" -> "s3a://test/"), s""" | { - | "metadataCacheVersion": "1.0", + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", | "refreshInterval": 600, - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "sourceTables": ["$testTable"] | } |""".stripMargin), ( @@ -306,8 +356,8 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat Map.empty[String, String], s""" | { - | "metadataCacheVersion": "1.0", - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] | } |""".stripMargin), ( @@ -315,8 +365,8 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat Map("incremental_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), s""" | { - | "metadataCacheVersion": "1.0", - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] | } |""".stripMargin)).foreach { case (refreshMode, optionsMap, expectedJson) => test(s"write metadata cache for $refreshMode") { @@ -389,9 +439,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat flintMetadataCacheWriter.serialize(index.get.metadata())) \ "_meta" \ "properties")) propertiesJson should matchJson(s""" | { - | "metadataCacheVersion": "1.0", + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", | "refreshInterval": 600, - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "sourceTables": ["$testTable"] | } |""".stripMargin) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index cbc4308b0..3bd98edf1 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -597,4 +597,68 @@ class FlintSparkPPLBasicITSuite | """.stripMargin)) assert(ex.getMessage().contains("Invalid table name")) } + + test("Search multiple tables - translated into union call with fields") { + val frame = sql(s""" + | source = $t1, $t2 + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val allFields1 = UnresolvedStar(None) + val allFields2 = UnresolvedStar(None) + + val projectedTable1 = Project(Seq(allFields1), table1) + val projectedTable2 = Project(Seq(allFields2), table2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("Search multiple tables - with table alias") { + val frame = sql(s""" + | source = $t1, $t2 as t | where t.country = "USA" + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val plan1 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table1)) + val plan2 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table2)) + + val projectedTable1 = Project(Seq(UnresolvedStar(None)), plan1) + val projectedTable2 = Project(Seq(UnresolvedStar(None)), plan2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala new file mode 100644 index 000000000..e714a5f7e --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala @@ -0,0 +1,350 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFlattenITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("flatten for structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | fields coor + | | flatten coor + | """.stripMargin) + + assert(frame.columns.sameElements(Array("alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(35, 51.5074, -0.1278), Row(null, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val projectCoor = Project(Seq(UnresolvedAttribute("coor")), filter) + val flattenCoor = flattenPlanFor("coor", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + private def flattenPlanFor(flattenedColumn: String, parentPlan: LogicalPlan): LogicalPlan = { + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute(flattenedColumn)) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), outer = true, None, seq(), parentPlan) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute(flattenedColumn)), generate) + dropSourceColumn + } + + test("flatten for arrays") { + val frame = sql(s""" + | source = $testTable + | | fields bridges + | | flatten bridges + | """.stripMargin) + + assert(frame.columns.sameElements(Array("length", "name"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, null), + Row(11L, "Bridge of Sighs"), + Row(48L, "Rialto Bridge"), + Row(160L, "Pont Alexandre III"), + Row(232L, "Pont Neuf"), + Row(801L, "Tower Bridge"), + Row(928L, "London Bridge"), + Row(343L, "Legion Bridge"), + Row(516L, "Charles Bridge"), + Row(333L, "Liberty Bridge"), + Row(375L, "Chain Bridge")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCoor = Project(Seq(UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenBridges) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten for structs and arrays") { + val frame = sql(s""" + | source = $testTable | flatten bridges | flatten coor + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("_time", "city", "country", "length", "name", "alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("1990-09-13T12:00:00", "Warsaw", "Poland", null, null, null, null, null), + Row( + "2024-09-13T12:00:00", + "Venice", + "Italy", + 11L, + "Bridge of Sighs", + 2, + 45.4408, + 12.3155), + Row("2024-09-13T12:00:00", "Venice", "Italy", 48L, "Rialto Bridge", 2, 45.4408, 12.3155), + Row( + "2024-09-13T12:00:00", + "Paris", + "France", + 160L, + "Pont Alexandre III", + 35, + 48.8566, + 2.3522), + Row("2024-09-13T12:00:00", "Paris", "France", 232L, "Pont Neuf", 35, 48.8566, 2.3522), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 801L, + "Tower Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 928L, + "London Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 343L, + "Legion Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 516L, + "Charles Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 333L, + "Liberty Bridge", + 96, + 47.4979, + 19.0402), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 375L, + "Chain Bridge", + 96, + 47.4979, + 19.0402)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](3)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val flattenBridges = flattenPlanFor("bridges", table) + val flattenCoor = flattenPlanFor("coor", flattenBridges) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val frame = sql(s""" + | source = $testTable + | | fields country, bridges + | | flatten bridges + | | fields country, length + | | stats avg(length) as avg by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, "Poland"), + Row(196d, "France"), + Row(429.5, "Czech Republic"), + Row(864.5, "England"), + Row(29.5, "Italy"), + Row(354.0, "Hungary")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCountryBridges = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCountryBridges) + val projectCountryLength = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("length")), flattenBridges) + val average = Alias( + UnresolvedFunction( + seq("AVG"), + seq(UnresolvedAttribute("length")), + isDistinct = false, + None, + ignoreNulls = false), + "avg")() + val country = Alias(UnresolvedAttribute("country"), "country")() + val grouping = Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(average, country), projectCountryLength) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct table") { + val frame = sql(s""" + | source = $structTable + | | flatten struct_col + | | flatten field1 + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(30, 123, "value1"), Row(40, 456, "value2"), Row(50, 789, "value3")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct nested table") { + val frame = sql(s""" + | source = $structNestedTable + | | flatten struct_col + | | flatten field1 + | | flatten struct_col2 + | | flatten field1 + | """.stripMargin) + + assert( + frame.columns.sameElements(Array("int_col", "field2", "subfield", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(30, 123, "value1", 23, "valueA"), + Row(40, 123, "value5", 33, "valueB"), + Row(30, 823, "value4", 83, "valueC"), + Row(40, 456, "value2", 46, "valueD"), + Row(50, 789, "value3", 89, "valueE")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_nested_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val flattenStructCol2 = flattenPlanFor("struct_col2", flattenField1) + val flattenField1Again = flattenPlanFor("field1", flattenStructCol2) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1Again) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten multi value nullable") { + val frame = sql(s""" + | source = $multiValueTable + | | flatten multi_value + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "name", "value"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "1_one", 1), + Row(1, null, 11), + Row(1, "1_three", null), + Row(2, "2_Monday", 2), + Row(2, null, null), + Row(3, "3_third", 3), + Row(3, "3_4th", 4), + Row(4, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val flattenMultiValue = flattenPlanFor("multi_value", table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenMultiValue) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index 00e55d50a..3127325c8 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} @@ -924,4 +924,271 @@ class FlintSparkPPLJoinITSuite s }.size == 13) } + + test("test multiple joins without table aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN ON $testTable1.name = $testTable2.name $testTable2 + | | JOIN ON $testTable2.name = $testTable3.name $testTable3 + | | fields $testTable1.name, $testTable2.name, $testTable3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | fields t1.name, t2.name, t3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("check access the reference by aliases") { + var frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 as t1 + | | JOIN ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 ] as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as t2 ] + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("access the reference by override aliases should throw exception") { + var ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as tt ] + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as tt ] as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 ] as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala index 7cc0a221d..fca758101 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -163,30 +163,32 @@ class FlintSparkPPLJsonFunctionITSuite assert(ex.getMessage().contains("should all be the same type")) } - test("test json_array() with json()") { + test("test json_array() with to_json_tring()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""[1.0,2.0,0.0,-1.0,1.1,-0.11]""")), frame) } - test("test json_array_length()") { + test("test array_length()") { var frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(5)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(6)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array()) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array()) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(0)), frame) + } - frame = sql(s""" + test("test json_array_length()") { + var frame = sql(s""" | source = $testTable | eval result = json_array_length('[]') | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row(0)), frame) @@ -211,24 +213,24 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object()") { // test value is a string var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 'string_value')) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 'string_value')) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":"string_value"}""")), frame) // test value is a number frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 123.45)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 123.45)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":123.45}""")), frame) // test value is a boolean frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', true)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', true)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":true}""")), frame) frame = sql(s""" - | source = $testTable| eval result = json(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"a":1,"b":2,"c":3}""")), frame) } @@ -236,13 +238,13 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() and json_array()") { // test value is an empty array var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array())) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array())) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[]}""")), frame) // test value is an array frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array(1, 2, 3))) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array(1, 2, 3))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[1,2,3]}""")), frame) @@ -272,14 +274,14 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() nested") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"outer":{"inner":123.45}}""")), frame) } test("test json_object(), json_array() and json()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]}""")), frame) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala new file mode 100644 index 000000000..f86502521 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions.{col, to_json} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLLambdaFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test forall()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame4) + } + + test("test exists()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame4) + + } + + test("test filter()") { + val frame = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 2, 1.1))), frame) + + val frame2 = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq())), frame2) + + val frame3 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + + assertSameRows(Seq(Row("""[{"a":1,"b":-1}]""")), frame3.select(to_json(col("result")))) + + val frame4 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""[]""")), frame4.select(to_json(col("result")))) + } + + test("test transform()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(2, 3, 4))), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 3, 5))), frame2) + } + + test("test reduce()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(6)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 1, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(7)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(60)), frame3) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala new file mode 100644 index 000000000..bc4463537 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTrendlineITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test trendline sma command without fields command and without alias") { + val frame = sql(s""" + | source = $testTable | sort - age | trendline sma(2, age) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", "USA", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 50.0), + Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command with fields command") { + val frame = sql(s""" + | source = $testTable | trendline sort - age sma(3, age) as age_sma | fields name, age, age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, null), + Row("Hello", 30, null), + Row("John", 25, 41.666666666666664), + Row("Jane", 20, 25)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageSmaField = UnresolvedAttribute("age_sma") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline sma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, null, null), + Row("John", 25, 22.5, null), + Row("Hello", 30, 27.5, 25.0), + Row("Jake", 70, 50.0, 41.666666666666664)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "three_points_sma")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 45.0), + Row("Hello", 60, 55.0), + Row("Jake", 140, 100.0)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val doubledAgeField = UnresolvedAttribute("doubled_age") + val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma") + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false), + "doubled_age")()), + table) + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val doubleAgeSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = + CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")()) + val expectedPlan = Project( + Seq(nameField, doubledAgeField, doubledAgeSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command chaining") { + val frame = sql(s""" + | source = $testTable | eval age_1 = age, age_2 = age | trendline sort - age_1 sma(3, age_1) | trendline sort + age_2 sma(3, age_2) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "age_1", + "age_2", + "age_1_trendline", + "age_2_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, null, 25.0), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, null, 41.666666666666664), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 25.0, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 41.666666666666664, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index bf6989b7c..93efb2df1 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -37,6 +37,8 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +FLATTEN: 'FLATTEN'; +TRENDLINE: 'TRENDLINE'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -89,6 +91,9 @@ FIELDSUMMARY: 'FIELDSUMMARY'; INCLUDEFIELDS: 'INCLUDEFIELDS'; NULLS: 'NULLS'; +//TRENDLINE KEYWORDS +SMA: 'SMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; @@ -198,6 +203,7 @@ RT_SQR_PRTHS: ']'; SINGLE_QUOTE: '\''; DOUBLE_QUOTE: '"'; BACKTICK: '`'; +ARROW: '->'; // Operators. Bit @@ -372,6 +378,7 @@ JSON: 'JSON'; JSON_OBJECT: 'JSON_OBJECT'; JSON_ARRAY: 'JSON_ARRAY'; JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; +TO_JSON_STRING: 'TO_JSON_STRING'; JSON_EXTRACT: 'JSON_EXTRACT'; JSON_KEYS: 'JSON_KEYS'; JSON_VALID: 'JSON_VALID'; @@ -379,14 +386,22 @@ JSON_VALID: 'JSON_VALID'; //JSON_DELETE: 'JSON_DELETE'; //JSON_EXTEND: 'JSON_EXTEND'; //JSON_SET: 'JSON_SET'; -//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH'; -//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH'; -//JSON_ARRAY_FILTER: 'JSON_FILTER'; +//JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH'; +//JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH'; +//JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER'; //JSON_ARRAY_MAP: 'JSON_ARRAY_MAP'; //JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE'; // COLLECTION FUNCTIONS ARRAY: 'ARRAY'; +ARRAY_LENGTH: 'ARRAY_LENGTH'; + +// LAMBDA FUNCTIONS +//EXISTS: 'EXISTS'; +FORALL: 'FORALL'; +FILTER: 'FILTER'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; // BOOL FUNCTIONS LIKE: 'LIKE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index aaf807a7b..123d1e15a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -53,6 +53,8 @@ commands | renameCommand | fillnullCommand | fieldsummaryCommand + | flattenCommand + | trendlineCommand ; commandName @@ -82,6 +84,8 @@ commandName | RENAME | FILLNULL | FIELDSUMMARY + | FLATTEN + | TRENDLINE ; searchCommand @@ -89,7 +93,7 @@ searchCommand | (SEARCH)? fromClause logicalExpression # searchFromFilter | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; - + fieldsummaryCommand : FIELDSUMMARY (fieldsummaryParameter)* ; @@ -246,6 +250,21 @@ fillnullCommand : expression ; +flattenCommand + : FLATTEN fieldExpression + ; + +trendlineCommand + : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + ; + +trendlineType + : SMA + ; kmeansCommand : KMEANS (kmeansParameter)* @@ -320,7 +339,7 @@ joinType ; sideAlias - : LEFT EQUAL leftAlias = ident COMMA? RIGHT EQUAL rightAlias = ident + : (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)? ; joinCriteria @@ -424,6 +443,8 @@ valueExpression | timestampFunction # timestampFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | ident ARROW expression # lambda + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda ; primaryExpression @@ -549,6 +570,7 @@ evalFunctionName | cryptographicFunctionName | jsonFunctionName | collectionFunctionName + | lambdaFunctionName ; functionArgs @@ -838,6 +860,7 @@ jsonFunctionName | JSON_OBJECT | JSON_ARRAY | JSON_ARRAY_LENGTH + | TO_JSON_STRING | JSON_EXTRACT | JSON_KEYS | JSON_VALID @@ -854,6 +877,15 @@ jsonFunctionName collectionFunctionName : ARRAY + | ARRAY_LENGTH + ; + +lambdaFunctionName + : FORALL + | EXISTS + | FILTER + | TRANSFORM + | REDUCE ; positionFunctionName @@ -1125,4 +1157,5 @@ keywordsCanBeId | ANTI | BETWEEN | CIDRMATCH + | trendlineType ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 03c40fcd2..189d9084a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -18,6 +18,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; @@ -111,6 +112,10 @@ public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } @@ -179,6 +184,10 @@ public T visitFunction(Function node, C context) { return visitChildren(node, context); } + public T visitLambdaFunction(LambdaFunction node, C context) { + return visitChildren(node, context); + } + public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } @@ -326,4 +335,8 @@ public T visitWindow(Window node, C context) { public T visitCidr(Cidr node, C context) { return visitChildren(node, context); } + + public T visitFlatten(Flatten flatten, C context) { + return visitChildren(flatten, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java new file mode 100644 index 000000000..e1ee755b8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of lambda function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class LambdaFunction extends UnresolvedExpression { + private final UnresolvedExpression function; + private final List funcArgs; + + @Override + public List getChild() { + List children = new ArrayList<>(); + children.add(function); + children.addAll(funcArgs); + return children; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLambdaFunction(this, context); + } + + @Override + public String toString() { + return String.format( + "(%s) -> %s", + funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")), + function.toString() + ); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java index b513d01bf..dd9947329 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java @@ -8,12 +8,14 @@ import lombok.ToString; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import java.util.Collections; + /** * Extend Relation to describe the table itself */ @ToString public class DescribeRelation extends Relation{ public DescribeRelation(UnresolvedExpression tableName) { - super(tableName); + super(Collections.singletonList(tableName)); } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java new file mode 100644 index 000000000..e31fbb6e3 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -0,0 +1,34 @@ +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; + +import java.util.List; + +@RequiredArgsConstructor +public class Flatten extends UnresolvedPlan { + + private UnresolvedPlan child; + + @Getter + private final Field fieldToBeFlattened; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFlatten(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java index 89f787d34..176902911 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -25,15 +25,15 @@ public class Join extends UnresolvedPlan { private UnresolvedPlan left; private final UnresolvedPlan right; - private final String leftAlias; - private final String rightAlias; + private final Optional leftAlias; + private final Optional rightAlias; private final JoinType joinType; private final Optional joinCondition; private final JoinHint joinHint; @Override public UnresolvedPlan attach(UnresolvedPlan child) { - this.left = new SubqueryAlias(leftAlias, child); + this.left = leftAlias.isEmpty() ? child : new SubqueryAlias(leftAlias.get(), child); return this; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index 1b30a7998..d8ea104a4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -6,53 +6,34 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; -import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; /** Logical plan node of Relation, the interface for building the searching sources. */ -@AllArgsConstructor @ToString +@Getter @EqualsAndHashCode(callSuper = false) @RequiredArgsConstructor public class Relation extends UnresolvedPlan { private static final String COMMA = ","; - private final List tableName; - - public Relation(UnresolvedExpression tableName) { - this(tableName, null); - } - - public Relation(UnresolvedExpression tableName, String alias) { - this.tableName = Arrays.asList(tableName); - this.alias = alias; - } - - /** Optional alias name for the relation. */ - @Setter @Getter private String alias; - - /** - * Return table name. - * - * @return table name - */ - public List getTableName() { - return tableName.stream().map(Object::toString).collect(Collectors.toList()); - } + // A relation could contain more than one table/index names, such as + // source=account1, account2 + // source=`account1`,`account2` + // source=`account*` + // They translated into union call with fields. + private final List tableNames; public List getQualifiedNames() { - return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); + return tableNames.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } /** @@ -63,11 +44,11 @@ public List getQualifiedNames() { * @return TableQualifiedName. */ public QualifiedName getTableQualifiedName() { - if (tableName.size() == 1) { - return (QualifiedName) tableName.get(0); + if (tableNames.size() == 1) { + return (QualifiedName) tableNames.get(0); } else { return new QualifiedName( - tableName.stream() + tableNames.stream() .map(UnresolvedExpression::toString) .collect(Collectors.joining(COMMA))); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java index 29c3d4b90..ba66cca80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java @@ -6,19 +6,14 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import java.util.List; -import java.util.Objects; -@AllArgsConstructor @EqualsAndHashCode(callSuper = false) -@RequiredArgsConstructor @ToString public class SubqueryAlias extends UnresolvedPlan { @Getter private final String alias; @@ -32,6 +27,11 @@ public SubqueryAlias(UnresolvedPlan child, String suffix) { this.child = child; } + public SubqueryAlias(String alias, UnresolvedPlan child) { + this.alias = alias; + this.child = child; + } + public List getChild() { return ImmutableList.of(child); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 000000000..9fa1ae81d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final Optional sortByField; + private final List computations; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation { + + private final Integer numberOfDataPoints; + private final UnresolvedExpression dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, Trendline.TrendlineType computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = computationType; + } + + } + + public enum TrendlineType { + SMA + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index d81dc7ce4..1959d0f6d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -213,6 +213,7 @@ public enum BuiltinFunctionName { JSON_OBJECT(FunctionName.of("json_object")), JSON_ARRAY(FunctionName.of("json_array")), JSON_ARRAY_LENGTH(FunctionName.of("json_array_length")), + TO_JSON_STRING(FunctionName.of("to_json_string")), JSON_EXTRACT(FunctionName.of("json_extract")), JSON_KEYS(FunctionName.of("json_keys")), JSON_VALID(FunctionName.of("json_valid")), @@ -228,6 +229,14 @@ public enum BuiltinFunctionName { /** COLLECTION Functions **/ ARRAY(FunctionName.of("array")), + ARRAY_LENGTH(FunctionName.of("array_length")), + + /** LAMBDA Functions **/ + ARRAY_FORALL(FunctionName.of("forall")), + ARRAY_EXISTS(FunctionName.of("exists")), + ARRAY_FILTER(FunctionName.of("filter")), + ARRAY_TRANSFORM(FunctionName.of("transform")), + ARRAY_AGGREGATE(FunctionName.of("reduce")), /** NULL Test. */ IS_NULL(FunctionName.of("is null")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java new file mode 100644 index 000000000..4c8d117b3 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -0,0 +1,475 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; +import org.apache.spark.sql.catalyst.expressions.Exists$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.In$; +import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LambdaFunction$; +import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.ListQuery$; +import org.apache.spark.sql.catalyst.expressions.MakeInterval$; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Between; +import org.opensearch.sql.ast.expression.BinaryExpression; +import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.FieldsMapping; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.LambdaFunction; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.When; +import org.opensearch.sql.ast.expression.WindowFunction; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FillNull; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.utils.AggregatorTransformer; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; +import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.JavaToScalaTransformer; +import scala.Option; +import scala.PartialFunction; +import scala.Tuple2; +import scala.collection.Seq; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Stack; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; +import static org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer.createIntervalArgs; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; +import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; + +/** + * Class of building catalyst AST Expression nodes. + */ +public class CatalystExpressionVisitor extends AbstractNodeVisitor { + + private final AbstractNodeVisitor planVisitor; + + public CatalystExpressionVisitor(AbstractNodeVisitor planVisitor) { + this.planVisitor = planVisitor; + } + + public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + return unresolved.accept(this, context); + } + + /** This method is only for analyze the join condition expression */ + public Expression analyzeJoinCondition(UnresolvedExpression unresolved, CatalystPlanContext context) { + return context.resolveJoinCondition(unresolved, this::analyze); + } + + @Override + public Expression visitLiteral(Literal node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( + translate(node.getValue(), node.getType()), translate(node.getType()))); + } + + /** + * generic binary (And, Or, Xor , ...) arithmetic expression resolver + * + * @param node + * @param transformer + * @param context + * @return + */ + public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Optional left = context.popNamedParseExpressions(); + node.getRight().accept(this, context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + return transformer.apply(left.get(), right.get()); + } else if (left.isPresent()) { + return context.getNamedParseExpressions().push(left.get()); + } else if (right.isPresent()) { + return context.getNamedParseExpressions().push(right.get()); + } + return null; + + } + + @Override + public Expression visitAnd(And node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); + } + + @Override + public Expression visitOr(Or node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); + } + + @Override + public Expression visitXor(Xor node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); + } + + @Override + public Expression visitNot(Not node, CatalystPlanContext context) { + node.getExpression().accept(this, context); + Optional arg = context.popNamedParseExpressions(); + return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); + } + + @Override + public Expression visitSpan(Span node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression field = (Expression) context.popNamedParseExpressions().get(); + node.getValue().accept(this, context); + Expression value = (Expression) context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); + } + + @Override + public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression arg = (Expression) context.popNamedParseExpressions().get(); + Expression aggregator = AggregatorTransformer.aggregator(node, arg); + return context.getNamedParseExpressions().push(aggregator); + } + + @Override + public Expression visitCompare(Compare node, CatalystPlanContext context) { + analyze(node.getLeft(), context); + Optional left = context.popNamedParseExpressions(); + analyze(node.getRight(), context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + } + return null; + } + + @Override + public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + // When the qualified name is part of join condition, for example: table1.id = table2.id + // findRelation(context.traversalContext() only returns relation table1 which cause table2.id fail to resolve + if (context.isResolvingJoinCondition()) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + List relation = findRelation(context.traversalContext()); + if (!relation.isEmpty()) { + Optional resolveField = resolveField(relation, node, context.getRelations()); + return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) + .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); + } + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + + /** + * Resolve the qualified name with subquery alias:
+ * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
+ * - tableName1.joinKey = subqueryAlias2.joinKey
+ * - subqueryAlias1.joinKey = tableName2.joinKey
+ */ + private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { + if (node.getPrefix().isPresent() && + context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { + if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) + .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) + .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + } + return null; + } + + @Override + public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { + return node.getChild().stream().map(expression -> + visitCompare((Compare) expression, context) + ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + + @Override + public Expression visitAllFields(AllFields node, CatalystPlanContext context) { + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + return context.getNamedParseExpressions().peek(); + } + + @Override + public Expression visitAlias(Alias node, CatalystPlanContext context) { + node.getDelegated().accept(this, context); + Expression arg = context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, + node.getAlias() != null ? node.getAlias() : node.getName(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + + @Override + public Expression visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public Expression visitFunction(Function node, CatalystPlanContext context) { + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTransformer.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); + } + + @Override + public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { + Stack namedParseExpressions = new Stack<>(); + namedParseExpressions.addAll(context.getNamedParseExpressions()); + Expression expression = visitCase(node.getCaseValue(), context); + namedParseExpressions.add(expression); + context.setNamedParseExpressions(namedParseExpressions); + return expression; + } + + @Override + public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : FillNull"); + } + + @Override + public Expression visitInterval(Interval node, CatalystPlanContext context) { + node.getValue().accept(this, context); + Expression value = context.getNamedParseExpressions().pop(); + Expression[] intervalArgs = createIntervalArgs(node.getUnit(), value); + Expression interval = MakeInterval$.MODULE$.apply( + intervalArgs[0], intervalArgs[1], intervalArgs[2], intervalArgs[3], + intervalArgs[4], intervalArgs[5], intervalArgs[6], true); + return context.getNamedParseExpressions().push(interval); + } + + @Override + public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Dedupe"); + } + + @Override + public Expression visitIn(In node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression value = context.popNamedParseExpressions().get(); + List list = node.getValueList().stream().map( expression -> { + expression.accept(this, context); + return context.popNamedParseExpressions().get(); + }).collect(Collectors.toList()); + return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); + } + + @Override + public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public Expression visitCase(Case node, CatalystPlanContext context) { + Stack initialNameExpressions = new Stack<>(); + initialNameExpressions.addAll(context.getNamedParseExpressions()); + analyze(node.getElseClause(), context); + Expression elseValue = context.getNamedParseExpressions().pop(); + List> whens = new ArrayList<>(); + for (When when : node.getWhenClauses()) { + if (node.getCaseValue() == null) { + whens.add( + new Tuple2<>( + analyze(when.getCondition(), context), + analyze(when.getResult(), context) + ) + ); + } else { + // Merge case value and condition (compare value) into a single equal condition + Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); + whens.add( + new Tuple2<>( + analyze(compare, context), analyze(when.getResult(), context) + ) + ); + } + context.retainAllNamedParseExpressions(e -> e); + } + context.setNamedParseExpressions(initialNameExpressions); + return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); + } + + @Override + public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + + @Override + public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + visitExpressionList(node.getChild(), innerContext); + Seq values = innerContext.retainAllNamedParseExpressions(p -> p); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression inSubQuery = InSubquery$.MODULE$.apply( + values, + ListQuery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + -1, + seq(new java.util.ArrayList()), + Option.empty())); + return outerContext.getNamedParseExpressions().push(inSubQuery); + } + + @Override + public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + Option.empty()); + return context.getNamedParseExpressions().push(scalarSubQuery); + } + + @Override + public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression existsSubQuery = Exists$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty()); + return context.getNamedParseExpressions().push(existsSubQuery); + } + + @Override + public Expression visitBetween(Between node, CatalystPlanContext context) { + Expression value = analyze(node.getValue(), context); + Expression lower = analyze(node.getLowerBound(), context); + Expression upper = analyze(node.getUpperBound(), context); + context.retainAllNamedParseExpressions(p -> p); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); + } + + @Override + public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) { + analyze(node.getIpAddress(), context); + Expression ipAddressExpression = context.getNamedParseExpressions().pop(); + analyze(node.getCidrBlock(), context); + Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); + + ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddressExpression,cidrBlockExpression), + seq(), + Option.empty(), + Option.apply("cidr"), + false, + true); + + return context.getNamedParseExpressions().push(udf); + } + + @Override + public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext context) { + PartialFunction transformer = JavaToScalaTransformer.toPartialFunction( + expr -> expr instanceof UnresolvedAttribute, + expr -> { + UnresolvedAttribute attr = (UnresolvedAttribute) expr; + return new UnresolvedNamedLambdaVariable(attr.nameParts()); + } + ); + Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer); + context.popNamedParseExpressions(); + List argsResult = node.getFuncArgs().stream() + .map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts()))) + .collect(Collectors.toList()); + return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false)); + } + + private List visitExpressionList(List expressionList, CatalystPlanContext context) { + return expressionList.isEmpty() + ? emptyList() + : expressionList.stream().map(field -> analyze(field, context)) + .collect(Collectors.toList()); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 61762f616..53dc17576 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl; +import lombok.Getter; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -39,19 +40,19 @@ public class CatalystPlanContext { /** * Catalyst relations list **/ - private List projectedFields = new ArrayList<>(); + @Getter private List projectedFields = new ArrayList<>(); /** * Catalyst relations list **/ - private List relations = new ArrayList<>(); + @Getter private List relations = new ArrayList<>(); /** * Catalyst SubqueryAlias list **/ - private List subqueryAlias = new ArrayList<>(); + @Getter private List subqueryAlias = new ArrayList<>(); /** * Catalyst evolving logical plan **/ - private Stack planBranches = new Stack<>(); + @Getter private Stack planBranches = new Stack<>(); /** * The current traversal context the visitor is going threw */ @@ -60,28 +61,12 @@ public class CatalystPlanContext { /** * NamedExpression contextual parameters **/ - private final Stack namedParseExpressions = new Stack<>(); + @Getter private final Stack namedParseExpressions = new Stack<>(); /** * Grouping NamedExpression contextual parameters **/ - private final Stack groupingParseExpressions = new Stack<>(); - - public Stack getPlanBranches() { - return planBranches; - } - - public List getRelations() { - return relations; - } - - public List getSubqueryAlias() { - return subqueryAlias; - } - - public List getProjectedFields() { - return projectedFields; - } + @Getter private final Stack groupingParseExpressions = new Stack<>(); public LogicalPlan getPlan() { if (this.planBranches.isEmpty()) return null; @@ -101,10 +86,6 @@ public Stack traversalContext() { return planTraversalContext; } - public Stack getNamedParseExpressions() { - return namedParseExpressions; - } - public void setNamedParseExpressions(Stack namedParseExpressions) { this.namedParseExpressions.clear(); this.namedParseExpressions.addAll(namedParseExpressions); @@ -114,10 +95,6 @@ public Optional popNamedParseExpressions() { return namedParseExpressions.isEmpty() ? Optional.empty() : Optional.of(namedParseExpressions.pop()); } - public Stack getGroupingParseExpressions() { - return groupingParseExpressions; - } - /** * define new field * @@ -154,13 +131,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + public LogicalPlan applyBranches(List> plans) { plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); planBranches.remove(0); return getPlan(); - } - + } + /** * append plan with evolving plans branches * @@ -288,4 +265,21 @@ public static Optional findRelation(LogicalPlan plan) { return Optional.empty(); } + @Getter private boolean isResolvingJoinCondition = false; + + /** + * Resolve the join condition with the given function. + * A flag will be set to true ahead expression resolving, then false after resolving. + * @param expr + * @param transformFunction + * @return + */ + public Expression resolveJoinCondition( + UnresolvedExpression expr, + BiFunction transformFunction) { + isResolvingJoinCondition = true; + Expression result = transformFunction.apply(expr, this); + isResolvingJoinCondition = false; + return result; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 90df01e66..a43378480 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -6,63 +6,47 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.TableIdentifier; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; -import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; -import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; -import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.plans.logical.*; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Generate; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.AllFields; -import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Between; -import org.opensearch.sql.ast.expression.BinaryExpression; -import org.opensearch.sql.ast.expression.Case; -import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; -import org.opensearch.sql.ast.expression.subquery.InSubquery; -import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; -import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.WindowFunction; -import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; @@ -74,6 +58,7 @@ import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Flatten; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; @@ -87,19 +72,16 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; -import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.expression.function.SerializableUdf; -import org.opensearch.sql.ppl.utils.AggregatorTranslator; -import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; -import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; -import org.opensearch.sql.ppl.utils.ParseStrategy; +import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; +import scala.None$; import scala.Option; -import scala.Tuple2; import scala.collection.IterableLike; import scala.collection.Seq; @@ -107,17 +89,11 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Stack; -import java.util.function.BiFunction; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.List.of; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; -import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; -import static org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator.createIntervalArgs; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent; @@ -129,8 +105,6 @@ import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList; import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields; import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; -import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; -import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; import static scala.collection.JavaConverters.seqAsJavaList; /** @@ -138,20 +112,16 @@ */ public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { - private final ExpressionAnalyzer expressionAnalyzer; + private final CatalystExpressionVisitor expressionAnalyzer; public CatalystQueryPlanVisitor() { - this.expressionAnalyzer = new ExpressionAnalyzer(); + this.expressionAnalyzer = new CatalystExpressionVisitor(this); } public LogicalPlan visit(Statement plan, CatalystPlanContext context) { return plan.accept(this, context); } - - public LogicalPlan visitSubSearch(UnresolvedPlan plan, CatalystPlanContext context) { - return plan.accept(this, context); - } - + /** * Handle Query Statement. */ @@ -256,6 +226,30 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { }); } + @Override + public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + + node.getSortByField() + .ifPresent(sortField -> { + Expression sortFieldExpression = visitExpression(sortField, context); + Seq sortOrder = context + .retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(sortFieldExpression, SortUtils.isSortedAscending(sortField))); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortOrder, true, p)); + }); + + List trendlineProjectExpressions = new ArrayList<>(); + + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); + } + + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); + } + @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); @@ -279,7 +273,8 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); - Optional joinCondition = node.getJoinCondition().map(c -> visitExpression(c, context)); + Optional joinCondition = node.getJoinCondition() + .map(c -> expressionAnalyzer.analyzeJoinCondition(c, context)); context.retainAllNamedParseExpressions(p -> p); context.retainAllPlans(p -> p); return join(left, right, node.getJoinType(), joinCondition, node.getJoinHint()); @@ -458,6 +453,20 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); } + @Override + public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { + flatten.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(flatten.getFieldToBeFlattened(), context); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + FlattenGenerator flattenGenerator = new FlattenGenerator(field); + context.apply(p -> new Generate(new GeneratorOuter(flattenGenerator), seq(), true, (Option) None$.MODULE$, seq(), p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } @@ -480,7 +489,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); String pattern = (String) node.getPattern().getValue(); - return ParseStrategy.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); + return ParseTransformer.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); } @Override @@ -574,343 +583,4 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { } } } - - /** - * Expression Analyzer. - */ - public class ExpressionAnalyzer extends AbstractNodeVisitor { - - public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { - return unresolved.accept(this, context); - } - - @Override - public Expression visitLiteral(Literal node, CatalystPlanContext context) { - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( - translate(node.getValue(), node.getType()), translate(node.getType()))); - } - - /** - * generic binary (And, Or, Xor , ...) arithmetic expression resolver - * - * @param node - * @param transformer - * @param context - * @return - */ - public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { - node.getLeft().accept(this, context); - Optional left = context.popNamedParseExpressions(); - node.getRight().accept(this, context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - return transformer.apply(left.get(), right.get()); - } else if (left.isPresent()) { - return context.getNamedParseExpressions().push(left.get()); - } else if (right.isPresent()) { - return context.getNamedParseExpressions().push(right.get()); - } - return null; - - } - - @Override - public Expression visitAnd(And node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); - } - - @Override - public Expression visitOr(Or node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); - } - - @Override - public Expression visitXor(Xor node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); - } - - @Override - public Expression visitNot(Not node, CatalystPlanContext context) { - node.getExpression().accept(this, context); - Optional arg = context.popNamedParseExpressions(); - return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); - } - - @Override - public Expression visitSpan(Span node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression field = (Expression) context.popNamedParseExpressions().get(); - node.getValue().accept(this, context); - Expression value = (Expression) context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); - } - - @Override - public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression arg = (Expression) context.popNamedParseExpressions().get(); - Expression aggregator = AggregatorTranslator.aggregator(node, arg); - return context.getNamedParseExpressions().push(aggregator); - } - - @Override - public Expression visitCompare(Compare node, CatalystPlanContext context) { - analyze(node.getLeft(), context); - Optional left = context.popNamedParseExpressions(); - analyze(node.getRight(), context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); - return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); - } - return null; - } - - @Override - public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { - List relation = findRelation(context.traversalContext()); - if (!relation.isEmpty()) { - Optional resolveField = resolveField(relation, node, context.getRelations()); - return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) - .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); - } - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - - /** - * Resolve the qualified name with subquery alias:
- * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
- * - tableName1.joinKey = subqueryAlias2.joinKey
- * - subqueryAlias1.joinKey = tableName2.joinKey
- */ - private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { - if (node.getPrefix().isPresent() && - context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { - if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) - .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) - .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - } - return null; - } - - @Override - public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { - return node.getChild().stream().map(expression -> - visitCompare((Compare) expression, context) - ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); - } - - @Override - public Expression visitAllFields(AllFields node, CatalystPlanContext context) { - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - return context.getNamedParseExpressions().peek(); - } - - @Override - public Expression visitAlias(Alias node, CatalystPlanContext context) { - node.getDelegated().accept(this, context); - Expression arg = context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, - node.getAlias() != null ? node.getAlias() : node.getName(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - } - - @Override - public Expression visitEval(Eval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Eval"); - } - - @Override - public Expression visitFunction(Function node, CatalystPlanContext context) { - List arguments = - node.getFuncArgs().stream() - .map( - unresolvedExpression -> { - var ret = analyze(unresolvedExpression, context); - if (ret == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", unresolvedExpression)); - } else { - return context.popNamedParseExpressions().get(); - } - }) - .collect(Collectors.toList()); - Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); - return context.getNamedParseExpressions().push(function); - } - - @Override - public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { - Stack namedParseExpressions = new Stack<>(); - namedParseExpressions.addAll(context.getNamedParseExpressions()); - Expression expression = visitCase(node.getCaseValue(), context); - namedParseExpressions.add(expression); - context.setNamedParseExpressions(namedParseExpressions); - return expression; - } - - @Override - public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : FillNull"); - } - - @Override - public Expression visitInterval(Interval node, CatalystPlanContext context) { - node.getValue().accept(this, context); - Expression value = context.getNamedParseExpressions().pop(); - Expression[] intervalArgs = createIntervalArgs(node.getUnit(), value); - Expression interval = MakeInterval$.MODULE$.apply( - intervalArgs[0], intervalArgs[1], intervalArgs[2], intervalArgs[3], - intervalArgs[4], intervalArgs[5], intervalArgs[6], true); - return context.getNamedParseExpressions().push(interval); - } - - @Override - public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Dedupe"); - } - - @Override - public Expression visitIn(In node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression value = context.popNamedParseExpressions().get(); - List list = node.getValueList().stream().map( expression -> { - expression.accept(this, context); - return context.popNamedParseExpressions().get(); - }).collect(Collectors.toList()); - return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); - } - - @Override - public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Kmeans"); - } - - @Override - public Expression visitCase(Case node, CatalystPlanContext context) { - Stack initialNameExpressions = new Stack<>(); - initialNameExpressions.addAll(context.getNamedParseExpressions()); - analyze(node.getElseClause(), context); - Expression elseValue = context.getNamedParseExpressions().pop(); - List> whens = new ArrayList<>(); - for (When when : node.getWhenClauses()) { - if (node.getCaseValue() == null) { - whens.add( - new Tuple2<>( - analyze(when.getCondition(), context), - analyze(when.getResult(), context) - ) - ); - } else { - // Merge case value and condition (compare value) into a single equal condition - Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); - whens.add( - new Tuple2<>( - analyze(compare, context), analyze(when.getResult(), context) - ) - ); - } - context.retainAllNamedParseExpressions(e -> e); - } - context.setNamedParseExpressions(initialNameExpressions); - return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); - } - - @Override - public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : RareTopN"); - } - - @Override - public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : WindowFunction"); - } - - @Override - public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - visitExpressionList(node.getChild(), innerContext); - Seq values = innerContext.retainAllNamedParseExpressions(p -> p); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression inSubQuery = InSubquery$.MODULE$.apply( - values, - ListQuery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - -1, - seq(new java.util.ArrayList()), - Option.empty())); - return outerContext.getNamedParseExpressions().push(inSubQuery); - } - - @Override - public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - Option.empty()); - return context.getNamedParseExpressions().push(scalarSubQuery); - } - - @Override - public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression existsSubQuery = Exists$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty()); - return context.getNamedParseExpressions().push(existsSubQuery); - } - - @Override - public Expression visitBetween(Between node, CatalystPlanContext context) { - Expression value = analyze(node.getValue(), context); - Expression lower = analyze(node.getLowerBound(), context); - Expression upper = analyze(node.getUpperBound(), context); - context.retainAllNamedParseExpressions(p -> p); - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); - } - - @Override - public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) { - analyze(node.getIpAddress(), context); - Expression ipAddressExpression = context.getNamedParseExpressions().pop(); - analyze(node.getCidrBlock(), context); - Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); - - ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddressExpression,cidrBlockExpression), - seq(), - Option.empty(), - Option.apply("cidr"), - false, - true); - - return context.getNamedParseExpressions().push(udf); - } - } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ed7717188..36a34cd06 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -82,8 +82,8 @@ public class AstBuilder extends OpenSearchPPLParserBaseVisitor { */ private String query; - public AstBuilder(AstExpressionBuilder expressionBuilder, String query) { - this.expressionBuilder = expressionBuilder; + public AstBuilder(String query) { + this.expressionBuilder = new AstExpressionBuilder(this); this.query = query; } @@ -155,14 +155,25 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct joinType = Join.JoinType.CROSS; } Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - String leftAlias = ctx.sideAlias().leftAlias.getText(); - String rightAlias = ctx.sideAlias().rightAlias.getText(); + Optional leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty(); + Optional rightAlias = Optional.empty(); if (ctx.tableOrSubqueryClause().alias != null) { - // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous - throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText()); } + if (ctx.sideAlias().rightAlias != null) { + rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText()); + } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); - UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); + // Add a SubqueryAlias to the right plan when the right alias is present and no duplicated alias existing in right. + UnresolvedPlan right; + if (rightAlias.isEmpty() || + (rightRelation instanceof SubqueryAlias && + rightAlias.get().equals(((SubqueryAlias) rightRelation).getAlias()))) { + right = rightRelation; + } else { + right = new SubqueryAlias(rightAlias.get(), rightRelation); + } Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -370,7 +381,7 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo /** Lookup command */ @Override public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { - Relation lookupRelation = new Relation(this.internalVisitExpression(ctx.tableSource())); + Relation lookupRelation = new Relation(Collections.singletonList(this.internalVisitExpression(ctx.tableSource()))); Lookup.OutputStrategy strategy = ctx.APPEND() != null ? Lookup.OutputStrategy.APPEND : Lookup.OutputStrategy.REPLACE; java.util.Map lookupMappingList = buildLookupPair(ctx.lookupMappingList().lookupPair()); @@ -386,6 +397,30 @@ private java.util.Map buildLookupPair(List (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); } + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = ctx.trendlineClause() + .stream() + .map(this::toTrendlineComputation) + .collect(Collectors.toList()); + return Optional.ofNullable(ctx.sortField()) + .map(this::internalVisitExpression) + .map(Field.class::cast) + .map(sort -> new Trendline(Optional.of(sort), trendlineComputations)) + .orElse(new Trendline(Optional.empty(), trendlineComputations)); + } + + private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParser.TrendlineClauseContext ctx) { + int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + if (numberOfDataPoints < 1) { + throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); + } + Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field); + String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText(); + String computationType = ctx.trendlineType().getText(); + return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase())); + } + /** Top command. */ @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { @@ -485,9 +520,8 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return ctx.alias == null - ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) - : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); + Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation; } @Override @@ -562,6 +596,12 @@ public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandCo } } + @Override + public UnresolvedPlan visitFlattenCommand(OpenSearchPPLParser.FlattenCommandContext ctx) { + Field unresolvedExpression = (Field) internalVisitExpression(ctx.fieldExpression()); + return new Flatten(unresolvedExpression); + } + /** AD command. */ @Override public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6eb72c91e..4b7c8a1c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -29,6 +29,7 @@ import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; @@ -43,6 +44,8 @@ import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -67,7 +70,6 @@ */ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor { - private static final int DEFAULT_TAKE_FUNCTION_SIZE_VALUE = 10; /** * The function name mapping between fronted and core engine. */ @@ -79,16 +81,10 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect( + Collectors.toList()); + UnresolvedExpression function = visitExpression(ctx.expression()); + return new LambdaFunction(function, arguments); + } + private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index 6a545f091..44610f3a4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -56,6 +56,7 @@ public StatementBuilderContext getContext() { } public static class StatementBuilderContext { + public static final int FETCH_SIZE = 1000; private int fetchSize; public StatementBuilderContext(int fetchSize) { @@ -63,8 +64,7 @@ public StatementBuilderContext(int fetchSize) { } public static StatementBuilderContext builder() { - //todo set the default statement builder init params configurable - return new StatementBuilderContext(1000); + return new StatementBuilderContext(FETCH_SIZE); } public int getFetchSize() { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java similarity index 97% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java index a01b38a80..9788ac1bc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java @@ -12,12 +12,10 @@ import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; -import java.util.Optional; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -27,7 +25,7 @@ * * @return */ -public interface AggregatorTranslator { +public interface AggregatorTransformer { static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java similarity index 90% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index d954b04b9..0b0fb8314 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -28,6 +28,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_SUB; @@ -58,6 +59,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TO_JSON_STRING; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_TIMESTAMP; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; @@ -65,7 +67,7 @@ import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; -public interface BuiltinFunctionTranslator { +public interface BuiltinFunctionTransformer { /** * The name mapping between PPL builtin functions to Spark builtin functions. @@ -102,7 +104,9 @@ public interface BuiltinFunctionTranslator { .put(COALESCE, "coalesce") .put(LENGTH, "length") .put(TRIM, "trim") + .put(ARRAY_LENGTH, "array_size") // json functions + .put(TO_JSON_STRING, "to_json") .put(JSON_KEYS, "json_object_keys") .put(JSON_EXTRACT, "get_json_object") .build(); @@ -126,26 +130,12 @@ public interface BuiltinFunctionTranslator { .put( JSON_ARRAY_LENGTH, args -> { - // Check if the input is an array (from json_array()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - // Input is a JSON array - return UnresolvedFunction$.MODULE$.apply("json_array_length", - seq(UnresolvedFunction$.MODULE$.apply("to_json", seq(args), false)), false); - } else { - // Input is a JSON string - return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); - } + return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); }) .put( JSON, args -> { - // Check if the input is a named_struct (from json_object()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - return UnresolvedFunction$.MODULE$.apply("to_json", seq(args.get(0)), false); - } else { - return UnresolvedFunction$.MODULE$.apply("get_json_object", - seq(args.get(0), Literal$.MODULE$.apply("$")), false); - } + return UnresolvedFunction$.MODULE$.apply("get_json_object", seq(args.get(0), Literal$.MODULE$.apply("$")), false); }) .put( JSON_VALID, diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 62eef90ed..e4defad52 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -14,16 +14,14 @@ import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.NullType$; import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -67,6 +65,8 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; + case UNDEFINED: + return NullType$.MODULE$; default: return StringType$.MODULE$; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java index 0866ca7e9..a34fc0184 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java @@ -16,8 +16,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; -import org.opensearch.sql.ppl.CatalystQueryPlanVisitor.ExpressionAnalyzer; import scala.collection.Seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; @@ -38,7 +38,7 @@ public interface DedupeTransformer { static LogicalPlan retainOneDuplicateEventAndKeepEmpty( Dedupe node, Seq dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); @@ -63,7 +63,7 @@ static LogicalPlan retainOneDuplicateEventAndKeepEmpty( static LogicalPlan retainOneDuplicateEvent( Dedupe node, Seq dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); @@ -87,7 +87,7 @@ static LogicalPlan retainOneDuplicateEvent( static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { // Build isnull Filter for right @@ -137,7 +137,7 @@ static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( static LogicalPlan retainMultipleDuplicateEvents( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // Build isnotnull Filter Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); @@ -163,7 +163,7 @@ static LogicalPlan retainMultipleDuplicateEvents( return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); } - private static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNotNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNotNullExpressions = context.retainAllNamedParseExpressions( @@ -180,7 +180,7 @@ private static Expression buildIsNotNullFilterExpression(Dedupe node, Expression return isNotNullExpr; } - private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNullExpressions = context.retainAllNamedParseExpressions( diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java new file mode 100644 index 000000000..40246d7c9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + + +import scala.PartialFunction; +import scala.runtime.AbstractPartialFunction; + +public interface JavaToScalaTransformer { + static PartialFunction toPartialFunction( + java.util.function.Predicate isDefinedAt, + java.util.function.Function apply) { + return new AbstractPartialFunction() { + @Override + public boolean isDefinedAt(T t) { + return isDefinedAt.test(t); + } + + @Override + public T apply(T t) { + if (isDefinedAt.test(t)) return apply.apply(t); + else return t; + } + }; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java index 58ef15ea9..3673d96d6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.tree.Lookup; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; import org.opensearch.sql.ppl.CatalystQueryPlanVisitor; import scala.Option; @@ -32,7 +33,7 @@ public interface LookupTransformer { /** lookup mapping fields + input fields*/ static List buildLookupRelationProjectList( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List inputFields = new ArrayList<>(node.getInputFieldList()); if (inputFields.isEmpty()) { @@ -45,7 +46,7 @@ static List buildLookupRelationProjectList( static List buildProjectListFromFields( List fields, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { return fields.stream().map(field -> expressionAnalyzer.visitField(field, context)) .map(NamedExpression.class::cast) @@ -54,7 +55,7 @@ static List buildProjectListFromFields( static Expression buildLookupMappingCondition( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // only equi-join conditions are accepted in lookup command List equiConditions = new ArrayList<>(); @@ -81,7 +82,7 @@ static Expression buildLookupMappingCondition( static List buildOutputProjectList( Lookup node, Lookup.OutputStrategy strategy, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List outputProjectList = new ArrayList<>(); for (Map.Entry entry : node.getOutputCandidateMap().entrySet()) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java similarity index 97% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java index 8775d077b..eed7db228 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java @@ -8,7 +8,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Field; @@ -27,7 +26,7 @@ import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -public interface ParseStrategy { +public interface ParseTransformer { /** * transform the parse/grok/patterns command into a standard catalyst RegExpExtract expression * Since spark's RegExpExtract cant accept actual regExp group name we need to translate the group's name into its corresponding index diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java index 83603b031..803daea8b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java @@ -38,7 +38,7 @@ static SortOrder getSortDirection(Sort node, NamedExpression expression) { .findAny(); return field.map(value -> sortOrder((Expression) expression, - (Boolean) value.getFieldArgs().get(0).getValue().getValue())) + isSortedAscending(value))) .orElse(null); } @@ -51,4 +51,8 @@ static SortOrder sortOrder(Expression expression, boolean ascending) { seq(new ArrayList()) ); } + + static boolean isSortedAscending(Field field) { + return (Boolean) field.getFieldArgs().get(0).getValue().getValue(); + } } \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java new file mode 100644 index 000000000..67603ccc7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.*; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; +import scala.Tuple2; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +public interface TrendlineCatalystUtils { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + return computations.stream() + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .collect(Collectors.toList()); + } + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //window lower boundary + expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); + Expression windowLowerBoundary = context.popNamedParseExpressions().get(); + + //window definition + WindowSpecDefinition windowDefinition = new WindowSpecDefinition( + seq(), + seq(), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + + if (node.getComputationType() == Trendline.TrendlineType.SMA) { + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, + node.getAlias(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList())); + } else { + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + } + } + + private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpressionVisitor expressionVisitor, WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //required number of data points + expressionVisitor.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context); + Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get(); + + //count data points function + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context); + Expression countDataPointsFunction = context.popNamedParseExpressions().get(); + //count data points window + WindowExpression countDataPointsWindow = new WindowExpression( + countDataPointsFunction, + trendlineWindow.windowSpec()); + + expressionVisitor.visitLiteral(new Literal(null, DataType.NULL), context); + Expression nullLiteral = context.popNamedParseExpressions().get(); + Tuple2 nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>( + new LessThan(countDataPointsWindow, requiredNumberOfDataPoints), + nullLiteral + ); + return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala new file mode 100644 index 000000000..23b545826 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, CreateArray, Expression, GenericInternalRow, Inline, UnaryExpression} +import org.apache.spark.sql.types.{ArrayType, StructType} + +class FlattenGenerator(override val child: Expression) + extends Inline(child) + with CollectionGenerator { + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case st: StructType => TypeCheckResult.TypeCheckSuccess + case _ => super.checkInputDataTypes() + } + + override def elementSchema: StructType = child.dataType match { + case st: StructType => st + case _ => super.elementSchema + } + + override protected def withNewChildInternal(newChild: Expression): FlattenGenerator = { + newChild.dataType match { + case ArrayType(st: StructType, _) => new FlattenGenerator(newChild) + case st: StructType => withNewChildInternal(CreateArray(Seq(newChild), false)) + case _ => + throw new IllegalArgumentException(s"Unexpected input type ${newChild.dataType}") + } + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala index c435af53d..ed498e98b 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -29,11 +29,8 @@ class PPLSyntaxParser extends Parser { object PlaneUtils { def plan(parser: PPLSyntaxParser, query: String): Statement = { - val astExpressionBuilder = new AstExpressionBuilder() - val astBuilder = new AstBuilder(astExpressionBuilder, query) - astExpressionBuilder.setAstBuilder(astBuilder) - val builder = - new AstStatementBuilder(astBuilder, AstStatementBuilder.StatementBuilderContext.builder()) - builder.visit(parser.parse(query)) + new AstStatementBuilder( + new AstBuilder(query), + AstStatementBuilder.StatementBuilderContext.builder()).visit(parser.parse(query)) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 2da93d5d8..50ef985d6 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, EqualTo, GreaterThan, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -292,6 +292,44 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("Search multiple tables - with table alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """ + | source=table1, table2, table3 as t + | | where t.name = 'Molly' + |""".stripMargin), + context) + + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val star = UnresolvedStar(None) + val plan1 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table1))) + val plan2 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table2))) + val plan3 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table3))) + + val expectedPlan = + Union(Seq(plan1, plan2, plan3), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logPlan, false) + } + test("test fields + field list") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..58a6c04b3 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, GeneratorOuter, Literal, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, GlobalLimit, LocalLimit, Project, Sort} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanFlattenCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test flatten only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | flatten field_with_array"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("field_with_array")) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), true, None, seq(), relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val context = new CatalystPlanContext + val query = + "source = relation | fields state, company, employee | flatten employee | fields state, company, salary | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val projectStateCompanyEmployee = + Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("employee")), + table) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + projectStateCompanyEmployee) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val projectStateCompanySalary = Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("salary")), + dropSourceColumn) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + projectStateCompanySalary) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and eval") { + val context = new CatalystPlanContext + val query = "source = relation | flatten employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and parse and flatten") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | flatten employee | parse description '(?.+@.+)' | flatten roles"), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumnEmployee = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generateEmployee) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + dropSourceColumnEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 3ceff7735..f4ed397e3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -271,9 +271,9 @@ class PPLLogicalPlanJoinTranslatorTestSuite pplParser, s""" | source = $testTable1 - | | inner JOIN left = l,right = r ON l.id = r.id $testTable2 - | | left JOIN left = l,right = r ON l.name = r.name $testTable3 - | | cross JOIN left = l,right = r $testTable4 + | | inner JOIN left = l right = r ON l.id = r.id $testTable2 + | | left JOIN left = l right = r ON l.name = r.name $testTable3 + | | cross JOIN left = l right = r $testTable4 | """.stripMargin) val logicalPlan = planTransformer.visit(logPlan, context) val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) @@ -443,17 +443,17 @@ class PPLLogicalPlanJoinTranslatorTestSuite s""" | source = $testTable1 | | head 10 - | | inner JOIN left = l,right = r ON l.id = r.id + | | inner JOIN left = l right = r ON l.id = r.id | [ | source = $testTable2 | | where id > 10 | ] - | | left JOIN left = l,right = r ON l.name = r.name + | | left JOIN left = l right = r ON l.name = r.name | [ | source = $testTable3 | | fields id | ] - | | cross JOIN left = l,right = r + | | cross JOIN left = l right = r | [ | source = $testTable4 | | sort id @@ -565,4 +565,284 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test multiple joins with table alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with table and subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN left = l right = r ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN left = l right = r ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN left = l right = r ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("l", SubqueryAlias("t1", table1)), + SubqueryAlias("r", SubqueryAlias("t2", table2)), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + SubqueryAlias("l", joinPlan1), + SubqueryAlias("r", SubqueryAlias("t3", table3)), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + SubqueryAlias("l", joinPlan2), + SubqueryAlias("r", SubqueryAlias("t4", table4)), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins without table aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN ON table1.id = table2.id table2 + | | JOIN ON table1.id = table3.id table3 + | | JOIN ON table2.id = table4.id table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + table4, + Inner, + Some(EqualTo(UnresolvedAttribute("table2.id"), UnresolvedAttribute("table4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name table2 + | | JOIN right = t3 ON t1.name = t3.name table3 + | | JOIN right = t4 ON t2.name = t4.name table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test side alias will override the subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt + | | fields t1.name, t2.name + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", SubqueryAlias("tt", SubqueryAlias("ttt", table2))), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val expectedPlan = + Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index f5dfc4ec8..6193bc43f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} @@ -48,7 +48,7 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', array(1, 2, 3)))"""), + plan(pplParser, """source=t a = to_json_string(json_object('key', array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -97,7 +97,9 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', json_array(1, 2, 3)))"""), + plan( + pplParser, + """source=t a = to_json_string(json_object('key', json_array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -139,25 +141,21 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } - test("test json_array_length(json_array())") { + test("test array_length(json_array())") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json_array_length(json_array(1,2,3))"""), + plan(pplParser, """source=t a = array_length(json_array(1,2,3))"""), context) val table = UnresolvedRelation(Seq("t")) val jsonFunc = UnresolvedFunction( - "json_array_length", + "array_size", Seq( UnresolvedFunction( - "to_json", - Seq( - UnresolvedFunction( - "array", - Seq(Literal(1), Literal(2), Literal(3)), - isDistinct = false)), + "array", + Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false)), isDistinct = false) val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..9c3c1c8a0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, LambdaFunction, Literal, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project + +class PPLLogicalPlanLambdaFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test forall()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = forall(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("forall", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test exits()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = exists(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("exists", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test filter()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = filter(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("filter", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform()") { + val context = new CatalystPlanContext + // test single argument of lambda + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, x -> x + 1)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction("+", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(1)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform() - test binary lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - without finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - with finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y, x -> x * 10)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val finishLambda = LambdaFunction( + UnresolvedFunction("*", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(10)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda, finishLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala new file mode 100644 index 000000000..213f201cc --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DataTypes + +class PPLLogicalPlanParseCidrmatchTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test cidrmatch for ipv4 for 192.168.1.0/24") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("192.168.1.0/24") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(false)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(false)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field projected") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and cidrmatch(ipAddress, '2003:db8::/32') | fields ip"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedAttribute("ip")), + Filter(And(filterIpv6, cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field bool respond for each ip") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true | eval inRange = case(cidrmatch(ipAddress, '2003:db8::/32'), 'in' else 'out') | fields ip, inRange"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterClause = Filter(filterIpv6, UnresolvedRelation(Seq("t"))) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val equalTo = EqualTo(Literal(true), cidr) + val caseFunction = CaseWhen(Seq((equalTo, Literal("in"))), Literal("out")) + val aliasStatusCategory = Alias(caseFunction, "inRange")() + val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory) + val evalProject = Project(evalProjectList, filterClause) + + val expectedPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("inRange")), evalProject) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..d22750ee0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanTrendlineCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test trendline") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sma(3, age)"), context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, table)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age sma(3, age)"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with multiple trendline sma commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort + age sma(2, age) as two_points_sma sma(3, age) | fields name, age, two_points_sma, age_trendline"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageTrendlineField = UnresolvedAttribute("age_trendline") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "age_trendline")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageTrendlineField), + Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } +}