diff --git a/build.sbt b/build.sbt index 30858e8d6..1300f68a0 100644 --- a/build.sbt +++ b/build.sbt @@ -154,6 +154,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", "com.github.sbt" % "junit-interface" % "0.13.3" % "test", "org.projectlombok" % "lombok" % "1.18.30", + "com.github.seancfoley" % "ipaddress" % "5.5.1", ), libraryDependencies ++= deps(sparkVersion), // ANTLR settings @@ -237,7 +238,8 @@ lazy val integtest = (project in file("integ-test")) inConfig(IntegrationTest)(Defaults.testSettings ++ Seq( IntegrationTest / javaSource := baseDirectory.value / "src/integration/java", IntegrationTest / scalaSource := baseDirectory.value / "src/integration/scala", - IntegrationTest / parallelExecution := false, + IntegrationTest / resourceDirectory := baseDirectory.value / "src/integration/resources", + IntegrationTest / parallelExecution := false, IntegrationTest / fork := true, )), inConfig(AwsIntegrationTest)(Defaults.testSettings ++ Seq( diff --git a/docs/index.md b/docs/index.md index bb3121ba6..e76cb387a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -549,6 +549,7 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i - `spark.flint.monitor.initialDelaySeconds`: Initial delay in seconds before starting the monitoring task. Default value is 15. - `spark.flint.monitor.intervalSeconds`: Interval in seconds for scheduling the monitoring task. Default value is 60. - `spark.flint.monitor.maxErrorCount`: Maximum number of consecutive errors allowed before stopping the monitoring task. Default value is 5. +- `spark.flint.metadataCacheWrite.enabled`: default is false. enable writing metadata to index mappings _meta as read cache for frontend user to access. Do not use in production, this setting will be removed in later version. #### Data Type Mapping diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 01c3f1619..4ea564111 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -58,6 +58,16 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where a not in (1, 2, 3) | fields a,b,c` - `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where cidrmatch(ipv6, '2003:db8::/32')` +- `source = table | trendline sma(2, temperature) as temp_trend` + +#### **IP related queries** +[See additional command details](functions/ppl-ip.md) + +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')` +- `source = table | where isV6 = true | eval inRange = case(cidrmatch(ipAddress, '2003:db8::/32'), 'in' else 'out') | fields ip, inRange` ```sql source = table | eval status_category = @@ -120,6 +130,15 @@ Assumptions: `a`, `b`, `c`, `d`, `e` are existing fields in `table` - `source = table | fillnull using a = 101, b = 102` - `source = table | fillnull using a = concat(b, c), d = 2 * pi() * e` +### Flatten +[See additional command details](ppl-flatten-command.md) +Assumptions: `bridges`, `coor` are existing fields in `table`, and the field's types are `struct` or `array>` +- `source = table | flatten bridges` +- `source = table | flatten coor` +- `source = table | flatten bridges | flatten coor` +- `source = table | fields bridges | flatten bridges` +- `source = table | fields country, bridges | flatten bridges | fields country, length | stats avg(length) as avg by country` + ```sql source = table | eval e = eval status_category = case(a >= 200 AND a < 300, 'Success', @@ -287,7 +306,11 @@ source = table | where ispresent(a) | - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` - +- `source = table1 | inner join on table1.a = table2.a table2 | fields table1.a, table2.a, table1.b, table1.c` (directly refer table name) +- `source = table1 | inner join on a = c table2 | fields a, b, c, d` (ignore side aliases as long as no ambiguous) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields l.a, r.a` (side alias overrides table alias) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields t1.a, t2.a` (error, side alias overrides table alias) +- `source = table1 | join left = l right = r on l.a = r.a [ source = table2 ] as s | fields l.a, s.a` (error, side alias overrides subquery alias) #### **Lookup** [See additional command details](ppl-lookup-command.md) @@ -418,8 +441,30 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in _- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ ---- -#### Experimental Commands: + +#### **fillnull** +[See additional command details](ppl-fillnull-command.md) +```sql + - `source=accounts | fillnull fields status_code=101` + - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` + - `source=accounts | fillnull using field1=101` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` +``` + +#### **expand** +[See additional command details](ppl-expand-command.md) +```sql + - `source = table | expand field_with_array as array_list` + - `source = table | expand employee | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` + - `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` + - `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` + - `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` +``` + +#### Correlation Commands: [See additional command details](ppl-correlation-command.md) ```sql @@ -431,14 +476,3 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` > ppl-correlation-command is an experimental command - it may be removed in future versions --- -### Planned Commands: - -#### **fillnull** -[See additional command details](ppl-fillnull-command.md) -```sql - - `source=accounts | fillnull fields status_code=101` - - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` - - `source=accounts | fillnull using field1=101` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` -``` diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 9cb5f118e..d72c973be 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -31,6 +31,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`describe command`](PPL-Example-Commands.md/#describe) - [`fillnull command`](ppl-fillnull-command.md) + + - [`flatten command`](ppl-flatten-command.md) - [`eval command`](ppl-eval-command.md) @@ -67,7 +69,10 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`subquery commands`](ppl-subquery-command.md) - [`correlation commands`](ppl-correlation-command.md) - + + - [`trendline commands`](ppl-trendline-command.md) + + - [`expand commands`](ppl-expand-command.md) * **Functions** @@ -87,6 +92,9 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`Cryptographic Functions`](functions/ppl-cryptographic.md) + - [`IP Address Functions`](functions/ppl-ip.md) + + - [`Lambda Functions`](functions/ppl-lambda.md) --- ### PPL On Spark @@ -98,6 +106,10 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). ### Example PPL Queries See samples of [PPL queries](PPL-Example-Commands.md) +--- +### TPC-H PPL Query Rewriting +See samples of [TPC-H PPL query rewriting](ppl-tpch.md) + --- ### Planned PPL Commands @@ -105,4 +117,4 @@ See samples of [PPL queries](PPL-Example-Commands.md) --- ### PPL Project Roadmap -[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) \ No newline at end of file +[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) diff --git a/docs/ppl-lang/functions/ppl-datetime.md b/docs/ppl-lang/functions/ppl-datetime.md index e7b423d41..e479176a4 100644 --- a/docs/ppl-lang/functions/ppl-datetime.md +++ b/docs/ppl-lang/functions/ppl-datetime.md @@ -14,7 +14,7 @@ Argument type: DATE, LONG (DATE, LONG) -> DATE -Antonyms: `SUBDATE`_ +Antonyms: `SUBDATE` Example: @@ -795,7 +795,7 @@ Argument type: DATE/TIMESTAMP, LONG (DATE, LONG) -> DATE -Antonyms: `ADDDATE`_ +Antonyms: `ADDDATE` Example: @@ -982,3 +982,134 @@ Example: +----------------------------+ +### `DATE_ADD` + +**Description:** + +Usage: date_add(date, INTERVAL expr unit) adds the interval expr to date. + +Argument type: DATE, INTERVAL + +Return type: DATE + +Antonyms: `DATE_SUB` + +Example:: + + os> source=people | eval `'2020-08-26' + 1d` = DATE_ADD(DATE('2020-08-26'), INTERVAL 1 DAY) | fields `'2020-08-26' + 1d` + fetched rows / total rows = 1/1 + +---------------------+ + | '2020-08-26' + 1d | + |---------------------+ + | 2020-08-27 | + +---------------------+ + + +### `DATE_SUB` + +**Description:** + +Usage: date_sub(date, INTERVAL expr unit) subtracts the interval expr from date. + +Argument type: DATE, INTERVAL + +Return type: DATE + +Antonyms: `DATE_ADD` + +Example:: + + os> source=people | eval `'2008-01-02' - 31d` = DATE_SUB(DATE('2008-01-02'), INTERVAL 31 DAY) | fields `'2008-01-02' - 31d` + fetched rows / total rows = 1/1 + +---------------------+ + | '2008-01-02' - 31d | + |---------------------+ + | 2007-12-02 | + +---------------------+ + + +### `TIMESTAMPADD` + +**Description:** + +Usage: Returns a TIMESTAMP value based on a passed in DATE/TIMESTAMP/STRING argument and an INTERVAL and INTEGER argument which determine the amount of time to be added. +If the third argument is a STRING, it must be formatted as a valid TIMESTAMP. +If the third argument is a DATE, it will be automatically converted to a TIMESTAMP. + +Argument type: INTERVAL, INTEGER, DATE/TIMESTAMP/STRING + +INTERVAL must be one of the following tokens: [SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] + +Examples:: + + os> source=people | eval `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | eval `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` = TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | fields `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')`, `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` + fetched rows / total rows = 1/1 + +----------------------------------------------+--------------------------------------------------+ + | TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | + |----------------------------------------------+--------------------------------------------------| + | 2000-01-18 00:00:00 | 1999-10-01 00:00:00 | + +----------------------------------------------+--------------------------------------------------+ + + +### `TIMESTAMPDIFF` + +**Description:** + +Usage: TIMESTAMPDIFF(interval, start, end) returns the difference between the start and end date/times in interval units. +Arguments will be automatically converted to a ]TIMESTAMP when appropriate. +Any argument that is a STRING must be formatted as a valid TIMESTAMP. + +Argument type: INTERVAL, DATE/TIMESTAMP/STRING, DATE/TIMESTAMP/STRING + +INTERVAL must be one of the following tokens: [SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] + +Examples:: + + os> source=people | eval `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | eval `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))` = TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00')) | fields `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')`, `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))` + fetched rows / total rows = 1/1 + +-------------------------------------------------------------------+-------------------------------------------------------------------------------------------+ + | TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00')) | + |-------------------------------------------------------------------+-------------------------------------------------------------------------------------------| + | 4 | -23 | + +-------------------------------------------------------------------+-------------------------------------------------------------------------------------------+ + + +### `UTC_TIMESTAMP` + +**Description:** + +Returns the current UTC timestamp as a value in 'YYYY-MM-DD hh:mm:ss'. + +Return type: TIMESTAMP + +Specification: UTC_TIMESTAMP() -> TIMESTAMP + +Example:: + + > source=people | eval `UTC_TIMESTAMP()` = UTC_TIMESTAMP() | fields `UTC_TIMESTAMP()` + fetched rows / total rows = 1/1 + +---------------------+ + | UTC_TIMESTAMP() | + |---------------------| + | 2022-10-03 17:54:28 | + +---------------------+ + + +### `CURRENT_TIMEZONE` + +**Description:** + +Returns the current local timezone. + +Return type: STRING + +Example:: + + > source=people | eval `CURRENT_TIMEZONE()` = CURRENT_TIMEZONE() | fields `CURRENT_TIMEZONE()` + fetched rows / total rows = 1/1 + +------------------------+ + | CURRENT_TIMEZONE() | + |------------------------| + | America/Chicago | + +------------------------+ + diff --git a/docs/ppl-lang/functions/ppl-ip.md b/docs/ppl-lang/functions/ppl-ip.md new file mode 100644 index 000000000..fb0b468ba --- /dev/null +++ b/docs/ppl-lang/functions/ppl-ip.md @@ -0,0 +1,35 @@ +## PPL IP Address Functions + +### `CIDRMATCH` + +**Description** + +`CIDRMATCH(ip, cidr)` checks if ip is within the specified cidr range. + +**Argument type:** + - STRING, STRING + - Return type: **BOOLEAN** + +Example: + + os> source=ips | where cidrmatch(ip, '192.169.1.0/24') | fields ip + fetched rows / total rows = 1/1 + +--------------+ + | ip | + |--------------| + | 192.169.1.5 | + +--------------+ + + os> source=ipsv6 | where cidrmatch(ip, '2003:db8::/32') | fields ip + fetched rows / total rows = 1/1 + +-----------------------------------------+ + | ip | + |-----------------------------------------| + | 2003:0db8:0000:0000:0000:0000:0000:0000 | + +-----------------------------------------+ + +Note: + - `ip` can be an IPv4 or an IPv6 address + - `cidr` can be an IPv4 or an IPv6 block + - `ip` and `cidr` must be either both IPv4 or both IPv6 + - `ip` and `cidr` must both be valid and non-empty/non-null \ No newline at end of file diff --git a/docs/ppl-lang/functions/ppl-json.md b/docs/ppl-lang/functions/ppl-json.md index 1953e8c70..5b26ee427 100644 --- a/docs/ppl-lang/functions/ppl-json.md +++ b/docs/ppl-lang/functions/ppl-json.md @@ -4,11 +4,11 @@ **Description** -`json(value)` Evaluates whether a value can be parsed as JSON. Returns the json string if valid, null otherwise. +`json(value)` Evaluates whether a string can be parsed as JSON format. Returns the string value if valid, null otherwise. -**Argument type:** STRING/JSON_ARRAY/JSON_OBJECT +**Argument type:** STRING -**Return type:** STRING +**Return type:** STRING/NULL A STRING expression of a valid JSON object format. @@ -47,7 +47,7 @@ A StructType expression of a valid JSON object. Example: - os> source=people | eval result = json(json_object('key', 123.45)) | fields result + os> source=people | eval result = json_object('key', 123.45) | fields result fetched rows / total rows = 1/1 +------------------+ | result | @@ -55,7 +55,7 @@ Example: | {"key":123.45} | +------------------+ - os> source=people | eval result = json(json_object('outer', json_object('inner', 123.45))) | fields result + os> source=people | eval result = json_object('outer', json_object('inner', 123.45)) | fields result fetched rows / total rows = 1/1 +------------------------------+ | result | @@ -81,13 +81,13 @@ Example: os> source=people | eval `json_array` = json_array(1, 2, 0, -1, 1.1, -0.11) fetched rows / total rows = 1/1 - +----------------------------+ - | json_array | - +----------------------------+ - | 1.0,2.0,0.0,-1.0,1.1,-0.11 | - +----------------------------+ + +------------------------------+ + | json_array | + +------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +------------------------------+ - os> source=people | eval `json_array_object` = json(json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11))) + os> source=people | eval `json_array_object` = json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11)) fetched rows / total rows = 1/1 +----------------------------------------+ | json_array_object | @@ -95,15 +95,44 @@ Example: | {"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]} | +----------------------------------------+ +### `TO_JSON_STRING` + +**Description** + +`to_json_string(jsonObject)` Returns a JSON string with a given json object value. + +**Argument type:** JSON_OBJECT (Spark StructType/ArrayType) + +**Return type:** STRING + +Example: + + os> source=people | eval `json_string` = to_json_string(json_array(1, 2, 0, -1, 1.1, -0.11)) | fields json_string + fetched rows / total rows = 1/1 + +--------------------------------+ + | json_string | + +--------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +--------------------------------+ + + os> source=people | eval `json_string` = to_json_string(json_object('key', 123.45)) | fields json_string + fetched rows / total rows = 1/1 + +-----------------+ + | json_string | + +-----------------+ + | {'key', 123.45} | + +-----------------+ + + ### `JSON_ARRAY_LENGTH` **Description** -`json_array_length(jsonArray)` Returns the number of elements in the outermost JSON array. +`json_array_length(jsonArrayString)` Returns the number of elements in the outermost JSON array string. -**Argument type:** STRING/JSON_ARRAY +**Argument type:** STRING -A STRING expression of a valid JSON array format, or JSON_ARRAY object. +A STRING expression of a valid JSON array format. **Return type:** INTEGER @@ -119,6 +148,21 @@ Example: | 4 | 5 | null | +-----------+-----------+-------------+ + +### `ARRAY_LENGTH` + +**Description** + +`array_length(jsonArray)` Returns the number of elements in the outermost array. + +**Argument type:** ARRAY + +ARRAY or JSON_ARRAY object. + +**Return type:** INTEGER + +Example: + os> source=people | eval `json_array` = json_array_length(json_array(1,2,3,4)), `empty_array` = json_array_length(json_array()) fetched rows / total rows = 1/1 +--------------+---------------+ @@ -127,6 +171,7 @@ Example: | 4 | 0 | +--------------+---------------+ + ### `JSON_EXTRACT` **Description** diff --git a/docs/ppl-lang/functions/ppl-lambda.md b/docs/ppl-lang/functions/ppl-lambda.md new file mode 100644 index 000000000..cdb6f9e8f --- /dev/null +++ b/docs/ppl-lang/functions/ppl-lambda.md @@ -0,0 +1,187 @@ +## Lambda Functions + +### `FORALL` + +**Description** + +`forall(array, lambda)` Evaluates whether a lambda predicate holds for all elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + **Note:** The lambda expression can access the nested fields of the array elements. This applies to all lambda functions introduced in this document. + +Consider constructing the following array: + + array = [ + {"a":1, "b":1}, + {"a":-1, "b":2} + ] + +and perform lambda functions against the nested fields `a` or `b`. See the examples: + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + +### `EXISTS` + +**Description** + +`exists(array, lambda)` Evaluates whether a lambda predicate holds for one or more elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if at least one element in the array satisfies the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + +### `FILTER` + +**Description** + +`filter(array, lambda)` Filters the input array using the given lambda function. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains all elements in the input array that satisfy the lambda predicate. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [1, 2] | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [] | + +-----------+ + +### `TRANSFORM` + +**Description** + +`transform(array, lambda)` Transform elements in an array using the lambda transform function. The second argument implies the index of the element if using binary lambda function. This is similar to a `map` in functional programming. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains the result of applying the lambda transform function to each element in the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [2, 3, 4] | + +--------------+ + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [1, 3, 5] | + +--------------+ + +### `REDUCE` + +**Description** + +`reduce(array, start, merge_lambda, finish_lambda)` Applies a binary merge lambda function to a start value and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish lambda function. + +**Argument type:** ARRAY, ANY, LAMBDA, LAMBDA + +**Return type:** ANY + +The final result of applying the lambda functions to the start value and the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 6 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 16 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 60 | + +-----------+ diff --git a/docs/ppl-lang/ppl-expand-command.md b/docs/ppl-lang/ppl-expand-command.md new file mode 100644 index 000000000..144c0aafa --- /dev/null +++ b/docs/ppl-lang/ppl-expand-command.md @@ -0,0 +1,45 @@ +## PPL `expand` command + +### Description +Using `expand` command to flatten a field of type: +- `Array` +- `Map` + + +### Syntax +`expand [As alias]` + +* field: to be expanded (exploded). The field must be of supported type. +* alias: Optional to be expanded as the name to be used instead of the original field name + +### Usage Guidelines +The expand command produces a row for each element in the specified array or map field, where: +- Array elements become individual rows. +- Map key-value pairs are broken into separate rows, with each key-value represented as a row. + +- When an alias is provided, the exploded values are represented under the alias instead of the original field name. +- This can be used in combination with other commands, such as stats, eval, and parse to manipulate or extract data post-expansion. + +### Examples: +- `source = table | expand employee | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` +- `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` +- `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` +- `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` + +- Expand command can be used in combination with other commands such as `eval`, `stats` and more +- Using multiple expand commands will create a cartesian product of all the internal elements within each composite array or map + +### Effective SQL push-down query +The expand command is translated into an equivalent SQL operation using LATERAL VIEW explode, allowing for efficient exploding of arrays or maps at the SQL query level. + +```sql +SELECT customer exploded_productId +FROM table +LATERAL VIEW explode(productId) AS exploded_productId +``` +Where the `explode` command offers the following functionality: +- it is a column operation that returns a new column +- it creates a new row for every element in the exploded column +- internal `null`s are ignored as part of the exploded field (no row is created/exploded for null) diff --git a/docs/ppl-lang/ppl-flatten-command.md b/docs/ppl-lang/ppl-flatten-command.md new file mode 100644 index 000000000..4c1ae5d0d --- /dev/null +++ b/docs/ppl-lang/ppl-flatten-command.md @@ -0,0 +1,90 @@ +## PPL `flatten` command + +### Description +Using `flatten` command to flatten a field of type: +- `struct` +- `array>` + + +### Syntax +`flatten ` + +* field: to be flattened. The field must be of supported type. + +### Test table +#### Schema +| col\_name | data\_type | +|-----------|-------------------------------------------------| +| \_time | string | +| bridges | array\\> | +| city | string | +| coor | struct\ | +| country | string | +#### Data +| \_time | bridges | city | coor | country | +|---------------------|----------------------------------------------|---------|------------------------|---------------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | {35, 51.5074, -0.1278} | England | +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | {35, 48.8566, 2.3522} | France | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | {2, 45.4408, 12.3155} | Italy | +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | {200, 50.0755, 14.4378}| Czech Republic| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| {96, 47.4979, 19.0402} | Hungary | +| 1990-09-13T12:00:00 | NULL | Warsaw | NULL | Poland | + + + +### Example 1: flatten struct +This example shows how to flatten a struct field. +PPL query: + - `source=table | flatten coor` + +| \_time | bridges | city | country | alt | lat | long | +|---------------------|----------------------------------------------|---------|---------------|-----|--------|--------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | England | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | France | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | Italy | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | Czech Republic| 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| Hungary | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | NULL | Warsaw | Poland | NULL| NULL | NULL | + + + +### Example 2: flatten array + +The example shows how to flatten an array of struct fields. + +PPL query: + - `source=table | flatten bridges` + +| \_time | city | coor | country | length | name | +|---------------------|---------|------------------------|---------------|--------|-------------------| +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 801 | Tower Bridge | +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 928 | London Bridge | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 232 | Pont Neuf | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 160 | Pont Alexandre III| +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 48 | Rialto Bridge | +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 11 | Bridge of Sighs | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 516 | Charles Bridge | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 343 | Legion Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 375 | Chain Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 333 | Liberty Bridge | +| 1990-09-13T12:00:00 | Warsaw | NULL | Poland | NULL | NULL | + + +### Example 3: flatten array and struct +This example shows how to flatten multiple fields. +PPL query: + - `source=table | flatten bridges | flatten coor` + +| \_time | city | country | length | name | alt | lat | long | +|---------------------|---------|---------------|--------|-------------------|------|--------|--------| +| 2024-09-13T12:00:00 | London | England | 801 | Tower Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | London | England | 928 | London Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | Paris | France | 232 | Pont Neuf | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Paris | France | 160 | Pont Alexandre III| 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Venice | Italy | 48 | Rialto Bridge | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Venice | Italy | 11 | Bridge of Sighs | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 516 | Charles Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 343 | Legion Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Budapest| Hungary | 375 | Chain Bridge | 96 | 47.4979| 19.0402| +| 2024-09-13T12:00:00 | Budapest| Hungary | 333 | Liberty Bridge | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | Warsaw | Poland | NULL | NULL | NULL | NULL | NULL | \ No newline at end of file diff --git a/docs/ppl-lang/ppl-join-command.md b/docs/ppl-lang/ppl-join-command.md index 525373f7c..b374bce5f 100644 --- a/docs/ppl-lang/ppl-join-command.md +++ b/docs/ppl-lang/ppl-join-command.md @@ -65,8 +65,8 @@ WHERE t1.serviceName = `order` SEARCH source= | | [joinType] JOIN - leftAlias - rightAlias + [leftAlias] + [rightAlias] [joinHints] ON joinCriteria @@ -79,12 +79,12 @@ SEARCH source= **leftAlias** - Syntax: `left = ` -- Required +- Optional - Description: The subquery alias to use with the left join side, to avoid ambiguous naming. **rightAlias** - Syntax: `right = ` -- Required +- Optional - Description: The subquery alias to use with the right join side, to avoid ambiguous naming. **joinHints** @@ -138,11 +138,11 @@ Rewritten by PPL Join query: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o - ON c.c_custkey = o.o_custkey AND o_comment NOT LIKE '%unusual%packages%' +| LEFT OUTER JOIN + ON c_custkey = o_custkey AND o_comment NOT LIKE '%unusual%packages%' orders -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` _- **Limitation: sub-searches is unsupported in join right side**_ @@ -151,14 +151,15 @@ If sub-searches is supported, above ppl query could be rewritten as: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o ON c.c_custkey = o.o_custkey +| LEFT OUTER JOIN + ON c_custkey = o_custkey [ SEARCH source=orders | WHERE o_comment NOT LIKE '%unusual%packages%' | FIELDS o_orderkey, o_custkey ] -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` diff --git a/docs/ppl-lang/ppl-tpch.md b/docs/ppl-lang/ppl-tpch.md new file mode 100644 index 000000000..ef5846ce0 --- /dev/null +++ b/docs/ppl-lang/ppl-tpch.md @@ -0,0 +1,102 @@ +## TPC-H Benchmark + +TPC-H is a decision support benchmark designed to evaluate the performance of database systems in handling complex business-oriented queries and concurrent data modifications. The benchmark utilizes a dataset that is broadly representative of various industries, making it widely applicable. TPC-H simulates a decision support environment where large volumes of data are analyzed, intricate queries are executed, and critical business questions are answered. + +### Test PPL Queries + +TPC-H 22 test query statements: [TPCH-Query-PPL](https://github.com/opensearch-project/opensearch-spark/blob/main/integ-test/src/integration/resources/tpch) + +### Data Preparation + +#### Option 1 - from PyPi + +``` +# Create the virtual environment +python3 -m venv .venv + +# Activate the virtual environment +. .venv/bin/activate + +pip install tpch-datagen +``` + +#### Option 2 - from source + +``` +git clone https://github.com/gizmodata/tpch-datagen + +cd tpch-datagen + +# Create the virtual environment +python3 -m venv .venv + +# Activate the virtual environment +. .venv/bin/activate + +# Upgrade pip, setuptools, and wheel +pip install --upgrade pip setuptools wheel + +# Install TPC-H Datagen - in editable mode with client and dev dependencies +pip install --editable .[dev] +``` + +#### Usage + +Here are the options for the tpch-datagen command: +``` +tpch-datagen --help +Usage: tpch-datagen [OPTIONS] + +Options: + --version / --no-version Prints the TPC-H Datagen package version and + exits. [required] + --scale-factor INTEGER The TPC-H Scale Factor to use for data + generation. + --data-directory TEXT The target output data directory to put the + files into [default: data; required] + --work-directory TEXT The work directory to use for data + generation. [default: /tmp; required] + --overwrite / --no-overwrite Can we overwrite the target directory if it + already exists... [default: no-overwrite; + required] + --num-chunks INTEGER The number of chunks that will be generated + - more chunks equals smaller memory + requirements, but more files generated. + [default: 10; required] + --num-processes INTEGER The maximum number of processes for the + multi-processing pool to use for data + generation. [default: 10; required] + --duckdb-threads INTEGER The number of DuckDB threads to use for data + generation (within each job process). + [default: 1; required] + --per-thread-output / --no-per-thread-output + Controls whether to write the output to a + single file or multiple files (for each + process). [default: per-thread-output; + required] + --compression-method [none|snappy|gzip|zstd] + The compression method to use for the + parquet files generated. [default: zstd; + required] + --file-size-bytes TEXT The target file size for the parquet files + generated. [default: 100m; required] + --help Show this message and exit. +``` + +### Generate 1 GB data with zstd (by default) compression + +``` +tpch-datagen --scale-factor 1 +``` + +### Generate 10 GB data with snappy compression + +``` +tpch-datagen --scale-factor 10 --compression-method snappy +``` + +### Query Test + +All TPC-H PPL Queries located in `integ-test/src/integration/resources/tpch` folder. + +To test all queries, run `org.opensearch.flint.spark.ppl.tpch.TPCHQueryITSuite`. \ No newline at end of file diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md new file mode 100644 index 000000000..393a9dd59 --- /dev/null +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -0,0 +1,60 @@ +## PPL trendline Command + +**Description** +Using ``trendline`` command to calculate moving averages of fields. + + +### Syntax +`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving average should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +And the moment only the Simple Moving Average (SMA) type is supported. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data-point + n: The number of data-points in the moving window (period) + t: The current time index + + SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t + +### Example 1: Calculate simple moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.5| + | 14| 257|2023-04-06 17:07:...| 13.5| + | 15| 258|2023-04-06 17:07:...| 14.5| + +-----------+---------+--------------------+----------+ + +### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id sma(2, temperature) as temp_trend_2 sma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.5| NULL| + | 13| 256|2023-04-06 17:07:...| 13.5| 14.0| + | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| + | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| + +-----------+---------+--------------------+------------+------------------+ diff --git a/docs/ppl-lang/ppl-where-command.md b/docs/ppl-lang/ppl-where-command.md index 89a7e61fa..c954623c3 100644 --- a/docs/ppl-lang/ppl-where-command.md +++ b/docs/ppl-lang/ppl-where-command.md @@ -43,6 +43,8 @@ PPL query: - `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'` - `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | eval status_category = case(a >= 200 AND a < 300, 'Success', diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala index 982b7df23..1cae64b83 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala @@ -18,6 +18,10 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState * log entry id * @param state * Flint index state + * @param lastRefreshStartTime + * timestamp when last refresh started for manual or external scheduler refresh + * @param lastRefreshCompleteTime + * timestamp when last refresh completed for manual or external scheduler refresh * @param entryVersion * entry version fields for consistency control * @param error @@ -28,10 +32,12 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState case class FlintMetadataLogEntry( id: String, /** - * This is currently used as streaming job start time. In future, this should represent the - * create timestamp of the log entry + * This is currently used as streaming job start time for internal scheduler. In future, this + * should represent the create timestamp of the log entry */ createTime: Long, + lastRefreshStartTime: Long, + lastRefreshCompleteTime: Long, state: IndexState, entryVersion: Map[String, Any], error: String, @@ -40,26 +46,48 @@ case class FlintMetadataLogEntry( def this( id: String, createTime: Long, + lastRefreshStartTime: Long, + lastRefreshCompleteTime: Long, state: IndexState, entryVersion: JMap[String, Any], error: String, properties: JMap[String, Any]) = { - this(id, createTime, state, entryVersion.asScala.toMap, error, properties.asScala.toMap) + this( + id, + createTime, + lastRefreshStartTime, + lastRefreshCompleteTime, + state, + entryVersion.asScala.toMap, + error, + properties.asScala.toMap) } def this( id: String, createTime: Long, + lastRefreshStartTime: Long, + lastRefreshCompleteTime: Long, state: IndexState, entryVersion: JMap[String, Any], error: String, properties: Map[String, Any]) = { - this(id, createTime, state, entryVersion.asScala.toMap, error, properties) + this( + id, + createTime, + lastRefreshStartTime, + lastRefreshCompleteTime, + state, + entryVersion.asScala.toMap, + error, + properties) } } object FlintMetadataLogEntry { + val EMPTY_TIMESTAMP = 0L + /** * Flint index state enum. */ diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 35297de6a..ef4d01652 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -160,6 +160,21 @@ public final class MetricConstants { */ public static final String CHECKPOINT_DELETE_TIME_METRIC = "checkpoint.delete.processingTime"; + /** + * Prefix for externalScheduler metrics + */ + public static final String EXTERNAL_SCHEDULER_METRIC_PREFIX = "externalScheduler."; + + /** + * Metric prefix for tracking the index state transitions + */ + public static final String INDEX_STATE_UPDATED_TO_PREFIX = "indexState.updatedTo."; + + /** + * Metric for tracking the index state transitions + */ + public static final String INITIAL_CONDITION_CHECK_FAILED_PREFIX = "initialConditionCheck.failed."; + private MetricConstants() { // Private constructor to prevent instantiation } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java index e6fed4126..466787c81 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java @@ -16,6 +16,8 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLog; import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry; import org.opensearch.flint.common.metadata.log.OptimisticTransaction; +import org.opensearch.flint.core.metrics.MetricConstants; +import org.opensearch.flint.core.metrics.MetricsUtil; /** * Default optimistic transaction implementation that captures the basic workflow for @@ -73,6 +75,7 @@ public T commit(Function operation) { // Perform initial log check if (!initialCondition.test(latest)) { LOG.warning("Initial log entry doesn't satisfy precondition " + latest); + emitConditionCheckFailedMetric(latest); throw new IllegalStateException( String.format("Index state [%s] doesn't satisfy precondition", latest.state())); } @@ -86,6 +89,8 @@ public T commit(Function operation) { initialLog = initialLog.copy( initialLog.id(), initialLog.createTime(), + initialLog.lastRefreshStartTime(), + initialLog.lastRefreshCompleteTime(), initialLog.state(), latest.entryVersion(), initialLog.error(), @@ -102,6 +107,7 @@ public T commit(Function operation) { metadataLog.purge(); } else { metadataLog.add(finalLog); + emitFinalLogStateMetric(finalLog); } return result; } catch (Exception e) { @@ -117,4 +123,12 @@ public T commit(Function operation) { throw new IllegalStateException("Failed to commit transaction operation", e); } } + + private void emitConditionCheckFailedMetric(FlintMetadataLogEntry latest) { + MetricsUtil.addHistoricGauge(MetricConstants.INITIAL_CONDITION_CHECK_FAILED_PREFIX + latest.state() + ".count", 1); + } + + private void emitFinalLogStateMetric(FlintMetadataLogEntry finalLog) { + MetricsUtil.addHistoricGauge(MetricConstants.INDEX_STATE_UPDATED_TO_PREFIX + finalLog.state() + ".count", 1); + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java index 0b78304d2..f90dda9a0 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java @@ -101,7 +101,7 @@ public static String toJson(FlintMetadataLogEntry logEntry) throws JsonProcessin ObjectMapper mapper = new ObjectMapper(); ObjectNode json = mapper.createObjectNode(); - json.put("version", "1.0"); + json.put("version", "1.1"); json.put("latestId", logEntry.id()); json.put("type", "flintindexstate"); json.put("state", logEntry.state().toString()); @@ -109,6 +109,8 @@ public static String toJson(FlintMetadataLogEntry logEntry) throws JsonProcessin json.put("jobId", jobId); json.put("dataSourceName", logEntry.properties().get("dataSourceName").get().toString()); json.put("jobStartTime", logEntry.createTime()); + json.put("lastRefreshStartTime", logEntry.lastRefreshStartTime()); + json.put("lastRefreshCompleteTime", logEntry.lastRefreshCompleteTime()); json.put("lastUpdateTime", lastUpdateTime); json.put("error", logEntry.error()); @@ -138,6 +140,8 @@ public static FlintMetadataLogEntry constructLogEntry( id, /* sourceMap may use Integer or Long even though it's always long in index mapping */ ((Number) sourceMap.get("jobStartTime")).longValue(), + ((Number) sourceMap.get("lastRefreshStartTime")).longValue(), + ((Number) sourceMap.get("lastRefreshCompleteTime")).longValue(), FlintMetadataLogEntry.IndexState$.MODULE$.from((String) sourceMap.get("state")), Map.of("seqNo", seqNo, "primaryTerm", primaryTerm), (String) sourceMap.get("error"), diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 8c327b664..24c9df492 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -132,7 +132,9 @@ public void purge() { public FlintMetadataLogEntry emptyLogEntry() { return new FlintMetadataLogEntry( "", - 0L, + FlintMetadataLogEntry.EMPTY_TIMESTAMP(), + FlintMetadataLogEntry.EMPTY_TIMESTAMP(), + FlintMetadataLogEntry.EMPTY_TIMESTAMP(), FlintMetadataLogEntry.IndexState$.MODULE$.EMPTY(), Map.of("seqNo", UNASSIGNED_SEQ_NO, "primaryTerm", UNASSIGNED_PRIMARY_TERM), "", @@ -146,6 +148,8 @@ private FlintMetadataLogEntry createLogEntry(FlintMetadataLogEntry logEntry) { logEntry.copy( latestId, logEntry.createTime(), + logEntry.lastRefreshStartTime(), + logEntry.lastRefreshCompleteTime(), logEntry.state(), logEntry.entryVersion(), logEntry.error(), @@ -184,6 +188,8 @@ private FlintMetadataLogEntry writeLogEntry( logEntry = new FlintMetadataLogEntry( logEntry.id(), logEntry.createTime(), + logEntry.lastRefreshStartTime(), + logEntry.lastRefreshCompleteTime(), logEntry.state(), Map.of("seqNo", response.getSeqNo(), "primaryTerm", response.getPrimaryTerm()), logEntry.error(), diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransactionMetricTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransactionMetricTest.java new file mode 100644 index 000000000..838e9978c --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransactionMetricTest.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata.log; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.flint.common.metadata.log.FlintMetadataLog; +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry; + +import java.util.Optional; +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState$; +import org.opensearch.flint.core.metrics.MetricsTestUtil; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class DefaultOptimisticTransactionMetricTest { + + @Mock + private FlintMetadataLog metadataLog; + + @Mock + private FlintMetadataLogEntry logEntry; + + @InjectMocks + private DefaultOptimisticTransaction transaction; + + @Test + void testCommitWithValidInitialCondition() throws Exception { + MetricsTestUtil.withMetricEnv(verifier -> { + when(metadataLog.getLatest()).thenReturn(Optional.of(logEntry)); + when(metadataLog.add(any(FlintMetadataLogEntry.class))).thenReturn(logEntry); + when(logEntry.state()).thenReturn(IndexState$.MODULE$.ACTIVE()); + + transaction.initialLog(entry -> true) + .transientLog(entry -> logEntry) + .finalLog(entry -> logEntry) + .commit(entry -> "Success"); + + verify(metadataLog, times(2)).add(logEntry); + verifier.assertHistoricGauge("indexState.updatedTo.active.count", 1); + }); + } + + @Test + void testConditionCheckFailed() throws Exception { + MetricsTestUtil.withMetricEnv(verifier -> { + when(metadataLog.getLatest()).thenReturn(Optional.of(logEntry)); + when(logEntry.state()).thenReturn(IndexState$.MODULE$.DELETED()); + + transaction.initialLog(entry -> false) + .finalLog(entry -> logEntry); + + assertThrows(IllegalStateException.class, () -> { + transaction.commit(entry -> "Should Fail"); + }); + verifier.assertHistoricGauge("initialConditionCheck.failed.deleted.count", 1); + }); + } +} diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala b/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala index 577dfc5fc..2708d48e8 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala @@ -25,6 +25,10 @@ class FlintMetadataLogEntryOpenSearchConverterTest val sourceMap = JMap.of( "jobStartTime", 1234567890123L.asInstanceOf[Object], + "lastRefreshStartTime", + 1234567890123L.asInstanceOf[Object], + "lastRefreshCompleteTime", + 1234567890123L.asInstanceOf[Object], "state", "active".asInstanceOf[Object], "dataSourceName", @@ -36,6 +40,8 @@ class FlintMetadataLogEntryOpenSearchConverterTest when(mockLogEntry.id).thenReturn("id") when(mockLogEntry.state).thenReturn(FlintMetadataLogEntry.IndexState.ACTIVE) when(mockLogEntry.createTime).thenReturn(1234567890123L) + when(mockLogEntry.lastRefreshStartTime).thenReturn(1234567890123L) + when(mockLogEntry.lastRefreshCompleteTime).thenReturn(1234567890123L) when(mockLogEntry.error).thenReturn("") when(mockLogEntry.properties).thenReturn(Map("dataSourceName" -> "testDataSource")) } @@ -45,7 +51,7 @@ class FlintMetadataLogEntryOpenSearchConverterTest val expectedJsonWithoutLastUpdateTime = s""" |{ - | "version": "1.0", + | "version": "1.1", | "latestId": "id", | "type": "flintindexstate", | "state": "active", @@ -53,6 +59,8 @@ class FlintMetadataLogEntryOpenSearchConverterTest | "jobId": "unknown", | "dataSourceName": "testDataSource", | "jobStartTime": 1234567890123, + | "lastRefreshStartTime": 1234567890123, + | "lastRefreshCompleteTime": 1234567890123, | "error": "" |} |""".stripMargin @@ -67,15 +75,22 @@ class FlintMetadataLogEntryOpenSearchConverterTest logEntry shouldBe a[FlintMetadataLogEntry] logEntry.id shouldBe "id" logEntry.createTime shouldBe 1234567890123L + logEntry.lastRefreshStartTime shouldBe 1234567890123L + logEntry.lastRefreshCompleteTime shouldBe 1234567890123L logEntry.state shouldBe FlintMetadataLogEntry.IndexState.ACTIVE logEntry.error shouldBe "" logEntry.properties.get("dataSourceName").get shouldBe "testDataSource" } - it should "construct log entry with integer jobStartTime value" in { + it should "construct log entry with integer timestamp value" in { + // Use Integer instead of Long for timestamps val testSourceMap = JMap.of( "jobStartTime", - 1234567890.asInstanceOf[Object], // Integer instead of Long + 1234567890.asInstanceOf[Object], + "lastRefreshStartTime", + 1234567890.asInstanceOf[Object], + "lastRefreshCompleteTime", + 1234567890.asInstanceOf[Object], "state", "active".asInstanceOf[Object], "dataSourceName", @@ -87,6 +102,8 @@ class FlintMetadataLogEntryOpenSearchConverterTest logEntry shouldBe a[FlintMetadataLogEntry] logEntry.id shouldBe "id" logEntry.createTime shouldBe 1234567890 + logEntry.lastRefreshStartTime shouldBe 1234567890 + logEntry.lastRefreshCompleteTime shouldBe 1234567890 logEntry.state shouldBe FlintMetadataLogEntry.IndexState.ACTIVE logEntry.error shouldBe "" logEntry.properties.get("dataSourceName").get shouldBe "testDataSource" diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 68721d235..bdcc120c0 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -253,6 +253,10 @@ object FlintSparkConf { FlintConfig("spark.metadata.accessAWSCredentialsProvider") .doc("AWS credentials provider for metadata access permission") .createOptional() + val METADATA_CACHE_WRITE = FlintConfig("spark.flint.metadataCacheWrite.enabled") + .doc("Enable Flint metadata cache write to Flint index mappings") + .createWithDefault("false") + val CUSTOM_SESSION_MANAGER = FlintConfig("spark.flint.job.customSessionManager") .createOptional() @@ -309,6 +313,8 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable def monitorMaxErrorCount(): Int = MONITOR_MAX_ERROR_COUNT.readFrom(reader).toInt + def isMetadataCacheWriteEnabled: Boolean = METADATA_CACHE_WRITE.readFrom(reader).toBoolean + /** * spark.sql.session.timeZone */ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 532bd8e60..68d2409ee 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -20,6 +20,7 @@ import org.opensearch.flint.core.metadata.log.FlintMetadataLogServiceBuilder import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._ import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex +import org.opensearch.flint.spark.metadatacache.FlintMetadataCacheWriterBuilder import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode._ @@ -56,8 +57,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w FlintIndexMetadataServiceBuilder.build(flintSparkConf.flintOptions()) } + private val flintMetadataCacheWriter = FlintMetadataCacheWriterBuilder.build(flintSparkConf) + private val flintAsyncQueryScheduler: AsyncQueryScheduler = { - AsyncQuerySchedulerBuilder.build(flintSparkConf.flintOptions()) + AsyncQuerySchedulerBuilder.build(spark, flintSparkConf.flintOptions()) } override protected val flintMetadataLogService: FlintMetadataLogService = { @@ -117,7 +120,6 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w throw new IllegalStateException(s"Flint index $indexName already exists") } } else { - val metadata = index.metadata() val jobSchedulingService = FlintSparkJobSchedulingService.create( index, spark, @@ -129,15 +131,18 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .transientLog(latest => latest.copy(state = CREATING)) .finalLog(latest => latest.copy(state = ACTIVE)) .commit(latest => { - if (latest == null) { // in case transaction capability is disabled - flintClient.createIndex(indexName, metadata) - flintIndexMetadataService.updateIndexMetadata(indexName, metadata) - } else { - logInfo(s"Creating index with metadata log entry ID ${latest.id}") - flintClient.createIndex(indexName, metadata.copy(latestId = Some(latest.id))) - flintIndexMetadataService - .updateIndexMetadata(indexName, metadata.copy(latestId = Some(latest.id))) + val metadata = latest match { + case null => // in case transaction capability is disabled + index.metadata() + case latestEntry => + logInfo(s"Creating index with metadata log entry ID ${latestEntry.id}") + index + .metadata() + .copy(latestId = Some(latestEntry.id), latestLogEntry = Some(latest)) } + flintClient.createIndex(indexName, metadata) + flintIndexMetadataService.updateIndexMetadata(indexName, metadata) + flintMetadataCacheWriter.updateMetadataCache(indexName, metadata) jobSchedulingService.handleJob(index, AsyncQuerySchedulerAction.SCHEDULE) }) } @@ -156,22 +161,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val index = describeIndex(indexName) .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) val indexRefresh = FlintSparkIndexRefresh.create(indexName, index) - tx - .initialLog(latest => latest.state == ACTIVE) - .transientLog(latest => - latest.copy(state = REFRESHING, createTime = System.currentTimeMillis())) - .finalLog(latest => { - // Change state to active if full, otherwise update index state regularly - if (indexRefresh.refreshMode == AUTO) { - logInfo("Scheduling index state monitor") - flintIndexMonitor.startMonitor(indexName) - latest - } else { - logInfo("Updating index state to active") - latest.copy(state = ACTIVE) - } - }) - .commit(_ => indexRefresh.start(spark, flintSparkConf)) + indexRefresh.refreshMode match { + case AUTO => refreshIndexAuto(index, indexRefresh, tx) + case FULL | INCREMENTAL => refreshIndexManual(index, indexRefresh, tx) + } }.flatten /** @@ -190,7 +183,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w attachLatestLogEntry(indexName, metadata) } .toList - .flatMap(FlintSparkIndexFactory.create) + .flatMap(metadata => FlintSparkIndexFactory.create(spark, metadata)) } else { Seq.empty } @@ -209,7 +202,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w if (flintClient.exists(indexName)) { val metadata = flintIndexMetadataService.getIndexMetadata(indexName) val metadataWithEntry = attachLatestLogEntry(indexName, metadata) - FlintSparkIndexFactory.create(metadataWithEntry) + FlintSparkIndexFactory.create(spark, metadataWithEntry) } else { Option.empty } @@ -334,7 +327,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val index = describeIndex(indexName) if (index.exists(_.options.autoRefresh())) { - val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get + val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(spark, index.get).get FlintSparkIndexRefresh .create(updatedIndex.name(), updatedIndex) .validate(spark) @@ -520,6 +513,63 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w updatedOptions.isExternalSchedulerEnabled() != originalOptions.isExternalSchedulerEnabled() } + /** + * Handles refresh for refresh mode AUTO, which is used exclusively by auto refresh index with + * internal scheduler. Refresh start time and complete time aren't tracked for streaming job. + * TODO: in future, track MicroBatchExecution time for streaming job and update as well + */ + private def refreshIndexAuto( + index: FlintSparkIndex, + indexRefresh: FlintSparkIndexRefresh, + tx: OptimisticTransaction[Option[String]]): Option[String] = { + val indexName = index.name + tx + .initialLog(latest => latest.state == ACTIVE) + .transientLog(latest => + latest.copy(state = REFRESHING, createTime = System.currentTimeMillis())) + .finalLog(latest => { + logInfo("Scheduling index state monitor") + flintIndexMonitor.startMonitor(indexName) + latest + }) + .commit(_ => indexRefresh.start(spark, flintSparkConf)) + } + + /** + * Handles refresh for refresh mode FULL and INCREMENTAL, which is used by full refresh index, + * incremental refresh index, and auto refresh index with external scheduler. Stores refresh + * start time and complete time. + */ + private def refreshIndexManual( + index: FlintSparkIndex, + indexRefresh: FlintSparkIndexRefresh, + tx: OptimisticTransaction[Option[String]]): Option[String] = { + val indexName = index.name + tx + .initialLog(latest => latest.state == ACTIVE) + .transientLog(latest => { + val currentTime = System.currentTimeMillis() + val updatedLatest = latest + .copy(state = REFRESHING, createTime = currentTime, lastRefreshStartTime = currentTime) + flintMetadataCacheWriter + .updateMetadataCache( + indexName, + index.metadata.copy(latestLogEntry = Some(updatedLatest))) + updatedLatest + }) + .finalLog(latest => { + logInfo("Updating index state to active") + val updatedLatest = + latest.copy(state = ACTIVE, lastRefreshCompleteTime = System.currentTimeMillis()) + flintMetadataCacheWriter + .updateMetadataCache( + indexName, + index.metadata.copy(latestLogEntry = Some(updatedLatest))) + updatedLatest + }) + .commit(_ => indexRefresh.start(spark, flintSparkConf)) + } + private def updateIndexAutoToManual( index: FlintSparkIndex, tx: OptimisticTransaction[Option[String]]): Option[String] = { @@ -539,8 +589,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .transientLog(latest => latest.copy(state = UPDATING)) .finalLog(latest => latest.copy(state = jobSchedulingService.stateTransitions.finalStateForUnschedule)) - .commit(_ => { + .commit(latest => { flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) + flintMetadataCacheWriter + .updateMetadataCache(indexName, index.metadata.copy(latestLogEntry = Some(latest))) logInfo("Update index options complete") jobSchedulingService.handleJob(index, AsyncQuerySchedulerAction.UNSCHEDULE) None @@ -566,8 +618,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .finalLog(latest => { latest.copy(state = jobSchedulingService.stateTransitions.finalStateForUpdate) }) - .commit(_ => { + .commit(latest => { flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) + flintMetadataCacheWriter + .updateMetadataCache(indexName, index.metadata.copy(latestLogEntry = Some(latest))) logInfo("Update index options complete") jobSchedulingService.handleJob(index, AsyncQuerySchedulerAction.UPDATE) }) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 0391741cf..2ff2883a9 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -92,7 +92,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get) + validateIndex(FlintSparkIndexFactory.create(flint.spark, updatedMetadata).get) } /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index 78636d992..ca659550d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -25,6 +25,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession /** * Flint Spark index factory that encapsulates specific Flint index instance creation. This is for @@ -35,14 +36,16 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index from generic Flint metadata. * + * @param spark + * Spark session * @param metadata * Flint metadata * @return * Flint index instance, or None if any error during creation */ - def create(metadata: FlintMetadata): Option[FlintSparkIndex] = { + def create(spark: SparkSession, metadata: FlintMetadata): Option[FlintSparkIndex] = { try { - Some(doCreate(metadata)) + Some(doCreate(spark, metadata)) } catch { case e: Exception => logWarning(s"Failed to create Flint index from metadata $metadata", e) @@ -53,24 +56,26 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index with default options. * + * @param spark + * Spark session * @param index * Flint index - * @param metadata - * Flint metadata * @return * Flint index with default options */ - def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = { + def createWithDefaultOptions( + spark: SparkSession, + index: FlintSparkIndex): Option[FlintSparkIndex] = { val originalOptions = index.options val updatedOptions = FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions) val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - this.create(updatedMetadata) + this.create(spark, updatedMetadata) } - private def doCreate(metadata: FlintMetadata): FlintSparkIndex = { + private def doCreate(spark: SparkSession, metadata: FlintMetadata): FlintSparkIndex = { val indexOptions = FlintSparkIndexOptions( metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) val latestLogEntry = metadata.latestLogEntry @@ -118,6 +123,7 @@ object FlintSparkIndexFactory extends Logging { FlintSparkMaterializedView( metadata.name, metadata.source, + getMvSourceTables(spark, metadata), metadata.indexedColumns.map { colInfo => getString(colInfo, "columnName") -> getString(colInfo, "columnType") }.toMap, @@ -134,6 +140,15 @@ object FlintSparkIndexFactory extends Logging { .toMap } + private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = { + val sourceTables = getArrayString(metadata.properties, "sourceTables") + if (sourceTables.isEmpty) { + FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source) + } else { + sourceTables + } + } + private def getString(map: java.util.Map[String, AnyRef], key: String): String = { map.get(key).asInstanceOf[String] } @@ -146,4 +161,12 @@ object FlintSparkIndexFactory extends Logging { Some(value.asInstanceOf[String]) } } + + private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = { + map.get(key) match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala index 1aaa85075..7e9922655 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala @@ -11,9 +11,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} +import org.apache.spark.sql.flint.{loadTable, parseTableName} /** * Flint Spark validation helper. @@ -31,16 +30,10 @@ trait FlintSparkValidationHelper extends Logging { * true if all non Hive, otherwise false */ def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = { - // Extract source table name (possibly more than one for MV query) val tableNames = index match { case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName) case covering: FlintSparkCoveringIndex => Seq(covering.tableName) - case mv: FlintSparkMaterializedView => - spark.sessionState.sqlParser - .parsePlan(mv.query) - .collect { case relation: UnresolvedRelation => - qualifyTableName(spark, relation.tableName) - } + case mv: FlintSparkMaterializedView => mv.sourceTables.toSeq } // Validate if any source table is not supported (currently Hive only) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintDisabledMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintDisabledMetadataCacheWriter.scala new file mode 100644 index 000000000..4099da3ff --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintDisabledMetadataCacheWriter.scala @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.opensearch.flint.common.metadata.FlintMetadata + +/** + * Default implementation of {@link FlintMetadataCacheWriter} that does nothing + */ +class FlintDisabledMetadataCacheWriter extends FlintMetadataCacheWriter { + override def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit = { + // Do nothing + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala new file mode 100644 index 000000000..e1c0f318c --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import scala.collection.JavaConverters.mapAsScalaMapConverter + +import org.opensearch.flint.common.metadata.FlintMetadata +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.spark.FlintSparkIndexOptions +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser + +/** + * Flint metadata cache defines metadata required to store in read cache for frontend user to + * access. + */ +case class FlintMetadataCache( + metadataCacheVersion: String, + /** Refresh interval for Flint index with auto refresh. Unit: seconds */ + refreshInterval: Option[Int], + /** Source table names for building the Flint index. */ + sourceTables: Array[String], + /** Timestamp when Flint index is last refreshed. Unit: milliseconds */ + lastRefreshTime: Option[Long]) { + + /** + * Convert FlintMetadataCache to a map. Skips a field if its value is not defined. + */ + def toMap: Map[String, AnyRef] = { + val fieldNames = getClass.getDeclaredFields.map(_.getName) + val fieldValues = productIterator.toList + + fieldNames + .zip(fieldValues) + .flatMap { + case (_, None) => List.empty + case (name, Some(value)) => List((name, value)) + case (name, value) => List((name, value)) + } + .toMap + .mapValues(_.asInstanceOf[AnyRef]) + } +} + +object FlintMetadataCache { + + val metadataCacheVersion = "1.0" + + def apply(metadata: FlintMetadata): FlintMetadataCache = { + val indexOptions = FlintSparkIndexOptions( + metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) + val refreshInterval = if (indexOptions.autoRefresh()) { + indexOptions + .refreshInterval() + .map(IntervalSchedulerParser.parseAndConvertToMillis) + .map(millis => (millis / 1000).toInt) // convert to seconds + } else { + None + } + val sourceTables = metadata.kind match { + case MV_INDEX_TYPE => + metadata.properties.get("sourceTables") match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + case _ => Array(metadata.source) + } + val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry => + entry.lastRefreshCompleteTime match { + case FlintMetadataLogEntry.EMPTY_TIMESTAMP => None + case timestamp => Some(timestamp) + } + } + + FlintMetadataCache(metadataCacheVersion, refreshInterval, sourceTables, lastRefreshTime) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriter.scala new file mode 100644 index 000000000..c256463c3 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriter.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.opensearch.flint.common.metadata.{FlintIndexMetadataService, FlintMetadata} + +/** + * Writes {@link FlintMetadataCache} to a storage of choice. This is different from {@link + * FlintIndexMetadataService} which persists the full index metadata to a storage for single + * source of truth. + */ +trait FlintMetadataCacheWriter { + + /** + * Update metadata cache for a Flint index. + * + * @param indexName + * index name + * @param metadata + * index metadata to update the cache + */ + def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit + +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriterBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriterBuilder.scala new file mode 100644 index 000000000..be821ae25 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriterBuilder.scala @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.apache.spark.sql.flint.config.FlintSparkConf + +object FlintMetadataCacheWriterBuilder { + def build(flintSparkConf: FlintSparkConf): FlintMetadataCacheWriter = { + if (flintSparkConf.isMetadataCacheWriteEnabled) { + new FlintOpenSearchMetadataCacheWriter(flintSparkConf.flintOptions()) + } else { + new FlintDisabledMetadataCacheWriter + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala new file mode 100644 index 000000000..2bc373792 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import java.util + +import scala.collection.JavaConverters._ + +import org.opensearch.client.RequestOptions +import org.opensearch.client.indices.PutMappingRequest +import org.opensearch.common.xcontent.XContentType +import org.opensearch.flint.common.metadata.{FlintIndexMetadataService, FlintMetadata} +import org.opensearch.flint.core.{FlintOptions, IRestHighLevelClient} +import org.opensearch.flint.core.metadata.FlintIndexMetadataServiceBuilder +import org.opensearch.flint.core.metadata.FlintJsonHelper._ +import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} + +import org.apache.spark.internal.Logging + +/** + * Writes {@link FlintMetadataCache} to index mappings `_meta` field for frontend user to access. + */ +class FlintOpenSearchMetadataCacheWriter(options: FlintOptions) + extends FlintMetadataCacheWriter + with Logging { + + /** + * Since metadata cache shares the index mappings _meta field with OpenSearch index metadata + * storage, this flag is to allow for preserving index metadata that is already stored in _meta + * when updating metadata cache. + */ + private val includeSpec: Boolean = + FlintIndexMetadataServiceBuilder + .build(options) + .isInstanceOf[FlintOpenSearchIndexMetadataService] + + override def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit = { + logInfo(s"Updating metadata cache for $indexName"); + val osIndexName = OpenSearchClientUtils.sanitizeIndexName(indexName) + var client: IRestHighLevelClient = null + try { + client = OpenSearchClientUtils.createClient(options) + val request = new PutMappingRequest(osIndexName) + request.source(serialize(metadata), XContentType.JSON) + client.updateIndexMapping(request, RequestOptions.DEFAULT) + } catch { + case e: Exception => + throw new IllegalStateException( + s"Failed to update metadata cache for Flint index $osIndexName", + e) + } finally + if (client != null) { + client.close() + } + } + + /** + * Serialize FlintMetadataCache from FlintMetadata. Modified from {@link + * FlintOpenSearchIndexMetadataService} + */ + private[metadatacache] def serialize(metadata: FlintMetadata): String = { + try { + buildJson(builder => { + objectField(builder, "_meta") { + // If _meta is used as index metadata storage, preserve them. + if (includeSpec) { + builder + .field("version", metadata.version.version) + .field("name", metadata.name) + .field("kind", metadata.kind) + .field("source", metadata.source) + .field("indexedColumns", metadata.indexedColumns) + + if (metadata.latestId.isDefined) { + builder.field("latestId", metadata.latestId.get) + } + optionalObjectField(builder, "options", metadata.options) + } + + optionalObjectField(builder, "properties", buildPropertiesMap(metadata)) + } + builder.field("properties", metadata.schema) + }) + } catch { + case e: Exception => + throw new IllegalStateException("Failed to jsonify cache metadata", e) + } + } + + /** + * Since _meta.properties is shared by both index metadata and metadata cache, here we merge the + * two maps. + */ + private def buildPropertiesMap(metadata: FlintMetadata): util.Map[String, AnyRef] = { + val metadataCacheProperties = FlintMetadataCache(metadata).toMap + + if (includeSpec) { + (metadataCacheProperties ++ metadata.properties.asScala).asJava + } else { + metadataCacheProperties.asJava + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index 2125c6878..e2a64d183 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * MV name * @param query * source query that generates MV data + * @param sourceTables + * source table names * @param outputSchema * output schema * @param options @@ -44,6 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class FlintSparkMaterializedView( mvName: String, query: String, + sourceTables: Array[String], outputSchema: Map[String, String], override val options: FlintSparkIndexOptions = empty, override val latestLogEntry: Option[FlintMetadataLogEntry] = None) @@ -64,6 +67,7 @@ case class FlintSparkMaterializedView( metadataBuilder(this) .name(mvName) .source(query) + .addProperty("sourceTables", sourceTables) .indexedColumns(indexColumnMaps) .schema(schema) .build() @@ -171,10 +175,30 @@ object FlintSparkMaterializedView { flintIndexNamePrefix(mvName) } + /** + * Extract source table names (possibly more than one) from the query. + * + * @param spark + * Spark session + * @param query + * source query that generates MV data + * @return + * source table names + */ + def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = { + spark.sessionState.sqlParser + .parsePlan(query) + .collect { case relation: UnresolvedRelation => + qualifyTableName(spark, relation.tableName) + } + .toArray + } + /** Builder class for MV build */ class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { private var mvName: String = "" private var query: String = "" + private var sourceTables: Array[String] = Array.empty[String] /** * Set MV name. @@ -199,6 +223,7 @@ object FlintSparkMaterializedView { */ def query(query: String): Builder = { this.query = query + this.sourceTables = extractSourceTableNames(flint.spark, query) this } @@ -227,7 +252,7 @@ object FlintSparkMaterializedView { field.name -> field.dataType.simpleString } .toMap - FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions) + FlintSparkMaterializedView(mvName, query, sourceTables, outputSchema, indexOptions) } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java index 3620608b0..330b38f02 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java @@ -7,9 +7,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.flint.config.FlintSparkConf; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.core.FlintOptions; +import java.io.IOException; import java.lang.reflect.Constructor; /** @@ -28,11 +31,27 @@ public enum AsyncQuerySchedulerAction { REMOVE } - public static AsyncQueryScheduler build(FlintOptions options) { + public static AsyncQueryScheduler build(SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilder().doBuild(sparkSession, options); + } + + /** + * Builds an AsyncQueryScheduler based on the provided options. + * + * @param sparkSession The SparkSession to be used. + * @param options The FlintOptions containing configuration details. + * @return An instance of AsyncQueryScheduler. + */ + protected AsyncQueryScheduler doBuild(SparkSession sparkSession, FlintOptions options) throws IOException { String className = options.getCustomAsyncQuerySchedulerClass(); if (className.isEmpty()) { - return new OpenSearchAsyncQueryScheduler(options); + OpenSearchAsyncQueryScheduler scheduler = createOpenSearchAsyncQueryScheduler(options); + // Check if the scheduler has access to the required index. Disable the external scheduler otherwise. + if (!hasAccessToSchedulerIndex(scheduler)){ + setExternalSchedulerEnabled(sparkSession, false); + } + return scheduler; } // Attempts to instantiate AsyncQueryScheduler using reflection @@ -45,4 +64,16 @@ public static AsyncQueryScheduler build(FlintOptions options) { throw new RuntimeException("Failed to instantiate AsyncQueryScheduler: " + className, e); } } + + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return new OpenSearchAsyncQueryScheduler(options); + } + + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return scheduler.hasAccessToSchedulerIndex(); + } + + protected void setExternalSchedulerEnabled(SparkSession sparkSession, boolean enabled) { + sparkSession.sqlContext().setConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED().key(), String.valueOf(enabled)); + } } \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala index d043746c0..197b2f8c7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala @@ -10,6 +10,7 @@ import java.time.Instant import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState import org.opensearch.flint.common.scheduler.AsyncQueryScheduler import org.opensearch.flint.common.scheduler.model.{AsyncQuerySchedulerRequest, LangType} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.opensearch.flint.core.storage.OpenSearchClientUtils import org.opensearch.flint.spark.FlintSparkIndex import org.opensearch.flint.spark.refresh.util.RefreshMetricsAspect @@ -83,8 +84,16 @@ class FlintSparkJobExternalSchedulingService( case AsyncQuerySchedulerAction.REMOVE => flintAsyncQueryScheduler.removeJob(request) case _ => throw new IllegalArgumentException(s"Unsupported action: $action") } + addExternalSchedulerMetrics(action) None // Return None for all cases } } + + private def addExternalSchedulerMetrics(action: AsyncQuerySchedulerAction): Unit = { + val actionName = action.name().toLowerCase() + MetricsUtil.addHistoricGauge( + MetricConstants.EXTERNAL_SCHEDULER_METRIC_PREFIX + actionName + ".count", + 1) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java index 19532254b..a1ef45825 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -37,6 +38,7 @@ import org.opensearch.jobscheduler.spi.schedule.Schedule; import org.opensearch.rest.RestStatus; +import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.time.Instant; @@ -55,6 +57,11 @@ public class OpenSearchAsyncQueryScheduler implements AsyncQueryScheduler { private static final ObjectMapper mapper = new ObjectMapper(); private final FlintOptions flintOptions; + @VisibleForTesting + public OpenSearchAsyncQueryScheduler() { + this.flintOptions = new FlintOptions(ImmutableMap.of()); + } + public OpenSearchAsyncQueryScheduler(FlintOptions options) { this.flintOptions = options; } @@ -124,6 +131,28 @@ void createAsyncQuerySchedulerIndex(IRestHighLevelClient client) { } } + /** + * Checks if the current setup has access to the scheduler index. + * + * This method attempts to create a client and ensure that the scheduler index exists. + * If these operations succeed, it indicates that the user has the necessary permissions + * to access and potentially modify the scheduler index. + * + * @see #createClient() + * @see #ensureIndexExists(IRestHighLevelClient) + */ + public boolean hasAccessToSchedulerIndex() throws IOException { + IRestHighLevelClient client = createClient(); + try { + ensureIndexExists(client); + return true; + } catch (Throwable e) { + LOG.error("Failed to ensure index exists", e); + return false; + } finally { + client.close(); + } + } private void ensureIndexExists(IRestHighLevelClient client) { try { if (!client.doesIndexExist(new GetIndexRequest(SCHEDULER_INDEX_NAME), RequestOptions.DEFAULT)) { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java index 8745681b9..9622b4c64 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java @@ -18,21 +18,30 @@ public class IntervalSchedulerParser { /** - * Parses a schedule string into an IntervalSchedule. + * Parses a schedule string into an integer in milliseconds. * * @param scheduleStr the schedule string to parse - * @return the parsed IntervalSchedule + * @return the parsed integer * @throws IllegalArgumentException if the schedule string is invalid */ - public static IntervalSchedule parse(String scheduleStr) { + public static Long parseAndConvertToMillis(String scheduleStr) { if (Strings.isNullOrEmpty(scheduleStr)) { throw new IllegalArgumentException("Schedule string must not be null or empty."); } - Long millis = Triggers.convert(scheduleStr); + return Triggers.convert(scheduleStr); + } + /** + * Parses a schedule string into an IntervalSchedule. + * + * @param scheduleStr the schedule string to parse + * @return the parsed IntervalSchedule + * @throws IllegalArgumentException if the schedule string is invalid + */ + public static IntervalSchedule parse(String scheduleStr) { // Convert milliseconds to minutes (rounding down) - int minutes = (int) (millis / (60 * 1000)); + int minutes = (int) (parseAndConvertToMillis(scheduleStr) / (60 * 1000)); // Use the current time as the start time Instant startTime = Instant.now(); diff --git a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java index 67b5afee5..3c65a96a5 100644 --- a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java +++ b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java @@ -5,43 +5,80 @@ package org.opensearch.flint.core.scheduler; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.SQLContext; +import org.junit.Before; import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.common.scheduler.model.AsyncQuerySchedulerRequest; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder; import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler; +import java.io.IOException; + import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AsyncQuerySchedulerBuilderTest { + @Mock + private SparkSession sparkSession; + + @Mock + private SQLContext sqlContext; + + private AsyncQuerySchedulerBuilderForLocalTest testBuilder; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + when(sparkSession.sqlContext()).thenReturn(sqlContext); + } + + @Test + public void testBuildWithEmptyClassNameAndAccessibleIndex() throws IOException { + FlintOptions options = mock(FlintOptions.class); + when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); + + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, true, sparkSession, options); + assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext, never()).setConf(anyString(), anyString()); + } @Test - public void testBuildWithEmptyClassName() { + public void testBuildWithEmptyClassNameAndInaccessibleIndex() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, false, sparkSession, options); assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext).setConf("spark.flint.job.externalScheduler.enabled", "false"); } @Test - public void testBuildWithCustomClassName() { + public void testBuildWithCustomClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); - when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); + when(options.getCustomAsyncQuerySchedulerClass()) + .thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(sparkSession, options); assertTrue(scheduler instanceof AsyncQuerySchedulerForLocalTest); } @Test(expected = RuntimeException.class) - public void testBuildWithInvalidClassName() { + public void testBuildWithInvalidClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("invalid.ClassName"); - AsyncQuerySchedulerBuilder.build(options); + AsyncQuerySchedulerBuilder.build(sparkSession, options); } public static class AsyncQuerySchedulerForLocalTest implements AsyncQueryScheduler { @@ -65,4 +102,35 @@ public void removeJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) { // Custom implementation } } + + public static class OpenSearchAsyncQuerySchedulerForLocalTest extends OpenSearchAsyncQueryScheduler { + @Override + public boolean hasAccessToSchedulerIndex() { + return true; + } + } + + public static class AsyncQuerySchedulerBuilderForLocalTest extends AsyncQuerySchedulerBuilder { + private OpenSearchAsyncQueryScheduler mockScheduler; + private Boolean mockHasAccess; + + public AsyncQuerySchedulerBuilderForLocalTest(OpenSearchAsyncQueryScheduler mockScheduler, Boolean mockHasAccess) { + this.mockScheduler = mockScheduler; + this.mockHasAccess = mockHasAccess; + } + + @Override + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return mockScheduler != null ? mockScheduler : super.createOpenSearchAsyncQueryScheduler(options); + } + + @Override + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return mockHasAccess != null ? mockHasAccess : super.hasAccessToSchedulerIndex(scheduler); + } + + public static AsyncQueryScheduler build(OpenSearchAsyncQueryScheduler asyncQueryScheduler, Boolean hasAccess, SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilderForLocalTest(asyncQueryScheduler, hasAccess).doBuild(sparkSession, options); + } + } } \ No newline at end of file diff --git a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java index 2ad1fea9c..731e1ae5c 100644 --- a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java +++ b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java @@ -23,53 +23,92 @@ public void testParseNull() { IntervalSchedulerParser.parse(null); } + @Test(expected = IllegalArgumentException.class) + public void testParseMillisNull() { + IntervalSchedulerParser.parseAndConvertToMillis(null); + } + @Test(expected = IllegalArgumentException.class) public void testParseEmptyString() { IntervalSchedulerParser.parse(""); } + @Test(expected = IllegalArgumentException.class) + public void testParseMillisEmptyString() { + IntervalSchedulerParser.parseAndConvertToMillis(""); + } + @Test public void testParseString() { - Schedule result = IntervalSchedulerParser.parse("10 minutes"); - assertTrue(result instanceof IntervalSchedule); - IntervalSchedule intervalSchedule = (IntervalSchedule) result; + Schedule schedule = IntervalSchedulerParser.parse("10 minutes"); + assertTrue(schedule instanceof IntervalSchedule); + IntervalSchedule intervalSchedule = (IntervalSchedule) schedule; assertEquals(10, intervalSchedule.getInterval()); assertEquals(ChronoUnit.MINUTES, intervalSchedule.getUnit()); } + @Test + public void testParseMillisString() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("10 minutes"); + assertEquals(600000, millis.longValue()); + } + @Test(expected = IllegalArgumentException.class) public void testParseInvalidFormat() { IntervalSchedulerParser.parse("invalid format"); } + @Test(expected = IllegalArgumentException.class) + public void testParseMillisInvalidFormat() { + IntervalSchedulerParser.parseAndConvertToMillis("invalid format"); + } + @Test public void testParseStringScheduleMinutes() { - IntervalSchedule result = IntervalSchedulerParser.parse("5 minutes"); - assertEquals(5, result.getInterval()); - assertEquals(ChronoUnit.MINUTES, result.getUnit()); + IntervalSchedule schedule = IntervalSchedulerParser.parse("5 minutes"); + assertEquals(5, schedule.getInterval()); + assertEquals(ChronoUnit.MINUTES, schedule.getUnit()); + } + + @Test + public void testParseMillisStringScheduleMinutes() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("5 minutes"); + assertEquals(300000, millis.longValue()); } @Test public void testParseStringScheduleHours() { - IntervalSchedule result = IntervalSchedulerParser.parse("2 hours"); - assertEquals(120, result.getInterval()); - assertEquals(ChronoUnit.MINUTES, result.getUnit()); + IntervalSchedule schedule = IntervalSchedulerParser.parse("2 hours"); + assertEquals(120, schedule.getInterval()); + assertEquals(ChronoUnit.MINUTES, schedule.getUnit()); + } + + @Test + public void testParseMillisStringScheduleHours() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("2 hours"); + assertEquals(7200000, millis.longValue()); } @Test public void testParseStringScheduleDays() { - IntervalSchedule result = IntervalSchedulerParser.parse("1 day"); - assertEquals(1440, result.getInterval()); - assertEquals(ChronoUnit.MINUTES, result.getUnit()); + IntervalSchedule schedule = IntervalSchedulerParser.parse("1 day"); + assertEquals(1440, schedule.getInterval()); + assertEquals(ChronoUnit.MINUTES, schedule.getUnit()); + } + + @Test + public void testParseMillisStringScheduleDays() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("1 day"); + assertEquals(86400000, millis.longValue()); } @Test public void testParseStringScheduleStartTime() { Instant before = Instant.now(); - IntervalSchedule result = IntervalSchedulerParser.parse("30 minutes"); + IntervalSchedule schedule = IntervalSchedulerParser.parse("30 minutes"); Instant after = Instant.now(); - assertTrue(result.getStartTime().isAfter(before) || result.getStartTime().equals(before)); - assertTrue(result.getStartTime().isBefore(after) || result.getStartTime().equals(after)); + assertTrue(schedule.getStartTime().isAfter(before) || schedule.getStartTime().equals(before)); + assertTrue(schedule.getStartTime().isBefore(after) || schedule.getStartTime().equals(after)); } } \ No newline at end of file diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index ee8a52d96..b675265b7 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -9,8 +9,8 @@ import org.opensearch.flint.spark.FlintSparkExtensions import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.flint.config.FlintConfigEntry -import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED +import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} +import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -26,6 +26,10 @@ trait FlintSuite extends SharedSparkSession { // ConstantPropagation etc. .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) .set("spark.sql.extensions", classOf[FlintSparkExtensions].getName) + // Override scheduler class for unit testing + .set( + FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key, + "org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest") conf } @@ -44,4 +48,22 @@ trait FlintSuite extends SharedSparkSession { setFlintSparkConf(HYBRID_SCAN_ENABLED, "false") } } + + protected def withExternalSchedulerEnabled(block: => Unit): Unit = { + setFlintSparkConf(EXTERNAL_SCHEDULER_ENABLED, "true") + try { + block + } finally { + setFlintSparkConf(EXTERNAL_SCHEDULER_ENABLED, "false") + } + } + + protected def withMetadataCacheWriteEnabled(block: => Unit): Unit = { + setFlintSparkConf(METADATA_CACHE_WRITE, "true") + try { + block + } finally { + setFlintSparkConf(METADATA_CACHE_WRITE, "false") + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala new file mode 100644 index 000000000..07720ff24 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.scalatest.matchers.should.Matchers._ + +import org.apache.spark.FlintSuite + +class FlintSparkIndexFactorySuite extends FlintSuite { + + test("create mv should generate source tables if missing in metadata") { + val testTable = "spark_catalog.default.mv_build_test" + val testMvName = "spark_catalog.default.mv" + val testQuery = s"SELECT * FROM $testTable" + + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "indexedColumns": [ + | { + | "columnType": "int", + | "columnName": "age" + | } + | ], + | "name": "$testMvName", + | "source": "$testQuery" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val metadata = FlintOpenSearchIndexMetadataService.deserialize(content) + val index = FlintSparkIndexFactory.create(spark, metadata) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index 2c5518778..96c71d94b 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -246,6 +246,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { new FlintMetadataLogEntry( "id", 0, + 0, + 0, state, Map("seqNo" -> 0, "primaryTerm" -> 0), "", diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala new file mode 100644 index 000000000..6ec6cf696 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { + val flintMetadataLogEntry = FlintMetadataLogEntry( + "id", + 0L, + 0L, + 1234567890123L, + FlintMetadataLogEntry.IndexState.ACTIVE, + Map.empty[String, Any], + "", + Map.empty[String, Any]) + + it should "construct from skipping index FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from covering index FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$COVERING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from materialized view FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "source": "spark_catalog.default.wrong_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | }, + | "properties": { + | "sourceTables": [ + | "spark_catalog.default.test_table", + | "spark_catalog.default.another_table" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array( + "spark_catalog.default.test_table", + "spark_catalog.default.another_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from FlintMetadata excluding invalid fields" in { + // Set auto_refresh = false and lastRefreshCompleteTime = 0 + val content = + s""" { + | "_meta": { + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry.copy(lastRefreshCompleteTime = 0L))) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval shouldBe empty + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime shouldBe empty + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index c1df42883..1c9a9e83c 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConv import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} -import org.scalatest.matchers.should.Matchers.{contain, convertToAnyShouldWrapper, the} +import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite @@ -37,31 +37,34 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = "SELECT 1" test("get mv name") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" } test("get mv name with dots") { val testMvNameDots = "spark_catalog.default.mv.2023.10" - val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv.2023.10" } test("should fail if get name with unqualified MV name") { the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("mv", testQuery, Array.empty, Map.empty).name() the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("default.mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("default.mv", testQuery, Array.empty, Map.empty).name() } test("get metadata") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map("test_col" -> "integer")) + val mv = + FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map("test_col" -> "integer")) val metadata = mv.metadata() metadata.name shouldBe mv.mvName metadata.kind shouldBe MV_INDEX_TYPE metadata.source shouldBe "SELECT 1" + metadata.properties should contain key "sourceTables" + metadata.properties.get("sourceTables").asInstanceOf[Array[String]] should have size 0 metadata.indexedColumns shouldBe Array( Map("columnName" -> "test_col", "columnType" -> "integer").asJava) metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava @@ -74,6 +77,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, testQuery, + Array.empty, Map("test_col" -> "integer"), indexOptions) @@ -83,12 +87,12 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build batch data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.build(spark, None).collect() shouldBe Array(Row(1)) } test("should fail if build given other source data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) the[IllegalArgumentException] thrownBy mv.build(spark, Some(mock[DataFrame])) } @@ -103,7 +107,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -128,7 +132,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -144,7 +148,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { test("build stream with non-aggregate query") { val testQuery = s"SELECT name, age FROM $testTable WHERE age > 30" - withAggregateMaterializedView(testQuery, Map.empty) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), Map.empty) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -158,7 +162,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = s"SELECT name, age FROM $testTable" val options = Map("extra_options" -> s"""{"$testTable": {"maxFilesPerTrigger": "1"}}""") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable, Map("maxFilesPerTrigger" -> "1")) @@ -175,6 +179,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), Map.empty) the[IllegalStateException] thrownBy @@ -182,14 +187,20 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } } - private def withAggregateMaterializedView(query: String, options: Map[String, String])( - codeBlock: LogicalPlan => Unit): Unit = { + private def withAggregateMaterializedView( + query: String, + sourceTables: Array[String], + options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { withTable(testTable) { sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView(testMvName, query, Map.empty, FlintSparkIndexOptions(options)) + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) val actualPlan = mv.buildStream(spark).queryExecution.logical codeBlock(actualPlan) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala index f03116de9..d56c4e66f 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala @@ -132,6 +132,8 @@ class ApplyFlintSparkSkippingIndexSuite extends FlintSuite with Matchers { new FlintMetadataLogEntry( "id", 0L, + 0L, + 0L, indexState, Map.empty[String, Any], "", diff --git a/integ-test/src/integration/resources/tpch/q1.ppl b/integ-test/src/integration/resources/tpch/q1.ppl new file mode 100644 index 000000000..885ce35c6 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q1.ppl @@ -0,0 +1,35 @@ +/* +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + l_shipdate <= date '1998-12-01' - interval '90' day +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus +*/ + +source = lineitem +| where l_shipdate <= subdate(date('1998-12-01'), 90) +| stats sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count() as count_order + by l_returnflag, l_linestatus +| sort l_returnflag, l_linestatus \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q10.ppl b/integ-test/src/integration/resources/tpch/q10.ppl new file mode 100644 index 000000000..10a050785 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q10.ppl @@ -0,0 +1,45 @@ +/* +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| join ON c_nationkey = n_nationkey nation +| where o_orderdate >= date('1993-10-01') + AND o_orderdate < date_add(date('1993-10-01'), interval 3 month) + AND l_returnflag = 'R' +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by c_custkey, c_name, c_acctbal, c_phone, n_name, c_address, c_comment +| sort - revenue +| head 20 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q11.ppl b/integ-test/src/integration/resources/tpch/q11.ppl new file mode 100644 index 000000000..3a55d986e --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q11.ppl @@ -0,0 +1,45 @@ +/* +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc +*/ + +source = partsupp +| join ON ps_suppkey = s_suppkey supplier +| join ON s_nationkey = n_nationkey nation +| where n_name = 'GERMANY' +| stats sum(ps_supplycost * ps_availqty) as value by ps_partkey +| where value > [ + source = partsupp + | join ON ps_suppkey = s_suppkey supplier + | join ON s_nationkey = n_nationkey nation + | where n_name = 'GERMANY' + | stats sum(ps_supplycost * ps_availqty) as check + | eval threshold = check * 0.0001000000 + | fields threshold + ] +| sort - value \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q12.ppl b/integ-test/src/integration/resources/tpch/q12.ppl new file mode 100644 index 000000000..79672d844 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q12.ppl @@ -0,0 +1,42 @@ +/* +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by + l_shipmode +order by + l_shipmode +*/ + +source = orders +| join ON o_orderkey = l_orderkey lineitem +| where l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_shipmode in ('MAIL', 'SHIP') + and l_receiptdate >= date('1994-01-01') + and l_receiptdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(case(o_orderpriority = '1-URGENT' or o_orderpriority = '2-HIGH', 1 else 0)) as high_line_count, + sum(case(o_orderpriority != '1-URGENT' and o_orderpriority != '2-HIGH', 1 else 0)) as low_line_countby + by l_shipmode +| sort l_shipmode \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q13.ppl b/integ-test/src/integration/resources/tpch/q13.ppl new file mode 100644 index 000000000..6e77c9b0a --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q13.ppl @@ -0,0 +1,31 @@ +/* +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc +*/ + +source = [ + source = customer + | left outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + orders + | stats count(o_orderkey) as c_count by c_custkey + ] as c_orders +| stats count() as custdist by c_count +| sort - custdist, - c_count \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q14.ppl b/integ-test/src/integration/resources/tpch/q14.ppl new file mode 100644 index 000000000..553f1e549 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q14.ppl @@ -0,0 +1,25 @@ +/* +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month +*/ + +source = lineitem +| join ON l_partkey = p_partkey + AND l_shipdate >= date('1995-09-01') + AND l_shipdate < date_add(date('1995-09-01'), interval 1 month) + part +| stats sum(case(like(p_type, 'PROMO%'), l_extendedprice * (1 - l_discount) else 0)) as sum1, + sum(l_extendedprice * (1 - l_discount)) as sum2 +| eval promo_revenue = 100.00 * sum1 / sum2 // Stats and Eval commands can combine when issues/819 resolved +| fields promo_revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q15.ppl b/integ-test/src/integration/resources/tpch/q15.ppl new file mode 100644 index 000000000..96f5ecea2 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q15.ppl @@ -0,0 +1,52 @@ +/* +with revenue0 as + (select + l_suppkey as supplier_no, + sum(l_extendedprice * (1 - l_discount)) as total_revenue + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey) +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey +*/ + +// CTE is unsupported in PPL +source = supplier +| join right = revenue0 ON s_suppkey = supplier_no [ + source = lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] +| where total_revenue = [ + source = [ + source = lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] + | stats max(total_revenue) + ] +| sort s_suppkey +| fields s_suppkey, s_name, s_address, s_phone, total_revenue diff --git a/integ-test/src/integration/resources/tpch/q16.ppl b/integ-test/src/integration/resources/tpch/q16.ppl new file mode 100644 index 000000000..4c5765f04 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q16.ppl @@ -0,0 +1,45 @@ +/* +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size +*/ + +source = partsupp +| join ON p_partkey = ps_partkey part +| where p_brand != 'Brand#45' + and not like(p_type, 'MEDIUM POLISHED%') + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in [ + source = supplier + | where like(s_comment, '%Customer%Complaints%') + | fields s_suppkey + ] +| stats distinct_count(ps_suppkey) as supplier_cnt by p_brand, p_type, p_size +| sort - supplier_cnt, p_brand, p_type, p_size \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q17.ppl b/integ-test/src/integration/resources/tpch/q17.ppl new file mode 100644 index 000000000..994b7ee18 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q17.ppl @@ -0,0 +1,34 @@ +/* +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ) +*/ + +source = lineitem +| join ON p_partkey = l_partkey part +| where p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < [ + source = lineitem + | where l_partkey = p_partkey + | stats avg(l_quantity) as avg + | eval `0.2 * avg` = 0.2 * avg // Stats and Eval commands can combine when issues/819 resolved + | fields `0.2 * avg` + ] +| stats sum(l_extendedprice) as sum +| eval avg_yearly = sum / 7.0 // Stats and Eval commands can combine when issues/819 resolved +| fields avg_yearly \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q18.ppl b/integ-test/src/integration/resources/tpch/q18.ppl new file mode 100644 index 000000000..1dab3d473 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q18.ppl @@ -0,0 +1,48 @@ +/* +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON o_orderkey = l_orderkey lineitem +| where o_orderkey in [ + source = lineitem + | stats sum(l_quantity) as sum by l_orderkey + | where sum > 300 + | fields l_orderkey + ] +| stats sum(l_quantity) by c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice +| sort - o_totalprice, o_orderdate +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q19.ppl b/integ-test/src/integration/resources/tpch/q19.ppl new file mode 100644 index 000000000..630d63bcc --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q19.ppl @@ -0,0 +1,61 @@ +/* +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) +*/ + +source = lineitem +| join ON p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + part \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q2.ppl b/integ-test/src/integration/resources/tpch/q2.ppl new file mode 100644 index 000000000..aa95d9d14 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q2.ppl @@ -0,0 +1,62 @@ +/* +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit 100 +*/ + +source = part +| join ON p_partkey = ps_partkey partsupp +| join ON s_suppkey = ps_suppkey supplier +| join ON s_nationkey = n_nationkey nation +| join ON n_regionkey = r_regionkey region +| where p_size = 15 AND like(p_type, '%BRASS') AND r_name = 'EUROPE' AND ps_supplycost = [ + source = partsupp + | join ON s_suppkey = ps_suppkey supplier + | join ON s_nationkey = n_nationkey nation + | join ON n_regionkey = r_regionkey region + | where r_name = 'EUROPE' + | stats MIN(ps_supplycost) + ] +| sort - s_acctbal, n_name, s_name, p_partkey +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q20.ppl b/integ-test/src/integration/resources/tpch/q20.ppl new file mode 100644 index 000000000..08bd21277 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q20.ppl @@ -0,0 +1,62 @@ +/* +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name +*/ + +source = supplier +| join ON s_nationkey = n_nationkey nation +| where n_name = 'CANADA' + and s_suppkey in [ + source = partsupp + | where ps_partkey in [ + source = part + | where like(p_name, 'forest%') + | fields p_partkey + ] + and ps_availqty > [ + source = lineitem + | where l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1994-01-01') + and l_shipdate < date_add(date('1994-01-01'), interval 1 year) + | stats sum(l_quantity) as sum_l_quantity + | eval half_sum_l_quantity = 0.5 * sum_l_quantity // Stats and Eval commands can combine when issues/819 resolved + | fields half_sum_l_quantity + ] + | fields ps_suppkey + ] \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q21.ppl b/integ-test/src/integration/resources/tpch/q21.ppl new file mode 100644 index 000000000..0eb7149f6 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q21.ppl @@ -0,0 +1,64 @@ +/* +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100 +*/ + +source = supplier +| join ON s_suppkey = l1.l_suppkey lineitem as l1 +| join ON o_orderkey = l1.l_orderkey orders +| join ON s_nationkey = n_nationkey nation +| where o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists [ + source = lineitem as l2 + | where l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey != l1.l_suppkey + ] + and not exists [ + source = lineitem as l3 + | where l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey != l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ] + and n_name = 'SAUDI ARABIA' +| stats count() as numwait by s_name +| sort - numwait, s_name +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q22.ppl b/integ-test/src/integration/resources/tpch/q22.ppl new file mode 100644 index 000000000..811308cb0 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q22.ppl @@ -0,0 +1,58 @@ +/* +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode +*/ + +source = [ + source = customer + | where substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > [ + source = customer + | where c_acctbal > 0.00 + and substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + | stats avg(c_acctbal) + ] + and not exists [ + source = orders + | where o_custkey = c_custkey + ] + | eval cntrycode = substring(c_phone, 1, 2) + | fields cntrycode, c_acctbal + ] as custsale +| stats count() as numcust, sum(c_acctbal) as totacctbal by cntrycode +| sort cntrycode \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q3.ppl b/integ-test/src/integration/resources/tpch/q3.ppl new file mode 100644 index 000000000..0ece358ab --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q3.ppl @@ -0,0 +1,33 @@ +/* +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit 10 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| where c_mktsegment = 'BUILDING' AND o_orderdate < date('1995-03-15') AND l_shipdate > date('1995-03-15') +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by l_orderkey, o_orderdate, o_shippriority +| sort - revenue, o_orderdate +| head 10 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q4.ppl b/integ-test/src/integration/resources/tpch/q4.ppl new file mode 100644 index 000000000..cc01bda7d --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q4.ppl @@ -0,0 +1,33 @@ +/* +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority +*/ + +source = orders +| where o_orderdate >= date('1993-07-01') + and o_orderdate < date_add(date('1993-07-01'), interval 3 month) + and exists [ + source = lineitem + | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate + ] +| stats count() as order_count by o_orderpriority +| sort o_orderpriority \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q5.ppl b/integ-test/src/integration/resources/tpch/q5.ppl new file mode 100644 index 000000000..4761b0365 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q5.ppl @@ -0,0 +1,36 @@ +/* +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by + n_name +order by + revenue desc +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| join ON l_suppkey = s_suppkey AND c_nationkey = s_nationkey supplier +| join ON s_nationkey = n_nationkey nation +| join ON n_regionkey = r_regionkey region +| where r_name = 'ASIA' AND o_orderdate >= date('1994-01-01') AND o_orderdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by n_name +| sort - revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q6.ppl b/integ-test/src/integration/resources/tpch/q6.ppl new file mode 100644 index 000000000..6a77877c3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q6.ppl @@ -0,0 +1,18 @@ +/* +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +*/ + +source = lineitem +| where l_shipdate >= date('1994-01-01') + and l_shipdate < adddate(date('1994-01-01'), 365) + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +| stats sum(l_extendedprice * l_discount) as revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q7.ppl b/integ-test/src/integration/resources/tpch/q7.ppl new file mode 100644 index 000000000..ceda602b3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q7.ppl @@ -0,0 +1,56 @@ +/* +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + year(l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year +*/ + +source = [ + source = supplier + | join ON s_suppkey = l_suppkey lineitem + | join ON o_orderkey = l_orderkey orders + | join ON c_custkey = o_custkey customer + | join ON s_nationkey = n1.n_nationkey nation as n1 + | join ON c_nationkey = n2.n_nationkey nation as n2 + | where l_shipdate between date('1995-01-01') and date('1996-12-31') + and n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY' or n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE' + | eval supp_nation = n1.n_name, cust_nation = n2.n_name, l_year = year(l_shipdate), volume = l_extendedprice * (1 - l_discount) + | fields supp_nation, cust_nation, l_year, volume + ] as shipping +| stats sum(volume) as revenue by supp_nation, cust_nation, l_year +| sort supp_nation, cust_nation, l_year \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q8.ppl b/integ-test/src/integration/resources/tpch/q8.ppl new file mode 100644 index 000000000..a73c7f7c3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q8.ppl @@ -0,0 +1,60 @@ +/* +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year +*/ + +source = [ + source = part + | join ON p_partkey = l_partkey lineitem + | join ON s_suppkey = l_suppkey supplier + | join ON l_orderkey = o_orderkey orders + | join ON o_custkey = c_custkey customer + | join ON c_nationkey = n1.n_nationkey nation as n1 + | join ON s_nationkey = n2.n_nationkey nation as n2 + | join ON n1.n_regionkey = r_regionkey region + | where r_name = 'AMERICA' AND p_type = 'ECONOMY ANODIZED STEEL' + and o_orderdate between date('1995-01-01') and date('1996-12-31') + | eval o_year = year(o_orderdate) + | eval volume = l_extendedprice * (1 - l_discount) + | eval nation = n2.n_name + | fields o_year, volume, nation + ] as all_nations +| stats sum(case(nation = 'BRAZIL', volume else 0)) as sum_case, sum(volume) as sum_volume by o_year +| eval mkt_share = sum_case / sum_volume +| fields mkt_share, o_year +| sort o_year \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q9.ppl b/integ-test/src/integration/resources/tpch/q9.ppl new file mode 100644 index 000000000..7692afd74 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q9.ppl @@ -0,0 +1,50 @@ +/* +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc +*/ + +source = [ + source = part + | join ON p_partkey = l_partkey lineitem + | join ON s_suppkey = l_suppkey supplier + | join ON ps_partkey = l_partkey and ps_suppkey = l_suppkey partsupp + | join ON o_orderkey = l_orderkey orders + | join ON s_nationkey = n_nationkey nation + | where like(p_name, '%green%') + | eval nation = n_name + | eval o_year = year(o_orderdate) + | eval amount = l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity + | fields nation, o_year, amount + ] as profit +| stats sum(amount) as sum_profit by nation, o_year +| sort nation, - o_year \ No newline at end of file diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala index 9aeba7512..33702f23f 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala @@ -25,10 +25,12 @@ class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { val testFlintIndex = "flint_test_index" val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) - val testCreateTime = 1234567890123L + val testTimestamp = 1234567890123L val flintMetadataLogEntry = FlintMetadataLogEntry( testLatestId, - testCreateTime, + testTimestamp, + testTimestamp, + testTimestamp, ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), "", @@ -85,8 +87,10 @@ class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { val latest = metadataLog.get.getLatest latest.isPresent shouldBe true latest.get.id shouldBe testLatestId - latest.get.createTime shouldBe testCreateTime + latest.get.createTime shouldBe testTimestamp latest.get.error shouldBe "" + latest.get.lastRefreshStartTime shouldBe testTimestamp + latest.get.lastRefreshCompleteTime shouldBe testTimestamp latest.get.properties.get("dataSourceName").get shouldBe testDataSourceName } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala index 605e8e7fd..df5f4eec2 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala @@ -24,6 +24,7 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { val testFlintIndex = "flint_test_index" val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) + val testTimestamp = 1234567890123L var flintMetadataLogService: FlintMetadataLogService = _ override def beforeAll(): Unit = { @@ -40,6 +41,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { latest.id shouldBe testLatestId latest.state shouldBe EMPTY latest.createTime shouldBe 0L + latest.lastRefreshStartTime shouldBe 0L + latest.lastRefreshCompleteTime shouldBe 0L latest.error shouldBe "" latest.properties.get("dataSourceName").get shouldBe testDataSourceName true @@ -49,11 +52,12 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { } test("should preserve original values when transition") { - val testCreateTime = 1234567890123L createLatestLogEntry( FlintMetadataLogEntry( id = testLatestId, - createTime = testCreateTime, + createTime = testTimestamp, + lastRefreshStartTime = testTimestamp, + lastRefreshCompleteTime = testTimestamp, state = ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), error = "", @@ -63,8 +67,10 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { .startTransaction(testFlintIndex) .initialLog(latest => { latest.id shouldBe testLatestId - latest.createTime shouldBe testCreateTime + latest.createTime shouldBe testTimestamp latest.error shouldBe "" + latest.lastRefreshStartTime shouldBe testTimestamp + latest.lastRefreshCompleteTime shouldBe testTimestamp latest.properties.get("dataSourceName").get shouldBe testDataSourceName true }) @@ -72,8 +78,10 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { .finalLog(latest => latest.copy(state = DELETED)) .commit(latest => { latest.id shouldBe testLatestId - latest.createTime shouldBe testCreateTime + latest.createTime shouldBe testTimestamp latest.error shouldBe "" + latest.lastRefreshStartTime shouldBe testTimestamp + latest.lastRefreshCompleteTime shouldBe testTimestamp latest.properties.get("dataSourceName").get shouldBe testDataSourceName }) @@ -112,7 +120,9 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { createLatestLogEntry( FlintMetadataLogEntry( id = testLatestId, - createTime = 1234567890123L, + createTime = testTimestamp, + lastRefreshStartTime = testTimestamp, + lastRefreshCompleteTime = testTimestamp, state = ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), error = "", @@ -198,7 +208,9 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { createLatestLogEntry( FlintMetadataLogEntry( id = testLatestId, - createTime = 1234567890123L, + createTime = testTimestamp, + lastRefreshStartTime = testTimestamp, + lastRefreshCompleteTime = testTimestamp, state = ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), error = "", @@ -240,6 +252,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { latest.id shouldBe testLatestId latest.state shouldBe EMPTY latest.createTime shouldBe 0L + latest.lastRefreshStartTime shouldBe 0L + latest.lastRefreshCompleteTime shouldBe 0L latest.error shouldBe "" latest.properties.get("dataSourceName").get shouldBe testDataSourceName true diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index 14d41c2bb..fc77faaea 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -17,9 +17,10 @@ import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTableNames, getFlintIndexName} import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler -import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.{DataFrame, Row} @@ -51,6 +52,29 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { deleteTestIndex(testFlintIndex) } + test("extract source table names from materialized view source query successfully") { + val testComplexQuery = s""" + | SELECT * + | FROM ( + | SELECT 1 + | FROM table1 + | LEFT JOIN `table2` + | ) + | UNION ALL + | SELECT 1 + | FROM spark_catalog.default.`table/3` + | INNER JOIN spark_catalog.default.`table.4` + |""".stripMargin + extractSourceTableNames(flint.spark, testComplexQuery) should contain theSameElementsAs + Array( + "spark_catalog.default.table1", + "spark_catalog.default.table2", + "spark_catalog.default.`table/3`", + "spark_catalog.default.`table.4`") + + extractSourceTableNames(flint.spark, "SELECT 1") should have size 0 + } + test("create materialized view with metadata successfully") { withTempDir { checkpointDir => val indexOptions = @@ -91,7 +115,9 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | "scheduler_mode":"internal" | }, | "latestId": "$testLatestId", - | "properties": {} + | "properties": { + | "sourceTables": ["$testTable"] + | } | }, | "properties": { | "startTime": { @@ -107,6 +133,22 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { } } + test("create materialized view should parse source tables successfully") { + val indexOptions = FlintSparkIndexOptions(Map.empty) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(indexOptions, testFlintIndex) + .create() + + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } + test("create materialized view with default checkpoint location successfully") { withTempDir { checkpointDir => setFlintSparkConf( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 23a336b4c..68d370791 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} import java.util.Comparator import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} @@ -23,6 +23,7 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.{FlintSuite, SparkConf} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY} import org.apache.spark.sql.streaming.StreamTest @@ -49,6 +50,8 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit override def beforeAll(): Unit = { super.beforeAll() + // Revoke override in FlintSuite on IT + conf.unsetConf(FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key) // Replace executor to avoid impact on IT. // TODO: Currently no IT test scheduler so no need to restore it back. @@ -534,6 +537,50 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiValueStructTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_value Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)) ), + | ( 2, array(STRUCT("2_Monday", 2), null) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 4, null ) + |""".stripMargin) + } + + protected def createMultiColumnArrayTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_valueA Array>, + | multi_valueB Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)), array(STRUCT("2_Monday", 2), null) ), + | ( 2, array(STRUCT("2_Monday", 2), null) , array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) , array(STRUCT("1_one", 1))), + | ( 4, null, array(STRUCT("1_one", 1))) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( @@ -669,4 +716,126 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (11, null, false) | """.stripMargin) } + + protected def createIpAddressTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | id INT, + | ipAddress STRING, + | isV6 BOOLEAN, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, '127.0.0.1', false, true), + | (2, '192.168.1.0', false, true), + | (3, '192.168.1.1', false, true), + | (4, '192.168.2.1', false, true), + | (5, '192.168.2.', false, false), + | (6, '2001:db8::ff00:12:3455', true, true), + | (7, '2001:db8::ff00:12:3456', true, true), + | (8, '2001:db8::ff00:13:3457', true, true), + | (9, '2001:db8::ff00:12:', true, false) + | """.stripMargin) + } + + protected def createNestedJsonContentTable(tempFile: Path, testTable: String): Unit = { + val json = + """ + |[ + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Tower Bridge", "length": 801}, + | {"name": "London Bridge", "length": 928} + | ], + | "city": "London", + | "country": "England", + | "coor": { + | "lat": 51.5074, + | "long": -0.1278, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Pont Neuf", "length": 232}, + | {"name": "Pont Alexandre III", "length": 160} + | ], + | "city": "Paris", + | "country": "France", + | "coor": { + | "lat": 48.8566, + | "long": 2.3522, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Rialto Bridge", "length": 48}, + | {"name": "Bridge of Sighs", "length": 11} + | ], + | "city": "Venice", + | "country": "Italy", + | "coor": { + | "lat": 45.4408, + | "long": 12.3155, + | "alt": 2 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Charles Bridge", "length": 516}, + | {"name": "Legion Bridge", "length": 343} + | ], + | "city": "Prague", + | "country": "Czech Republic", + | "coor": { + | "lat": 50.0755, + | "long": 14.4378, + | "alt": 200 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Chain Bridge", "length": 375}, + | {"name": "Liberty Bridge", "length": 333} + | ], + | "city": "Budapest", + | "country": "Hungary", + | "coor": { + | "lat": 47.4979, + | "long": 19.0402, + | "alt": 96 + | } + | }, + | { + | "_time": "1990-09-13T12:00:00", + | "bridges": null, + | "city": "Warsaw", + | "country": "Poland", + | "coor": null + | } + |] + |""".stripMargin + val tempFile = Files.createTempFile("jsonTestData", ".json") + val absolutPath = tempFile.toAbsolutePath.toString; + Files.write(tempFile, json.getBytes) + sql(s""" + | CREATE TEMPORARY VIEW $testTable + | USING org.apache.spark.sql.json + | OPTIONS ( + | path "$absolutPath", + | multiLine true + | ); + |""".stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala index debb95370..e16d40f2a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala @@ -55,6 +55,8 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match latestLogEntry(testLatestId) should (contain("latestId" -> testLatestId) and contain("state" -> "active") and contain("jobStartTime" -> 0) + and contain("lastRefreshStartTime" -> 0) + and contain("lastRefreshCompleteTime" -> 0) and contain("dataSourceName" -> testDataSourceName)) implicit val formats: Formats = Serialization.formats(NoTypeHints) @@ -77,9 +79,25 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match .create() flint.refreshIndex(testFlintIndex) - val latest = latestLogEntry(testLatestId) + var latest = latestLogEntry(testLatestId) + val prevJobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() latest should contain("state" -> "active") - latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L + prevJobStartTime should be > 0L + prevLastRefreshStartTime should be > 0L + prevLastRefreshCompleteTime should be > prevLastRefreshStartTime + + flint.refreshIndex(testFlintIndex) + latest = latestLogEntry(testLatestId) + val jobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val lastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val lastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() + jobStartTime should be > prevLastRefreshCompleteTime + lastRefreshStartTime should be > prevLastRefreshCompleteTime + lastRefreshCompleteTime should be > lastRefreshStartTime } test("incremental refresh index") { @@ -97,9 +115,26 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match .create() flint.refreshIndex(testFlintIndex) - val latest = latestLogEntry(testLatestId) + var latest = latestLogEntry(testLatestId) + val prevJobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshStartTime = + latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() latest should contain("state" -> "active") - latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L + prevJobStartTime should be > 0L + prevLastRefreshStartTime should be > 0L + prevLastRefreshCompleteTime should be > prevLastRefreshStartTime + + flint.refreshIndex(testFlintIndex) + latest = latestLogEntry(testLatestId) + val jobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val lastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val lastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() + jobStartTime should be > prevLastRefreshCompleteTime + lastRefreshStartTime should be > prevLastRefreshCompleteTime + lastRefreshCompleteTime should be > lastRefreshStartTime } } @@ -142,6 +177,8 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match val latest = latestLogEntry(testLatestId) latest should contain("state" -> "refreshing") latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L + latest("lastRefreshStartTime").asInstanceOf[Number].longValue() shouldBe 0L + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() shouldBe 0L } test("update auto refresh index to full refresh index") { @@ -153,13 +190,24 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match .create() flint.refreshIndex(testFlintIndex) + var latest = latestLogEntry(testLatestId) + val prevLastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() + val index = flint.describeIndex(testFlintIndex).get val updatedIndex = flint .skippingIndex() .copyWithUpdate(index, FlintSparkIndexOptions(Map("auto_refresh" -> "false"))) flint.updateIndex(updatedIndex) - val latest = latestLogEntry(testLatestId) + latest = latestLogEntry(testLatestId) latest should contain("state" -> "active") + latest("lastRefreshStartTime") + .asInstanceOf[Number] + .longValue() shouldBe prevLastRefreshStartTime + latest("lastRefreshCompleteTime") + .asInstanceOf[Number] + .longValue() shouldBe prevLastRefreshCompleteTime } test("delete and vacuum index") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index a6f7e0ed0..c9f6c47f7 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -9,7 +9,6 @@ import scala.jdk.CollectionConverters.mapAsJavaMapConverter import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods._ -import org.opensearch.OpenSearchException import org.opensearch.action.get.GetRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.core.FlintOptions @@ -207,13 +206,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { val indexInitial = flint.describeIndex(testIndex).get indexInitial.options.refreshInterval() shouldBe Some("4 Minute") - the[OpenSearchException] thrownBy { - val client = - OpenSearchClientUtils.createClient(new FlintOptions(openSearchOptions.asJava)) - client.get( - new GetRequest(OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME, testIndex), - RequestOptions.DEFAULT) - } + indexInitial.options.isExternalSchedulerEnabled() shouldBe false // Update Flint index to change refresh interval val updatedIndex = flint @@ -228,6 +221,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { val indexFinal = flint.describeIndex(testIndex).get indexFinal.options.autoRefresh() shouldBe true indexFinal.options.refreshInterval() shouldBe Some("5 Minutes") + indexFinal.options.isExternalSchedulerEnabled() shouldBe true indexFinal.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) // Verify scheduler index is updated diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala new file mode 100644 index 000000000..c0d253fd3 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala @@ -0,0 +1,457 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import java.util.{Base64, List} + +import scala.collection.JavaConverters._ + +import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.json4s.native.JsonMethods._ +import org.opensearch.flint.common.FlintVersion.current +import org.opensearch.flint.common.metadata.FlintMetadata +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} +import org.opensearch.flint.spark.{FlintSparkIndexOptions, FlintSparkSuite} +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} +import org.scalatest.Entry +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.flint.config.FlintSparkConf + +class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Matchers { + + /** Lazy initialize after container started. */ + lazy val options = new FlintOptions(openSearchOptions.asJava) + lazy val flintClient = new FlintOpenSearchClient(options) + lazy val flintMetadataCacheWriter = new FlintOpenSearchMetadataCacheWriter(options) + lazy val flintIndexMetadataService = new FlintOpenSearchIndexMetadataService(options) + + private val testTable = "spark_catalog.default.metadatacache_test" + private val testFlintIndex = getSkippingIndexName(testTable) + private val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) + private val testLastRefreshCompleteTime = 1234567890123L + private val flintMetadataLogEntry = FlintMetadataLogEntry( + testLatestId, + 0L, + 0L, + testLastRefreshCompleteTime, + FlintMetadataLogEntry.IndexState.ACTIVE, + Map.empty[String, Any], + "", + Map.empty[String, Any]) + + override def beforeAll(): Unit = { + super.beforeAll() + createPartitionedMultiRowAddressTable(testTable) + } + + override def afterAll(): Unit = { + sql(s"DROP TABLE $testTable") + super.afterAll() + } + + override def afterEach(): Unit = { + deleteTestIndex(testFlintIndex) + super.afterEach() + } + + test("build disabled metadata cache writer") { + FlintMetadataCacheWriterBuilder + .build(FlintSparkConf()) shouldBe a[FlintDisabledMetadataCacheWriter] + } + + test("build opensearch metadata cache writer") { + setFlintSparkConf(FlintSparkConf.METADATA_CACHE_WRITE, "true") + withMetadataCacheWriteEnabled { + FlintMetadataCacheWriterBuilder + .build(FlintSparkConf()) shouldBe a[FlintOpenSearchMetadataCacheWriter] + } + } + + test("serialize metadata cache to JSON") { + val expectedMetadataJson: String = s""" + | { + | "_meta": { + | "version": "${current()}", + | "name": "$testFlintIndex", + | "kind": "test_kind", + | "source": "$testTable", + | "indexedColumns": [ + | { + | "test_field": "spark_type" + | }], + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | }, + | "properties": { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "refreshInterval": 600, + | "sourceTables": ["$testTable"], + | "lastRefreshTime": $testLastRefreshCompleteTime + | }, + | "latestId": "$testLatestId" + | }, + | "properties": { + | "test_field": { + | "type": "os_type" + | } + | } + | } + |""".stripMargin + val builder = new FlintMetadata.Builder + builder.name(testFlintIndex) + builder.kind("test_kind") + builder.source(testTable) + builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava) + builder.options( + Map("auto_refresh" -> "true", "refresh_interval" -> "10 Minutes") + .mapValues(_.asInstanceOf[AnyRef]) + .asJava) + builder.schema(Map[String, AnyRef]("test_field" -> Map("type" -> "os_type").asJava).asJava) + builder.latestLogEntry(flintMetadataLogEntry) + + val metadata = builder.build() + flintMetadataCacheWriter.serialize(metadata) should matchJson(expectedMetadataJson) + } + + test("write metadata cache to index mappings") { + val metadata = FlintOpenSearchIndexMetadataService + .deserialize("{}") + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + Seq(SKIPPING_INDEX_TYPE, COVERING_INDEX_TYPE).foreach { case kind => + test(s"write metadata cache to $kind index mappings with source tables") { + val content = + s""" { + | "_meta": { + | "kind": "$kind", + | "source": "$testTable" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[List[String]] + .toArray should contain theSameElementsAs Array(testTable) + } + } + + test(s"write metadata cache to materialized view index mappings with source tables") { + val testTable2 = "spark_catalog.default.metadatacache_test2" + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": [ + | "$testTable", "$testTable2" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[List[String]] + .toArray should contain theSameElementsAs Array(testTable, testTable2) + } + + test("write metadata cache to index mappings with refresh interval") { + val content = + """ { + | "_meta": { + | "kind": "test_kind", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 4 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("refreshInterval", 600), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + test("exclude refresh interval in metadata cache when auto refresh is false") { + val content = + """ { + | "_meta": { + | "kind": "test_kind", + | "options": { + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + test("exclude last refresh time in metadata cache when index has not been refreshed") { + val metadata = FlintOpenSearchIndexMetadataService + .deserialize("{}") + .copy(latestLogEntry = Some(flintMetadataLogEntry.copy(lastRefreshCompleteTime = 0L))) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 2 + properties should contain( + Entry("metadataCacheVersion", FlintMetadataCache.metadataCacheVersion)) + } + + test("write metadata cache to index mappings and preserve other index metadata") { + val content = + """ { + | "_meta": { + | "kind": "test_kind" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + + flintIndexMetadataService.updateIndexMetadata(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + flintIndexMetadataService.getIndexMetadata(testFlintIndex).kind shouldBe "test_kind" + flintIndexMetadataService.getIndexMetadata(testFlintIndex).name shouldBe empty + flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 + var properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + + val newContent = + """ { + | "_meta": { + | "kind": "test_kind", + | "name": "test_name" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val newMetadata = FlintOpenSearchIndexMetadataService + .deserialize(newContent) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintIndexMetadataService.updateIndexMetadata(testFlintIndex, newMetadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, newMetadata) + + flintIndexMetadataService.getIndexMetadata(testFlintIndex).kind shouldBe "test_kind" + flintIndexMetadataService.getIndexMetadata(testFlintIndex).name shouldBe "test_name" + flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 + properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + Seq( + ( + "auto refresh index with external scheduler", + Map( + "auto_refresh" -> "true", + "scheduler_mode" -> "external", + "refresh_interval" -> "10 Minute", + "checkpoint_location" -> "s3a://test/"), + s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "refreshInterval": 600, + | "sourceTables": ["$testTable"] + | } + |""".stripMargin), + ( + "full refresh index", + Map.empty[String, String], + s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] + | } + |""".stripMargin), + ( + "incremental refresh index", + Map("incremental_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] + | } + |""".stripMargin)).foreach { case (refreshMode, optionsMap, expectedJson) => + test(s"write metadata cache for $refreshMode") { + withExternalSchedulerEnabled { + withMetadataCacheWriteEnabled { + withTempDir { checkpointDir => + // update checkpoint_location if available in optionsMap + val indexOptions = FlintSparkIndexOptions( + optionsMap + .get("checkpoint_location") + .map(_ => + optionsMap.updated("checkpoint_location", checkpointDir.getAbsolutePath)) + .getOrElse(optionsMap)) + + flint + .skippingIndex() + .onTable(testTable) + .addMinMax("age") + .options(indexOptions, testFlintIndex) + .create() + + var index = flint.describeIndex(testFlintIndex) + index shouldBe defined + val propertiesJson = + compact( + render( + parse( + flintMetadataCacheWriter.serialize( + index.get.metadata())) \ "_meta" \ "properties")) + propertiesJson should matchJson(expectedJson) + + flint.refreshIndex(testFlintIndex) + index = flint.describeIndex(testFlintIndex) + index shouldBe defined + val lastRefreshTime = + compact( + render( + parse( + flintMetadataCacheWriter.serialize( + index.get.metadata())) \ "_meta" \ "properties" \ "lastRefreshTime")).toLong + lastRefreshTime should be > 0L + } + } + } + } + } + + test("write metadata cache for auto refresh index with internal scheduler") { + withMetadataCacheWriteEnabled { + withTempDir { checkpointDir => + flint + .skippingIndex() + .onTable(testTable) + .addMinMax("age") + .options( + FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "scheduler_mode" -> "internal", + "refresh_interval" -> "10 Minute", + "checkpoint_location" -> checkpointDir.getAbsolutePath)), + testFlintIndex) + .create() + + var index = flint.describeIndex(testFlintIndex) + index shouldBe defined + val propertiesJson = + compact( + render(parse( + flintMetadataCacheWriter.serialize(index.get.metadata())) \ "_meta" \ "properties")) + propertiesJson should matchJson(s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "refreshInterval": 600, + | "sourceTables": ["$testTable"] + | } + |""".stripMargin) + + flint.refreshIndex(testFlintIndex) + index = flint.describeIndex(testFlintIndex) + index shouldBe defined + compact(render(parse( + flintMetadataCacheWriter.serialize( + index.get.metadata())) \ "_meta" \ "properties")) should not include "lastRefreshTime" + } + } + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index cbc4308b0..3bd98edf1 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -597,4 +597,68 @@ class FlintSparkPPLBasicITSuite | """.stripMargin)) assert(ex.getMessage().contains("Invalid table name")) } + + test("Search multiple tables - translated into union call with fields") { + val frame = sql(s""" + | source = $t1, $t2 + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val allFields1 = UnresolvedStar(None) + val allFields2 = UnresolvedStar(None) + + val projectedTable1 = Project(Seq(allFields1), table1) + val projectedTable2 = Project(Seq(allFields2), table2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("Search multiple tables - with table alias") { + val frame = sql(s""" + | source = $t1, $t2 as t | where t.country = "USA" + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val plan1 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table1)) + val plan2 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table2)) + + val projectedTable1 = Project(Seq(UnresolvedStar(None)), plan1) + val projectedTable2 = Project(Seq(UnresolvedStar(None)), plan2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltInDateTimeFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltInDateTimeFunctionITSuite.scala index 71ed72814..8001a690d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltInDateTimeFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltInDateTimeFunctionITSuite.scala @@ -218,6 +218,117 @@ class FlintSparkPPLBuiltInDateTimeFunctionITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test DATE_ADD") { + val frame1 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2d` = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 DAY) + | | fields `'2020-08-26' + 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-28"))), frame1) + + val frame2 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2d` = DATE_ADD(DATE('2020-08-26'), INTERVAL -2 DAY) + | | fields `'2020-08-26' - 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-24"))), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2m` = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 MONTH) + | | fields `'2020-08-26' + 2m` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-10-26"))), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2y` = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 YEAR) + | | fields `'2020-08-26' + 2y` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2022-08-26"))), frame4) + + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval `'2020-08-26 01:01:01' + 2h` = DATE_ADD(TIMESTAMP('2020-08-26 01:01:01'), INTERVAL 2 HOUR) + | | fields `'2020-08-26 01:01:01' + 2h` | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("""Parameter 1 requires the "DATE" type""")) + } + + test("test DATE_SUB") { + val frame1 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2d` = DATE_SUB(DATE('2020-08-26'), INTERVAL 2 DAY) + | | fields `'2020-08-26' - 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-24"))), frame1) + + val frame2 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2d` = DATE_SUB(DATE('2020-08-26'), INTERVAL -2 DAY) + | | fields `'2020-08-26' + 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-28"))), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2m` = DATE_SUB(DATE('2020-08-26'), INTERVAL 12 MONTH) + | | fields `'2020-08-26' - 2m` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2019-08-26"))), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2y` = DATE_SUB(DATE('2020-08-26'), INTERVAL 2 YEAR) + | | fields `'2020-08-26' - 2y` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2018-08-26"))), frame4) + + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval `'2020-08-26 01:01:01' - 2h` = DATE_SUB(TIMESTAMP('2020-08-26 01:01:01'), INTERVAL 2 HOUR) + | | fields `'2020-08-26 01:01:01' - 2h` | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("""Parameter 1 requires the "DATE" type""")) + } + + test("test TIMESTAMPADD") { + val frame = sql(s""" + | source = $testTable + | | eval `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') + | | eval `TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00'))` = TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00')) + | | eval `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` = TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') + | | fields `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')`, `TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00'))`, `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` + | | head 1 + | """.stripMargin) + assertSameRows( + Seq( + Row( + Timestamp.valueOf("2000-01-18 00:00:00"), + Timestamp.valueOf("2000-01-18 00:00:00"), + Timestamp.valueOf("1999-10-01 00:00:00"))), + frame) + } + + test("test TIMESTAMPDIFF") { + val frame = sql(s""" + | source = $testTable + | | eval `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') + | | eval `TIMESTAMPDIFF(SECOND, TIMESTAMP('2000-01-01 00:00:23'), TIMESTAMP('2000-01-01 00:00:00'))` = TIMESTAMPDIFF(SECOND, TIMESTAMP('2000-01-01 00:00:23'), TIMESTAMP('2000-01-01 00:00:00')) + | | fields `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')`, `TIMESTAMPDIFF(SECOND, TIMESTAMP('2000-01-01 00:00:23'), TIMESTAMP('2000-01-01 00:00:00'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(4, -23)), frame) + } + + test("test CURRENT_TIMEZONE") { + val frame = sql(s""" + | source = $testTable + | | eval `CURRENT_TIMEZONE` = CURRENT_TIMEZONE() + | | fields `CURRENT_TIMEZONE` + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("test UTC_TIMESTAMP") { + val frame = sql(s""" + | source = $testTable + | | eval `UTC_TIMESTAMP` = UTC_TIMESTAMP() + | | fields `UTC_TIMESTAMP` + | """.stripMargin) + assert(frame.collect().length > 0) + } + test("test hour, minute, second, HOUR_OF_DAY, MINUTE_OF_HOUR") { val frame = sql(s""" | source = $testTable @@ -284,24 +395,6 @@ class FlintSparkPPLBuiltInDateTimeFunctionITSuite assert(ex.getMessage.contains("ADDTIME is not a builtin function of PPL")) } - test("test DATE_ADD is not supported") { - val ex = intercept[UnsupportedOperationException](sql(s""" - | source = $testTable - | | eval `DATE_ADD` = DATE_ADD() - | | fields DATE_ADD | head 1 - | """.stripMargin)) - assert(ex.getMessage.contains("DATE_ADD is not a builtin function of PPL")) - } - - test("test DATE_SUB is not supported") { - val ex = intercept[UnsupportedOperationException](sql(s""" - | source = $testTable - | | eval `DATE_SUB` = DATE_SUB() - | | fields DATE_SUB | head 1 - | """.stripMargin)) - assert(ex.getMessage.contains("DATE_SUB is not a builtin function of PPL")) - } - test("test DATETIME is not supported") { val ex = intercept[UnsupportedOperationException](sql(s""" | source = $testTable @@ -445,22 +538,6 @@ class FlintSparkPPLBuiltInDateTimeFunctionITSuite assert(ex.getMessage.contains("TIMEDIFF is not a builtin function of PPL")) } - test("test TIMESTAMPADD is not supported") { - intercept[Exception](sql(s""" - | source = $testTable - | | eval `TIMESTAMPADD` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') - | | fields TIMESTAMPADD | head 1 - | """.stripMargin)) - } - - test("test TIMESTAMPDIFF is not supported") { - intercept[Exception](sql(s""" - | source = $testTable - | | eval `TIMESTAMPDIFF_1` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') - | | fields TIMESTAMPDIFF_1 | head 1 - | """.stripMargin)) - } - test("test TO_DAYS is not supported") { val ex = intercept[UnsupportedOperationException](sql(s""" | source = $testTable @@ -497,15 +574,6 @@ class FlintSparkPPLBuiltInDateTimeFunctionITSuite assert(ex.getMessage.contains("UTC_TIME is not a builtin function of PPL")) } - test("test UTC_TIMESTAMP is not supported") { - val ex = intercept[UnsupportedOperationException](sql(s""" - | source = $testTable - | | eval `UTC_TIMESTAMP` = UTC_TIMESTAMP() - | | fields UTC_TIMESTAMP | head 1 - | """.stripMargin)) - assert(ex.getMessage.contains("UTC_TIMESTAMP is not a builtin function of PPL")) - } - test("test YEARWEEK is not supported") { val ex = intercept[UnsupportedOperationException](sql(s""" | source = $testTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala new file mode 100644 index 000000000..d9cf8968b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.SparkException +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLCidrmatchITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createIpAddressTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test cidrmatch for ipv4 for 192.168.1.0/24") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 2) + } + + test("test cidrmatch for ipv4 for 192.169.1.0/24") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 0) + } + + test("test cidrmatch for ipv6 for 2001:db8::/32") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2001:db8::/32') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 3) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2003:db8::/32') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 0) + } + + test("test cidrmatch for ipv6 with ipv4 cidr") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } + + test("test cidrmatch for invalid ipv4 addresses") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = false and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } + + test("test cidrmatch for invalid ipv6 addresses") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala new file mode 100644 index 000000000..f0404bf7b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala @@ -0,0 +1,255 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import scala.collection.mutable + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Explode, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLExpandITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val occupationTable = "spark_catalog.default.flint_ppl_flat_table_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val multiArraysTable = "spark_catalog.default.flint_ppl_multi_array_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createOccupationTable(occupationTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + createMultiColumnArrayTable(multiArraysTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("expand for eval field of an array") { + val frame = sql( + s""" source = $occupationTable | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", "Engineer", 1), + Row("Jake", "Engineer", 2), + Row("Jake", "Engineer", 3), + Row("Hello", "Artist", 1), + Row("Hello", "Artist", 2), + Row("Hello", "Artist", 3), + Row("John", "Doctor", 1), + Row("John", "Doctor", 2), + Row("John", "Doctor", 3), + Row("David", "Doctor", 1), + Row("David", "Doctor", 2), + Row("David", "Doctor", 3), + Row("David", "Unemployed", 1), + Row("David", "Unemployed", 2), + Row("David", "Unemployed", 3), + Row("Jane", "Scientist", 1), + Row("Jane", "Scientist", 2), + Row("Jane", "Scientist", 3)) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_flat_table_test")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), table) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project( + seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("occupation"), + UnresolvedAttribute("uid")), + dropSourceColumn) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("expand for structs") { + val frame = sql( + s""" source = $multiValueTable | expand multi_value AS exploded_multi_value | fields exploded_multi_value + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row("1_one", 1)), + Row(Row(null, 11)), + Row(Row("1_three", null)), + Row(Row("2_Monday", 2)), + Row(null), + Row(Row("3_third", 3)), + Row(Row("3_4th", 4)), + Row(null)) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val generate = Generate( + Explode(UnresolvedAttribute("multi_value")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("exploded_multi_value")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_value")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("exploded_multi_value")), dropSourceColumn) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | expand bridges + | | fields bridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))), + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))) + // Row(null)) -> in case of outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) it will include the `null` row + ) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val generate = + Generate(Explode(UnresolvedAttribute("bridges")), seq(), outer = false, None, seq(), filter) + val expectedPlan = Project(Seq(UnresolvedAttribute("bridges")), generate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs with alias") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' + | | expand bridges as britishBridges + | | fields britishBridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge")), + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge"))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter(EqualTo(UnresolvedAttribute("country"), Literal("England")), table) + val generate = Generate( + Explode(UnresolvedAttribute("bridges")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("britishBridges")), + filter) + val dropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("bridges")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("britishBridges")), dropColumn) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val frame = sql(s""" + | source = $multiArraysTable + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1, Row("1_one", 1), Row("2_Monday", 2)), + Row(1, Row("1_one", 1), null), + Row(1, Row(null, 11), Row("2_Monday", 2)), + Row(1, Row(null, 11), null), + Row(1, Row("1_three", null), Row("2_Monday", 2)), + Row(1, Row("1_three", null), null), + Row(2, Row("2_Monday", 2), Row("3_third", 3)), + Row(2, Row("2_Monday", 2), Row("3_4th", 4)), + Row(2, null, Row("3_third", 3)), + Row(2, null, Row("3_4th", 4)), + Row(3, Row("3_third", 3), Row("1_one", 1)), + Row(3, Row("3_4th", 4), Row("1_one", 1))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_array_test")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), table) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala new file mode 100644 index 000000000..e714a5f7e --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala @@ -0,0 +1,350 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFlattenITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("flatten for structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | fields coor + | | flatten coor + | """.stripMargin) + + assert(frame.columns.sameElements(Array("alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(35, 51.5074, -0.1278), Row(null, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val projectCoor = Project(Seq(UnresolvedAttribute("coor")), filter) + val flattenCoor = flattenPlanFor("coor", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + private def flattenPlanFor(flattenedColumn: String, parentPlan: LogicalPlan): LogicalPlan = { + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute(flattenedColumn)) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), outer = true, None, seq(), parentPlan) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute(flattenedColumn)), generate) + dropSourceColumn + } + + test("flatten for arrays") { + val frame = sql(s""" + | source = $testTable + | | fields bridges + | | flatten bridges + | """.stripMargin) + + assert(frame.columns.sameElements(Array("length", "name"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, null), + Row(11L, "Bridge of Sighs"), + Row(48L, "Rialto Bridge"), + Row(160L, "Pont Alexandre III"), + Row(232L, "Pont Neuf"), + Row(801L, "Tower Bridge"), + Row(928L, "London Bridge"), + Row(343L, "Legion Bridge"), + Row(516L, "Charles Bridge"), + Row(333L, "Liberty Bridge"), + Row(375L, "Chain Bridge")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCoor = Project(Seq(UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenBridges) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten for structs and arrays") { + val frame = sql(s""" + | source = $testTable | flatten bridges | flatten coor + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("_time", "city", "country", "length", "name", "alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("1990-09-13T12:00:00", "Warsaw", "Poland", null, null, null, null, null), + Row( + "2024-09-13T12:00:00", + "Venice", + "Italy", + 11L, + "Bridge of Sighs", + 2, + 45.4408, + 12.3155), + Row("2024-09-13T12:00:00", "Venice", "Italy", 48L, "Rialto Bridge", 2, 45.4408, 12.3155), + Row( + "2024-09-13T12:00:00", + "Paris", + "France", + 160L, + "Pont Alexandre III", + 35, + 48.8566, + 2.3522), + Row("2024-09-13T12:00:00", "Paris", "France", 232L, "Pont Neuf", 35, 48.8566, 2.3522), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 801L, + "Tower Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 928L, + "London Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 343L, + "Legion Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 516L, + "Charles Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 333L, + "Liberty Bridge", + 96, + 47.4979, + 19.0402), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 375L, + "Chain Bridge", + 96, + 47.4979, + 19.0402)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](3)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val flattenBridges = flattenPlanFor("bridges", table) + val flattenCoor = flattenPlanFor("coor", flattenBridges) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val frame = sql(s""" + | source = $testTable + | | fields country, bridges + | | flatten bridges + | | fields country, length + | | stats avg(length) as avg by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, "Poland"), + Row(196d, "France"), + Row(429.5, "Czech Republic"), + Row(864.5, "England"), + Row(29.5, "Italy"), + Row(354.0, "Hungary")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCountryBridges = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCountryBridges) + val projectCountryLength = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("length")), flattenBridges) + val average = Alias( + UnresolvedFunction( + seq("AVG"), + seq(UnresolvedAttribute("length")), + isDistinct = false, + None, + ignoreNulls = false), + "avg")() + val country = Alias(UnresolvedAttribute("country"), "country")() + val grouping = Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(average, country), projectCountryLength) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct table") { + val frame = sql(s""" + | source = $structTable + | | flatten struct_col + | | flatten field1 + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(30, 123, "value1"), Row(40, 456, "value2"), Row(50, 789, "value3")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct nested table") { + val frame = sql(s""" + | source = $structNestedTable + | | flatten struct_col + | | flatten field1 + | | flatten struct_col2 + | | flatten field1 + | """.stripMargin) + + assert( + frame.columns.sameElements(Array("int_col", "field2", "subfield", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(30, 123, "value1", 23, "valueA"), + Row(40, 123, "value5", 33, "valueB"), + Row(30, 823, "value4", 83, "valueC"), + Row(40, 456, "value2", 46, "valueD"), + Row(50, 789, "value3", 89, "valueE")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_nested_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val flattenStructCol2 = flattenPlanFor("struct_col2", flattenField1) + val flattenField1Again = flattenPlanFor("field1", flattenStructCol2) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1Again) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten multi value nullable") { + val frame = sql(s""" + | source = $multiValueTable + | | flatten multi_value + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "name", "value"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "1_one", 1), + Row(1, null, 11), + Row(1, "1_three", null), + Row(2, "2_Monday", 2), + Row(2, null, null), + Row(3, "3_third", 3), + Row(3, "3_4th", 4), + Row(4, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val flattenMultiValue = flattenPlanFor("multi_value", table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenMultiValue) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index 00e55d50a..3127325c8 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} @@ -924,4 +924,271 @@ class FlintSparkPPLJoinITSuite s }.size == 13) } + + test("test multiple joins without table aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN ON $testTable1.name = $testTable2.name $testTable2 + | | JOIN ON $testTable2.name = $testTable3.name $testTable3 + | | fields $testTable1.name, $testTable2.name, $testTable3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | fields t1.name, t2.name, t3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("check access the reference by aliases") { + var frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 as t1 + | | JOIN ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 ] as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as t2 ] + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("access the reference by override aliases should throw exception") { + var ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as tt ] + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as tt ] as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 ] as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala index 7cc0a221d..fca758101 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -163,30 +163,32 @@ class FlintSparkPPLJsonFunctionITSuite assert(ex.getMessage().contains("should all be the same type")) } - test("test json_array() with json()") { + test("test json_array() with to_json_tring()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""[1.0,2.0,0.0,-1.0,1.1,-0.11]""")), frame) } - test("test json_array_length()") { + test("test array_length()") { var frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(5)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(6)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array()) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array()) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(0)), frame) + } - frame = sql(s""" + test("test json_array_length()") { + var frame = sql(s""" | source = $testTable | eval result = json_array_length('[]') | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row(0)), frame) @@ -211,24 +213,24 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object()") { // test value is a string var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 'string_value')) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 'string_value')) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":"string_value"}""")), frame) // test value is a number frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 123.45)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 123.45)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":123.45}""")), frame) // test value is a boolean frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', true)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', true)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":true}""")), frame) frame = sql(s""" - | source = $testTable| eval result = json(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"a":1,"b":2,"c":3}""")), frame) } @@ -236,13 +238,13 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() and json_array()") { // test value is an empty array var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array())) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array())) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[]}""")), frame) // test value is an array frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array(1, 2, 3))) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array(1, 2, 3))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[1,2,3]}""")), frame) @@ -272,14 +274,14 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() nested") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"outer":{"inner":123.45}}""")), frame) } test("test json_object(), json_array() and json()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]}""")), frame) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala new file mode 100644 index 000000000..f86502521 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions.{col, to_json} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLLambdaFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test forall()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame4) + } + + test("test exists()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame4) + + } + + test("test filter()") { + val frame = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 2, 1.1))), frame) + + val frame2 = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq())), frame2) + + val frame3 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + + assertSameRows(Seq(Row("""[{"a":1,"b":-1}]""")), frame3.select(to_json(col("result")))) + + val frame4 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""[]""")), frame4.select(to_json(col("result")))) + } + + test("test transform()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(2, 3, 4))), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 3, 5))), frame2) + } + + test("test reduce()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(6)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 1, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(7)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(60)), frame3) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala new file mode 100644 index 000000000..bc4463537 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTrendlineITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test trendline sma command without fields command and without alias") { + val frame = sql(s""" + | source = $testTable | sort - age | trendline sma(2, age) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", "USA", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 50.0), + Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command with fields command") { + val frame = sql(s""" + | source = $testTable | trendline sort - age sma(3, age) as age_sma | fields name, age, age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, null), + Row("Hello", 30, null), + Row("John", 25, 41.666666666666664), + Row("Jane", 20, 25)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageSmaField = UnresolvedAttribute("age_sma") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline sma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, null, null), + Row("John", 25, 22.5, null), + Row("Hello", 30, 27.5, 25.0), + Row("Jake", 70, 50.0, 41.666666666666664)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "three_points_sma")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 45.0), + Row("Hello", 60, 55.0), + Row("Jake", 140, 100.0)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val doubledAgeField = UnresolvedAttribute("doubled_age") + val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma") + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false), + "doubled_age")()), + table) + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val doubleAgeSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = + CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")()) + val expectedPlan = Project( + Seq(nameField, doubledAgeField, doubledAgeSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command chaining") { + val frame = sql(s""" + | source = $testTable | eval age_1 = age, age_2 = age | trendline sort - age_1 sma(3, age_1) | trendline sort + age_2 sma(3, age_2) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "age_1", + "age_2", + "age_1_trendline", + "age_2_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, null, 25.0), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, null, 41.666666666666664), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 25.0, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 41.666666666666664, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala new file mode 100644 index 000000000..fb14210e9 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl.tpch + +import org.opensearch.flint.spark.ppl.FlintPPLSuite + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_SECOND +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.internal.SQLConf + +trait TPCHQueryBase extends FlintPPLSuite { + + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS.key, Int.MaxValue.toString) + } + + override def beforeAll(): Unit = { + super.beforeAll() + RuleExecutor.resetMetrics() + CodeGenerator.resetCompileTime() + WholeStageCodegenExec.resetCodeGenTime() + tpchCreateTable.values.foreach { ppl => + sql(ppl) + } + } + + override def afterAll(): Unit = { + try { + tpchCreateTable.keys.foreach { tableName => + spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) + } + // For debugging dump some statistics about how much time was spent in various optimizer rules + // code generation, and compilation. + logWarning(RuleExecutor.dumpTimeSpent()) + val codeGenTime = WholeStageCodegenExec.codeGenTime.toDouble / NANOS_PER_SECOND + val compileTime = CodeGenerator.compileTime.toDouble / NANOS_PER_SECOND + val codegenInfo = + s""" + |=== Metrics of Whole-stage Codegen === + |Total code generation time: $codeGenTime seconds + |Total compile time: $compileTime seconds + """.stripMargin + logWarning(codegenInfo) + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + def checkGeneratedCode(plan: SparkPlan, checkMethodCodeSize: Boolean = true): Unit = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() + + def findSubtrees(plan: SparkPlan): Unit = { + plan foreach { + case s: WholeStageCodegenExec => + codegenSubtrees += s + case s => + s.subqueries.foreach(findSubtrees) + } + } + + findSubtrees(plan) + codegenSubtrees.toSeq.foreach { subtree => + val code = subtree.doCodeGen()._2 + val (_, ByteCodeStats(maxMethodCodeSize, _, _)) = + try { + // Just check the generated code can be properly compiled + CodeGenerator.compile(code) + } catch { + case e: Exception => + val msg = + s""" + |failed to compile: + |Subtree: + |$subtree + |Generated code: + |${CodeFormatter.format(code)} + """.stripMargin + throw new Exception(msg, e) + } + + assert( + !checkMethodCodeSize || + maxMethodCodeSize <= CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT, + s"too long generated codes found in the WholeStageCodegenExec subtree (id=${subtree.id}) " + + s"and JIT optimization might not work:\n${subtree.treeString}") + } + } + + val tpchCreateTable = Map( + "orders" -> + """ + |CREATE TABLE `orders` ( + |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, + |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, + |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) + |USING parquet + """.stripMargin, + "nation" -> + """ + |CREATE TABLE `nation` ( + |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) + |USING parquet + """.stripMargin, + "region" -> + """ + |CREATE TABLE `region` ( + |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) + |USING parquet + """.stripMargin, + "part" -> + """ + |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, + |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, + |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) + |USING parquet + """.stripMargin, + "partsupp" -> + """ + |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, + |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) + |USING parquet + """.stripMargin, + "customer" -> + """ + |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, + |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), + |`c_mktsegment` STRING, `c_comment` STRING) + |USING parquet + """.stripMargin, + "supplier" -> + """ + |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, + |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) + |USING parquet + """.stripMargin, + "lineitem" -> + """ + |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, + |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), + |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, + |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, + |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) + |USING parquet + """.stripMargin) + + val tpchQueries = Seq( + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "q16", + "q17", + "q18", + "q19", + "q20", + "q21", + "q22") +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala new file mode 100644 index 000000000..1b9681618 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl.tpch + +import org.opensearch.flint.spark.ppl.LogicalPlanTestUtils + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.streaming.StreamTest + +class TPCHQueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with TPCHQueryBase + with StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + tpchQueries.foreach { name => + val queryString = resourceToString( + s"tpch/$name.ppl", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 82fdeb42f..2c3344b3c 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -37,6 +37,9 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +EXPAND: 'EXPAND'; +FLATTEN: 'FLATTEN'; +TRENDLINE: 'TRENDLINE'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -89,6 +92,9 @@ FIELDSUMMARY: 'FIELDSUMMARY'; INCLUDEFIELDS: 'INCLUDEFIELDS'; NULLS: 'NULLS'; +//TRENDLINE KEYWORDS +SMA: 'SMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; @@ -198,6 +204,7 @@ RT_SQR_PRTHS: ']'; SINGLE_QUOTE: '\''; DOUBLE_QUOTE: '"'; BACKTICK: '`'; +ARROW: '->'; // Operators. Bit @@ -295,6 +302,7 @@ CURDATE: 'CURDATE'; CURRENT_DATE: 'CURRENT_DATE'; CURRENT_TIME: 'CURRENT_TIME'; CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_TIMEZONE: 'CURRENT_TIMEZONE'; CURTIME: 'CURTIME'; DATE: 'DATE'; DATEDIFF: 'DATEDIFF'; @@ -371,6 +379,7 @@ JSON: 'JSON'; JSON_OBJECT: 'JSON_OBJECT'; JSON_ARRAY: 'JSON_ARRAY'; JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; +TO_JSON_STRING: 'TO_JSON_STRING'; JSON_EXTRACT: 'JSON_EXTRACT'; JSON_KEYS: 'JSON_KEYS'; JSON_VALID: 'JSON_VALID'; @@ -378,14 +387,22 @@ JSON_VALID: 'JSON_VALID'; //JSON_DELETE: 'JSON_DELETE'; //JSON_EXTEND: 'JSON_EXTEND'; //JSON_SET: 'JSON_SET'; -//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH'; -//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH'; -//JSON_ARRAY_FILTER: 'JSON_FILTER'; +//JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH'; +//JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH'; +//JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER'; //JSON_ARRAY_MAP: 'JSON_ARRAY_MAP'; //JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE'; // COLLECTION FUNCTIONS ARRAY: 'ARRAY'; +ARRAY_LENGTH: 'ARRAY_LENGTH'; + +// LAMBDA FUNCTIONS +//EXISTS: 'EXISTS'; +FORALL: 'FORALL'; +FILTER: 'FILTER'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; // BOOL FUNCTIONS LIKE: 'LIKE'; @@ -393,6 +410,7 @@ ISNULL: 'ISNULL'; ISNOTNULL: 'ISNOTNULL'; ISPRESENT: 'ISPRESENT'; BETWEEN: 'BETWEEN'; +CIDRMATCH: 'CIDRMATCH'; // FLOWCONTROL FUNCTIONS IFNULL: 'IFNULL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 48984b3a5..1cfd172f7 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -53,6 +53,9 @@ commands | renameCommand | fillnullCommand | fieldsummaryCommand + | flattenCommand + | expandCommand + | trendlineCommand ; commandName @@ -80,8 +83,11 @@ commandName | PATTERNS | LOOKUP | RENAME + | EXPAND | FILLNULL | FIELDSUMMARY + | FLATTEN + | TRENDLINE ; searchCommand @@ -89,7 +95,7 @@ searchCommand | (SEARCH)? fromClause logicalExpression # searchFromFilter | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; - + fieldsummaryCommand : FIELDSUMMARY (fieldsummaryParameter)* ; @@ -246,6 +252,25 @@ fillnullCommand : expression ; +expandCommand + : EXPAND fieldExpression (AS alias = qualifiedName)? + ; + +flattenCommand + : FLATTEN fieldExpression + ; + +trendlineCommand + : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + ; + +trendlineType + : SMA + ; kmeansCommand : KMEANS (kmeansParameter)* @@ -320,7 +345,7 @@ joinType ; sideAlias - : LEFT EQUAL leftAlias = ident COMMA? RIGHT EQUAL rightAlias = ident + : (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)? ; joinCriteria @@ -421,8 +446,11 @@ valueExpression | primaryExpression # valueExpressionDefault | positionFunction # positionFunctionCall | caseFunction # caseExpr + | timestampFunction # timestampFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | ident ARROW expression # lambda + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda ; primaryExpression @@ -440,6 +468,7 @@ booleanExpression | isEmptyExpression # isEmptyExpr | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr | EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr + | cidrMatchFunctionCall # cidrFunctionCallExpr ; isEmptyExpression @@ -519,6 +548,10 @@ booleanFunctionCall : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS ; +cidrMatchFunctionCall + : CIDRMATCH LT_PRTHS ipAddress = functionArg COMMA cidrBlock = functionArg RT_PRTHS + ; + convertedDataType : typeName = DATE | typeName = TIME @@ -543,6 +576,7 @@ evalFunctionName | cryptographicFunctionName | jsonFunctionName | collectionFunctionName + | lambdaFunctionName ; functionArgs @@ -672,6 +706,7 @@ dateTimeFunctionName | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP + | CURRENT_TIMEZONE | CURTIME | DATE | DATEDIFF @@ -831,6 +866,7 @@ jsonFunctionName | JSON_OBJECT | JSON_ARRAY | JSON_ARRAY_LENGTH + | TO_JSON_STRING | JSON_EXTRACT | JSON_KEYS | JSON_VALID @@ -847,6 +883,15 @@ jsonFunctionName collectionFunctionName : ARRAY + | ARRAY_LENGTH + ; + +lambdaFunctionName + : FORALL + | EXISTS + | FILTER + | TRANSFORM + | REDUCE ; positionFunctionName @@ -888,6 +933,7 @@ literalValue | decimalLiteral | booleanLiteral | datetimeLiteral //#datetime + | intervalLiteral ; intervalLiteral @@ -1116,4 +1162,6 @@ keywordsCanBeId | SEMI | ANTI | BETWEEN + | CIDRMATCH + | trendlineType ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index e1397a754..54e1205cb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -13,10 +13,12 @@ import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cidr; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; @@ -106,10 +108,18 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitExpand(Expand node, C context) { + return visitChildren(node, context); + } + public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } @@ -178,6 +188,10 @@ public T visitFunction(Function node, C context) { return visitChildren(node, context); } + public T visitLambdaFunction(LambdaFunction node, C context) { + return visitChildren(node, context); + } + public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } @@ -322,4 +336,11 @@ public T visitExistsSubquery(ExistsSubquery node, C context) { public T visitWindow(Window node, C context) { return visitChildren(node, context); } + public T visitCidr(Cidr node, C context) { + return visitChildren(node, context); + } + + public T visitFlatten(Flatten flatten, C context) { + return visitChildren(flatten, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java new file mode 100644 index 000000000..fdbb3ef65 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** AST node that represents CIDR function. */ +@AllArgsConstructor +@Getter +@EqualsAndHashCode(callSuper = false) +@ToString +public class Cidr extends UnresolvedExpression { + private UnresolvedExpression ipAddress; + private UnresolvedExpression cidrBlock; + + @Override + public List getChild() { + return Arrays.asList(ipAddress, cidrBlock); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCidr(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java index fc09ec2f5..bf00b2106 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java @@ -23,11 +23,6 @@ public class Interval extends UnresolvedExpression { private final UnresolvedExpression value; private final IntervalUnit unit; - public Interval(UnresolvedExpression value, String unit) { - this.value = value; - this.unit = IntervalUnit.of(unit); - } - @Override public List getChild() { return Collections.singletonList(value); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java index a7e983473..6e1e0712c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java @@ -17,6 +17,7 @@ public enum IntervalUnit { UNKNOWN, MICROSECOND, + MILLISECOND, SECOND, MINUTE, HOUR, diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java new file mode 100644 index 000000000..e1ee755b8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of lambda function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class LambdaFunction extends UnresolvedExpression { + private final UnresolvedExpression function; + private final List funcArgs; + + @Override + public List getChild() { + List children = new ArrayList<>(); + children.add(function); + children.addAll(funcArgs); + return children; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLambdaFunction(this, context); + } + + @Override + public String toString() { + return String.format( + "(%s) -> %s", + funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")), + function.toString() + ); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java index b513d01bf..dd9947329 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java @@ -8,12 +8,14 @@ import lombok.ToString; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import java.util.Collections; + /** * Extend Relation to describe the table itself */ @ToString public class DescribeRelation extends Relation{ public DescribeRelation(UnresolvedExpression tableName) { - super(tableName); + super(Collections.singletonList(tableName)); } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java new file mode 100644 index 000000000..0e164ccd7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedAttribute; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +/** Logical plan node of Expand */ +@RequiredArgsConstructor +public class Expand extends UnresolvedPlan { + private UnresolvedPlan child; + + @Getter + private final Field field; + @Getter + private final Optional alias; + + @Override + public Expand attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitExpand(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java new file mode 100644 index 000000000..9c57d2adf --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -0,0 +1,34 @@ +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; + +import java.util.List; + +@RequiredArgsConstructor +public class Flatten extends UnresolvedPlan { + + private UnresolvedPlan child; + + @Getter + private final Field field; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFlatten(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java index 89f787d34..176902911 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -25,15 +25,15 @@ public class Join extends UnresolvedPlan { private UnresolvedPlan left; private final UnresolvedPlan right; - private final String leftAlias; - private final String rightAlias; + private final Optional leftAlias; + private final Optional rightAlias; private final JoinType joinType; private final Optional joinCondition; private final JoinHint joinHint; @Override public UnresolvedPlan attach(UnresolvedPlan child) { - this.left = new SubqueryAlias(leftAlias, child); + this.left = leftAlias.isEmpty() ? child : new SubqueryAlias(leftAlias.get(), child); return this; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index 1b30a7998..d8ea104a4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -6,53 +6,34 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; -import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; /** Logical plan node of Relation, the interface for building the searching sources. */ -@AllArgsConstructor @ToString +@Getter @EqualsAndHashCode(callSuper = false) @RequiredArgsConstructor public class Relation extends UnresolvedPlan { private static final String COMMA = ","; - private final List tableName; - - public Relation(UnresolvedExpression tableName) { - this(tableName, null); - } - - public Relation(UnresolvedExpression tableName, String alias) { - this.tableName = Arrays.asList(tableName); - this.alias = alias; - } - - /** Optional alias name for the relation. */ - @Setter @Getter private String alias; - - /** - * Return table name. - * - * @return table name - */ - public List getTableName() { - return tableName.stream().map(Object::toString).collect(Collectors.toList()); - } + // A relation could contain more than one table/index names, such as + // source=account1, account2 + // source=`account1`,`account2` + // source=`account*` + // They translated into union call with fields. + private final List tableNames; public List getQualifiedNames() { - return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); + return tableNames.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } /** @@ -63,11 +44,11 @@ public List getQualifiedNames() { * @return TableQualifiedName. */ public QualifiedName getTableQualifiedName() { - if (tableName.size() == 1) { - return (QualifiedName) tableName.get(0); + if (tableNames.size() == 1) { + return (QualifiedName) tableNames.get(0); } else { return new QualifiedName( - tableName.stream() + tableNames.stream() .map(UnresolvedExpression::toString) .collect(Collectors.joining(COMMA))); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java index 29c3d4b90..ba66cca80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java @@ -6,19 +6,14 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import java.util.List; -import java.util.Objects; -@AllArgsConstructor @EqualsAndHashCode(callSuper = false) -@RequiredArgsConstructor @ToString public class SubqueryAlias extends UnresolvedPlan { @Getter private final String alias; @@ -32,6 +27,11 @@ public SubqueryAlias(UnresolvedPlan child, String suffix) { this.child = child; } + public SubqueryAlias(String alias, UnresolvedPlan child) { + this.alias = alias; + this.child = child; + } + public List getChild() { return ImmutableList.of(child); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 000000000..9fa1ae81d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final Optional sortByField; + private final List computations; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation { + + private final Integer numberOfDataPoints; + private final UnresolvedExpression dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, Trendline.TrendlineType computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = computationType; + } + + } + + public enum TrendlineType { + SMA + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 9e1a9a743..1959d0f6d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -64,9 +64,9 @@ public enum BuiltinFunctionName { DATE(FunctionName.of("date")), DATEDIFF(FunctionName.of("datediff")), // DATETIME(FunctionName.of("datetime")), -// DATE_ADD(FunctionName.of("date_add")), + DATE_ADD(FunctionName.of("date_add")), DATE_FORMAT(FunctionName.of("date_format")), -// DATE_SUB(FunctionName.of("date_sub")), + DATE_SUB(FunctionName.of("date_sub")), DAY(FunctionName.of("day")), // DAYNAME(FunctionName.of("dayname")), DAYOFMONTH(FunctionName.of("dayofmonth")), @@ -105,14 +105,15 @@ public enum BuiltinFunctionName { // TIMEDIFF(FunctionName.of("timediff")), // TIME_TO_SEC(FunctionName.of("time_to_sec")), TIMESTAMP(FunctionName.of("timestamp")), -// TIMESTAMPADD(FunctionName.of("timestampadd")), -// TIMESTAMPDIFF(FunctionName.of("timestampdiff")), + TIMESTAMPADD(FunctionName.of("timestampadd")), + TIMESTAMPDIFF(FunctionName.of("timestampdiff")), // TIME_FORMAT(FunctionName.of("time_format")), // TO_DAYS(FunctionName.of("to_days")), // TO_SECONDS(FunctionName.of("to_seconds")), // UTC_DATE(FunctionName.of("utc_date")), // UTC_TIME(FunctionName.of("utc_time")), -// UTC_TIMESTAMP(FunctionName.of("utc_timestamp")), + UTC_TIMESTAMP(FunctionName.of("utc_timestamp")), + CURRENT_TIMEZONE(FunctionName.of("current_timezone")), UNIX_TIMESTAMP(FunctionName.of("unix_timestamp")), WEEK(FunctionName.of("week")), WEEKDAY(FunctionName.of("weekday")), @@ -212,6 +213,7 @@ public enum BuiltinFunctionName { JSON_OBJECT(FunctionName.of("json_object")), JSON_ARRAY(FunctionName.of("json_array")), JSON_ARRAY_LENGTH(FunctionName.of("json_array_length")), + TO_JSON_STRING(FunctionName.of("to_json_string")), JSON_EXTRACT(FunctionName.of("json_extract")), JSON_KEYS(FunctionName.of("json_keys")), JSON_VALID(FunctionName.of("json_valid")), @@ -227,6 +229,14 @@ public enum BuiltinFunctionName { /** COLLECTION Functions **/ ARRAY(FunctionName.of("array")), + ARRAY_LENGTH(FunctionName.of("array_length")), + + /** LAMBDA Functions **/ + ARRAY_FORALL(FunctionName.of("forall")), + ARRAY_EXISTS(FunctionName.of("exists")), + ARRAY_FILTER(FunctionName.of("filter")), + ARRAY_TRANSFORM(FunctionName.of("transform")), + ARRAY_AGGREGATE(FunctionName.of("reduce")), /** NULL Test. */ IS_NULL(FunctionName.of("is null")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java new file mode 100644 index 000000000..2541b3743 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import inet.ipaddr.AddressStringException; +import inet.ipaddr.IPAddressString; +import inet.ipaddr.IPAddressStringParameters; +import scala.Function2; +import scala.Serializable; +import scala.runtime.AbstractFunction2; + + +public interface SerializableUdf { + + Function2 cidrFunction = new SerializableAbstractFunction2<>() { + + IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() + .allowEmpty(false) + .setEmptyAsLoopback(false) + .allow_inet_aton(false) + .allowSingleSegment(false) + .toParams(); + + @Override + public Boolean apply(String ipAddress, String cidrBlock) { + + IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions); + + try { + parsedIpAddress.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage()); + } + + IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions); + + try { + parsedCidrBlock.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage()); + } + + if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6."); + } + + return parsedCidrBlock.contains(parsedIpAddress); + } + }; + + abstract class SerializableAbstractFunction2 extends AbstractFunction2 + implements Serializable { + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java new file mode 100644 index 000000000..4c8d117b3 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -0,0 +1,475 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; +import org.apache.spark.sql.catalyst.expressions.Exists$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.In$; +import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LambdaFunction$; +import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.ListQuery$; +import org.apache.spark.sql.catalyst.expressions.MakeInterval$; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Between; +import org.opensearch.sql.ast.expression.BinaryExpression; +import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.FieldsMapping; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.LambdaFunction; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.When; +import org.opensearch.sql.ast.expression.WindowFunction; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FillNull; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.utils.AggregatorTransformer; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; +import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.JavaToScalaTransformer; +import scala.Option; +import scala.PartialFunction; +import scala.Tuple2; +import scala.collection.Seq; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Stack; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; +import static org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer.createIntervalArgs; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; +import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; + +/** + * Class of building catalyst AST Expression nodes. + */ +public class CatalystExpressionVisitor extends AbstractNodeVisitor { + + private final AbstractNodeVisitor planVisitor; + + public CatalystExpressionVisitor(AbstractNodeVisitor planVisitor) { + this.planVisitor = planVisitor; + } + + public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + return unresolved.accept(this, context); + } + + /** This method is only for analyze the join condition expression */ + public Expression analyzeJoinCondition(UnresolvedExpression unresolved, CatalystPlanContext context) { + return context.resolveJoinCondition(unresolved, this::analyze); + } + + @Override + public Expression visitLiteral(Literal node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( + translate(node.getValue(), node.getType()), translate(node.getType()))); + } + + /** + * generic binary (And, Or, Xor , ...) arithmetic expression resolver + * + * @param node + * @param transformer + * @param context + * @return + */ + public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Optional left = context.popNamedParseExpressions(); + node.getRight().accept(this, context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + return transformer.apply(left.get(), right.get()); + } else if (left.isPresent()) { + return context.getNamedParseExpressions().push(left.get()); + } else if (right.isPresent()) { + return context.getNamedParseExpressions().push(right.get()); + } + return null; + + } + + @Override + public Expression visitAnd(And node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); + } + + @Override + public Expression visitOr(Or node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); + } + + @Override + public Expression visitXor(Xor node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); + } + + @Override + public Expression visitNot(Not node, CatalystPlanContext context) { + node.getExpression().accept(this, context); + Optional arg = context.popNamedParseExpressions(); + return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); + } + + @Override + public Expression visitSpan(Span node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression field = (Expression) context.popNamedParseExpressions().get(); + node.getValue().accept(this, context); + Expression value = (Expression) context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); + } + + @Override + public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression arg = (Expression) context.popNamedParseExpressions().get(); + Expression aggregator = AggregatorTransformer.aggregator(node, arg); + return context.getNamedParseExpressions().push(aggregator); + } + + @Override + public Expression visitCompare(Compare node, CatalystPlanContext context) { + analyze(node.getLeft(), context); + Optional left = context.popNamedParseExpressions(); + analyze(node.getRight(), context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + } + return null; + } + + @Override + public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + // When the qualified name is part of join condition, for example: table1.id = table2.id + // findRelation(context.traversalContext() only returns relation table1 which cause table2.id fail to resolve + if (context.isResolvingJoinCondition()) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + List relation = findRelation(context.traversalContext()); + if (!relation.isEmpty()) { + Optional resolveField = resolveField(relation, node, context.getRelations()); + return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) + .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); + } + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + + /** + * Resolve the qualified name with subquery alias:
+ * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
+ * - tableName1.joinKey = subqueryAlias2.joinKey
+ * - subqueryAlias1.joinKey = tableName2.joinKey
+ */ + private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { + if (node.getPrefix().isPresent() && + context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { + if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) + .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) + .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + } + return null; + } + + @Override + public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { + return node.getChild().stream().map(expression -> + visitCompare((Compare) expression, context) + ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + + @Override + public Expression visitAllFields(AllFields node, CatalystPlanContext context) { + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + return context.getNamedParseExpressions().peek(); + } + + @Override + public Expression visitAlias(Alias node, CatalystPlanContext context) { + node.getDelegated().accept(this, context); + Expression arg = context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, + node.getAlias() != null ? node.getAlias() : node.getName(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + + @Override + public Expression visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public Expression visitFunction(Function node, CatalystPlanContext context) { + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTransformer.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); + } + + @Override + public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { + Stack namedParseExpressions = new Stack<>(); + namedParseExpressions.addAll(context.getNamedParseExpressions()); + Expression expression = visitCase(node.getCaseValue(), context); + namedParseExpressions.add(expression); + context.setNamedParseExpressions(namedParseExpressions); + return expression; + } + + @Override + public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : FillNull"); + } + + @Override + public Expression visitInterval(Interval node, CatalystPlanContext context) { + node.getValue().accept(this, context); + Expression value = context.getNamedParseExpressions().pop(); + Expression[] intervalArgs = createIntervalArgs(node.getUnit(), value); + Expression interval = MakeInterval$.MODULE$.apply( + intervalArgs[0], intervalArgs[1], intervalArgs[2], intervalArgs[3], + intervalArgs[4], intervalArgs[5], intervalArgs[6], true); + return context.getNamedParseExpressions().push(interval); + } + + @Override + public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Dedupe"); + } + + @Override + public Expression visitIn(In node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression value = context.popNamedParseExpressions().get(); + List list = node.getValueList().stream().map( expression -> { + expression.accept(this, context); + return context.popNamedParseExpressions().get(); + }).collect(Collectors.toList()); + return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); + } + + @Override + public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public Expression visitCase(Case node, CatalystPlanContext context) { + Stack initialNameExpressions = new Stack<>(); + initialNameExpressions.addAll(context.getNamedParseExpressions()); + analyze(node.getElseClause(), context); + Expression elseValue = context.getNamedParseExpressions().pop(); + List> whens = new ArrayList<>(); + for (When when : node.getWhenClauses()) { + if (node.getCaseValue() == null) { + whens.add( + new Tuple2<>( + analyze(when.getCondition(), context), + analyze(when.getResult(), context) + ) + ); + } else { + // Merge case value and condition (compare value) into a single equal condition + Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); + whens.add( + new Tuple2<>( + analyze(compare, context), analyze(when.getResult(), context) + ) + ); + } + context.retainAllNamedParseExpressions(e -> e); + } + context.setNamedParseExpressions(initialNameExpressions); + return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); + } + + @Override + public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + + @Override + public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + visitExpressionList(node.getChild(), innerContext); + Seq values = innerContext.retainAllNamedParseExpressions(p -> p); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression inSubQuery = InSubquery$.MODULE$.apply( + values, + ListQuery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + -1, + seq(new java.util.ArrayList()), + Option.empty())); + return outerContext.getNamedParseExpressions().push(inSubQuery); + } + + @Override + public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + Option.empty()); + return context.getNamedParseExpressions().push(scalarSubQuery); + } + + @Override + public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression existsSubQuery = Exists$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty()); + return context.getNamedParseExpressions().push(existsSubQuery); + } + + @Override + public Expression visitBetween(Between node, CatalystPlanContext context) { + Expression value = analyze(node.getValue(), context); + Expression lower = analyze(node.getLowerBound(), context); + Expression upper = analyze(node.getUpperBound(), context); + context.retainAllNamedParseExpressions(p -> p); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); + } + + @Override + public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) { + analyze(node.getIpAddress(), context); + Expression ipAddressExpression = context.getNamedParseExpressions().pop(); + analyze(node.getCidrBlock(), context); + Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); + + ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddressExpression,cidrBlockExpression), + seq(), + Option.empty(), + Option.apply("cidr"), + false, + true); + + return context.getNamedParseExpressions().push(udf); + } + + @Override + public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext context) { + PartialFunction transformer = JavaToScalaTransformer.toPartialFunction( + expr -> expr instanceof UnresolvedAttribute, + expr -> { + UnresolvedAttribute attr = (UnresolvedAttribute) expr; + return new UnresolvedNamedLambdaVariable(attr.nameParts()); + } + ); + Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer); + context.popNamedParseExpressions(); + List argsResult = node.getFuncArgs().stream() + .map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts()))) + .collect(Collectors.toList()); + return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false)); + } + + private List visitExpressionList(List expressionList, CatalystPlanContext context) { + return expressionList.isEmpty() + ? emptyList() + : expressionList.stream().map(field -> analyze(field, context)) + .collect(Collectors.toList()); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 61762f616..53dc17576 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl; +import lombok.Getter; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -39,19 +40,19 @@ public class CatalystPlanContext { /** * Catalyst relations list **/ - private List projectedFields = new ArrayList<>(); + @Getter private List projectedFields = new ArrayList<>(); /** * Catalyst relations list **/ - private List relations = new ArrayList<>(); + @Getter private List relations = new ArrayList<>(); /** * Catalyst SubqueryAlias list **/ - private List subqueryAlias = new ArrayList<>(); + @Getter private List subqueryAlias = new ArrayList<>(); /** * Catalyst evolving logical plan **/ - private Stack planBranches = new Stack<>(); + @Getter private Stack planBranches = new Stack<>(); /** * The current traversal context the visitor is going threw */ @@ -60,28 +61,12 @@ public class CatalystPlanContext { /** * NamedExpression contextual parameters **/ - private final Stack namedParseExpressions = new Stack<>(); + @Getter private final Stack namedParseExpressions = new Stack<>(); /** * Grouping NamedExpression contextual parameters **/ - private final Stack groupingParseExpressions = new Stack<>(); - - public Stack getPlanBranches() { - return planBranches; - } - - public List getRelations() { - return relations; - } - - public List getSubqueryAlias() { - return subqueryAlias; - } - - public List getProjectedFields() { - return projectedFields; - } + @Getter private final Stack groupingParseExpressions = new Stack<>(); public LogicalPlan getPlan() { if (this.planBranches.isEmpty()) return null; @@ -101,10 +86,6 @@ public Stack traversalContext() { return planTraversalContext; } - public Stack getNamedParseExpressions() { - return namedParseExpressions; - } - public void setNamedParseExpressions(Stack namedParseExpressions) { this.namedParseExpressions.clear(); this.namedParseExpressions.addAll(namedParseExpressions); @@ -114,10 +95,6 @@ public Optional popNamedParseExpressions() { return namedParseExpressions.isEmpty() ? Optional.empty() : Optional.of(namedParseExpressions.pop()); } - public Stack getGroupingParseExpressions() { - return groupingParseExpressions; - } - /** * define new field * @@ -154,13 +131,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + public LogicalPlan applyBranches(List> plans) { plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); planBranches.remove(0); return getPlan(); - } - + } + /** * append plan with evolving plans branches * @@ -288,4 +265,21 @@ public static Optional findRelation(LogicalPlan plan) { return Optional.empty(); } + @Getter private boolean isResolvingJoinCondition = false; + + /** + * Resolve the join condition with the given function. + * A flag will be set to true ahead expression resolving, then false after resolving. + * @param expr + * @param transformFunction + * @return + */ + public Expression resolveJoinCondition( + UnresolvedExpression expr, + BiFunction transformFunction) { + isResolvingJoinCondition = true; + Expression result = transformFunction.apply(expr, this); + isResolvingJoinCondition = false; + return result; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 441287ddb..d2ee46ae6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -6,61 +6,48 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.TableIdentifier; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; -import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; -import org.apache.spark.sql.catalyst.expressions.Exists$; +import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; +import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.plans.logical.*; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Generate; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.AllFields; -import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Between; -import org.opensearch.sql.ast.expression.BinaryExpression; -import org.opensearch.sql.ast.expression.Case; -import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; -import org.opensearch.sql.ast.expression.subquery.InSubquery; -import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; -import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.WindowFunction; -import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; @@ -72,6 +59,7 @@ import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Flatten; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; @@ -85,31 +73,29 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; -import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.ppl.utils.AggregatorTranslator; -import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; -import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; -import org.opensearch.sql.ppl.utils.ParseStrategy; +import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; +import scala.None$; import scala.Option; -import scala.Tuple2; import scala.collection.IterableLike; import scala.collection.Seq; -import java.util.*; -import java.util.function.BiFunction; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.List.of; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent; @@ -121,8 +107,6 @@ import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList; import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields; import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; -import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; -import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; import static scala.collection.JavaConverters.seqAsJavaList; /** @@ -130,20 +114,16 @@ */ public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { - private final ExpressionAnalyzer expressionAnalyzer; + private final CatalystExpressionVisitor expressionAnalyzer; public CatalystQueryPlanVisitor() { - this.expressionAnalyzer = new ExpressionAnalyzer(); + this.expressionAnalyzer = new CatalystExpressionVisitor(this); } public LogicalPlan visit(Statement plan, CatalystPlanContext context) { return plan.accept(this, context); } - - public LogicalPlan visitSubSearch(UnresolvedPlan plan, CatalystPlanContext context) { - return plan.accept(this, context); - } - + /** * Handle Query Statement. */ @@ -248,6 +228,30 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { }); } + @Override + public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + + node.getSortByField() + .ifPresent(sortField -> { + Expression sortFieldExpression = visitExpression(sortField, context); + Seq sortOrder = context + .retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(sortFieldExpression, SortUtils.isSortedAscending(sortField))); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortOrder, true, p)); + }); + + List trendlineProjectExpressions = new ArrayList<>(); + + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); + } + + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); + } + @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); @@ -271,7 +275,8 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); - Optional joinCondition = node.getJoinCondition().map(c -> visitExpression(c, context)); + Optional joinCondition = node.getJoinCondition() + .map(c -> expressionAnalyzer.analyzeJoinCondition(c, context)); context.retainAllNamedParseExpressions(p -> p); context.retainAllPlans(p -> p); return join(left, right, node.getJoinType(), joinCondition, node.getJoinHint()); @@ -450,6 +455,41 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); } + @Override + public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { + flatten.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(flatten.getField(), context); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + FlattenGenerator flattenGenerator = new FlattenGenerator(field); + context.apply(p -> new Generate(new GeneratorOuter(flattenGenerator), seq(), true, (Option) None$.MODULE$, seq(), p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + + @Override + public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(node.getField(), context); + Optional alias = node.getAlias().map(aliasNode -> visitExpression(aliasNode, context)); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + Explode explodeGenerator = new Explode(field); + scala.collection.mutable.Seq outputs = alias.isEmpty() ? seq() : seq(alias.get()); + if(alias.isEmpty()) + return context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); + else { + //in case an alias does appear - remove the original field from the returning columns + context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } @@ -472,7 +512,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); String pattern = (String) node.getPattern().getValue(); - return ParseStrategy.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); + return ParseTransformer.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); } @Override @@ -566,318 +606,4 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { } } } - - /** - * Expression Analyzer. - */ - public class ExpressionAnalyzer extends AbstractNodeVisitor { - - public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { - return unresolved.accept(this, context); - } - - @Override - public Expression visitLiteral(Literal node, CatalystPlanContext context) { - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( - translate(node.getValue(), node.getType()), translate(node.getType()))); - } - - /** - * generic binary (And, Or, Xor , ...) arithmetic expression resolver - * - * @param node - * @param transformer - * @param context - * @return - */ - public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { - node.getLeft().accept(this, context); - Optional left = context.popNamedParseExpressions(); - node.getRight().accept(this, context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - return transformer.apply(left.get(), right.get()); - } else if (left.isPresent()) { - return context.getNamedParseExpressions().push(left.get()); - } else if (right.isPresent()) { - return context.getNamedParseExpressions().push(right.get()); - } - return null; - - } - - @Override - public Expression visitAnd(And node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); - } - - @Override - public Expression visitOr(Or node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); - } - - @Override - public Expression visitXor(Xor node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); - } - - @Override - public Expression visitNot(Not node, CatalystPlanContext context) { - node.getExpression().accept(this, context); - Optional arg = context.popNamedParseExpressions(); - return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); - } - - @Override - public Expression visitSpan(Span node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression field = (Expression) context.popNamedParseExpressions().get(); - node.getValue().accept(this, context); - Expression value = (Expression) context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); - } - - @Override - public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression arg = (Expression) context.popNamedParseExpressions().get(); - Expression aggregator = AggregatorTranslator.aggregator(node, arg); - return context.getNamedParseExpressions().push(aggregator); - } - - @Override - public Expression visitCompare(Compare node, CatalystPlanContext context) { - analyze(node.getLeft(), context); - Optional left = context.popNamedParseExpressions(); - analyze(node.getRight(), context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); - return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); - } - return null; - } - - @Override - public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { - List relation = findRelation(context.traversalContext()); - if (!relation.isEmpty()) { - Optional resolveField = resolveField(relation, node, context.getRelations()); - return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) - .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); - } - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - - /** - * Resolve the qualified name with subquery alias:
- * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
- * - tableName1.joinKey = subqueryAlias2.joinKey
- * - subqueryAlias1.joinKey = tableName2.joinKey
- */ - private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { - if (node.getPrefix().isPresent() && - context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { - if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) - .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) - .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - } - return null; - } - - @Override - public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { - return node.getChild().stream().map(expression -> - visitCompare((Compare) expression, context) - ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); - } - - @Override - public Expression visitAllFields(AllFields node, CatalystPlanContext context) { - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - return context.getNamedParseExpressions().peek(); - } - - @Override - public Expression visitAlias(Alias node, CatalystPlanContext context) { - node.getDelegated().accept(this, context); - Expression arg = context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, - node.getAlias() != null ? node.getAlias() : node.getName(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - } - - @Override - public Expression visitEval(Eval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Eval"); - } - - @Override - public Expression visitFunction(Function node, CatalystPlanContext context) { - List arguments = - node.getFuncArgs().stream() - .map( - unresolvedExpression -> { - var ret = analyze(unresolvedExpression, context); - if (ret == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", unresolvedExpression)); - } else { - return context.popNamedParseExpressions().get(); - } - }) - .collect(Collectors.toList()); - Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); - return context.getNamedParseExpressions().push(function); - } - - @Override - public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { - Stack namedParseExpressions = new Stack<>(); - namedParseExpressions.addAll(context.getNamedParseExpressions()); - Expression expression = visitCase(node.getCaseValue(), context); - namedParseExpressions.add(expression); - context.setNamedParseExpressions(namedParseExpressions); - return expression; - } - - @Override - public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : FillNull"); - } - - @Override - public Expression visitInterval(Interval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Interval"); - } - - @Override - public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Dedupe"); - } - - @Override - public Expression visitIn(In node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression value = context.popNamedParseExpressions().get(); - List list = node.getValueList().stream().map( expression -> { - expression.accept(this, context); - return context.popNamedParseExpressions().get(); - }).collect(Collectors.toList()); - return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); - } - - @Override - public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Kmeans"); - } - - @Override - public Expression visitCase(Case node, CatalystPlanContext context) { - Stack initialNameExpressions = new Stack<>(); - initialNameExpressions.addAll(context.getNamedParseExpressions()); - analyze(node.getElseClause(), context); - Expression elseValue = context.getNamedParseExpressions().pop(); - List> whens = new ArrayList<>(); - for (When when : node.getWhenClauses()) { - if (node.getCaseValue() == null) { - whens.add( - new Tuple2<>( - analyze(when.getCondition(), context), - analyze(when.getResult(), context) - ) - ); - } else { - // Merge case value and condition (compare value) into a single equal condition - Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); - whens.add( - new Tuple2<>( - analyze(compare, context), analyze(when.getResult(), context) - ) - ); - } - context.retainAllNamedParseExpressions(e -> e); - } - context.setNamedParseExpressions(initialNameExpressions); - return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); - } - - @Override - public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : RareTopN"); - } - - @Override - public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : WindowFunction"); - } - - @Override - public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - visitExpressionList(node.getChild(), innerContext); - Seq values = innerContext.retainAllNamedParseExpressions(p -> p); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression inSubQuery = InSubquery$.MODULE$.apply( - values, - ListQuery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - -1, - seq(new java.util.ArrayList()), - Option.empty())); - return outerContext.getNamedParseExpressions().push(inSubQuery); - } - - @Override - public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - Option.empty()); - return context.getNamedParseExpressions().push(scalarSubQuery); - } - - @Override - public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression existsSubQuery = Exists$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty()); - return context.getNamedParseExpressions().push(existsSubQuery); - } - - @Override - public Expression visitBetween(Between node, CatalystPlanContext context) { - Expression value = analyze(node.getValue(), context); - Expression lower = analyze(node.getLowerBound(), context); - Expression upper = analyze(node.getUpperBound(), context); - context.retainAllNamedParseExpressions(p -> p); - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); - } - } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ed7717188..f6581016f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -82,8 +82,8 @@ public class AstBuilder extends OpenSearchPPLParserBaseVisitor { */ private String query; - public AstBuilder(AstExpressionBuilder expressionBuilder, String query) { - this.expressionBuilder = expressionBuilder; + public AstBuilder(String query) { + this.expressionBuilder = new AstExpressionBuilder(this); this.query = query; } @@ -131,6 +131,12 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext return new Filter(internalVisitExpression(ctx.logicalExpression())); } + @Override + public UnresolvedPlan visitExpandCommand(OpenSearchPPLParser.ExpandCommandContext ctx) { + return new Expand((Field) internalVisitExpression(ctx.fieldExpression()), + ctx.alias!=null ? Optional.of(internalVisitExpression(ctx.alias)) : Optional.empty()); + } + @Override public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { return new Correlation(ctx.correlationType().getText(), @@ -155,14 +161,25 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct joinType = Join.JoinType.CROSS; } Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - String leftAlias = ctx.sideAlias().leftAlias.getText(); - String rightAlias = ctx.sideAlias().rightAlias.getText(); + Optional leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty(); + Optional rightAlias = Optional.empty(); if (ctx.tableOrSubqueryClause().alias != null) { - // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous - throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText()); } + if (ctx.sideAlias().rightAlias != null) { + rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText()); + } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); - UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); + // Add a SubqueryAlias to the right plan when the right alias is present and no duplicated alias existing in right. + UnresolvedPlan right; + if (rightAlias.isEmpty() || + (rightRelation instanceof SubqueryAlias && + rightAlias.get().equals(((SubqueryAlias) rightRelation).getAlias()))) { + right = rightRelation; + } else { + right = new SubqueryAlias(rightAlias.get(), rightRelation); + } Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -370,7 +387,7 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo /** Lookup command */ @Override public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { - Relation lookupRelation = new Relation(this.internalVisitExpression(ctx.tableSource())); + Relation lookupRelation = new Relation(Collections.singletonList(this.internalVisitExpression(ctx.tableSource()))); Lookup.OutputStrategy strategy = ctx.APPEND() != null ? Lookup.OutputStrategy.APPEND : Lookup.OutputStrategy.REPLACE; java.util.Map lookupMappingList = buildLookupPair(ctx.lookupMappingList().lookupPair()); @@ -386,6 +403,30 @@ private java.util.Map buildLookupPair(List (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); } + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = ctx.trendlineClause() + .stream() + .map(this::toTrendlineComputation) + .collect(Collectors.toList()); + return Optional.ofNullable(ctx.sortField()) + .map(this::internalVisitExpression) + .map(Field.class::cast) + .map(sort -> new Trendline(Optional.of(sort), trendlineComputations)) + .orElse(new Trendline(Optional.empty(), trendlineComputations)); + } + + private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParser.TrendlineClauseContext ctx) { + int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + if (numberOfDataPoints < 1) { + throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); + } + Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field); + String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText(); + String computationType = ctx.trendlineType().getText(); + return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase())); + } + /** Top command. */ @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { @@ -485,9 +526,8 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return ctx.alias == null - ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) - : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); + Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation; } @Override @@ -562,6 +602,12 @@ public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandCo } } + @Override + public UnresolvedPlan visitFlattenCommand(OpenSearchPPLParser.FlattenCommandContext ctx) { + Field unresolvedExpression = (Field) internalVisitExpression(ctx.fieldExpression()); + return new Flatten(unresolvedExpression); + } + /** AD command. */ @Override public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6a0c80c16..4b7c8a1c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -19,29 +19,33 @@ import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cidr; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; -import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -66,15 +70,6 @@ */ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor { - private static final int DEFAULT_TAKE_FUNCTION_SIZE_VALUE = 10; - - private AstBuilder astBuilder; - - /** Set AstBuilder back to AstExpressionBuilder for resolving the subquery plan in subquery expression */ - public void setAstBuilder(AstBuilder astBuilder) { - this.astBuilder = astBuilder; - } - /** * The function name mapping between fronted and core engine. */ @@ -84,7 +79,12 @@ public void setAstBuilder(AstBuilder astBuilder) { .put("isnotnull", IS_NOT_NULL.getName().getFunctionName()) .put("ispresent", IS_NOT_NULL.getName().getFunctionName()) .build(); + private AstBuilder astBuilder; + public AstExpressionBuilder(AstBuilder astBuilder) { + this.astBuilder = astBuilder; + } + @Override public UnresolvedExpression visitMappingCompareExpr(OpenSearchPPLParser.MappingCompareExprContext ctx) { return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right)); @@ -154,7 +154,7 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont @Override public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) { return new Function( - ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); + ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); } @Override @@ -245,7 +245,7 @@ public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ct }) .collect(Collectors.toList()); UnresolvedExpression elseValue = new Literal(null, DataType.NULL); - if(ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { + if (ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { // else value is present elseValue = visit(ctx.caseFunction().valueExpression(ctx.caseFunction().valueExpression().size() - 1)); } @@ -290,9 +290,6 @@ private Function buildFunction( functionName, args.stream().map(this::visitFunctionArg).collect(Collectors.toList())); } - public AstExpressionBuilder() { - } - @Override public UnresolvedExpression visitMultiFieldRelevanceFunction( OpenSearchPPLParser.MultiFieldRelevanceFunctionContext ctx) { @@ -306,7 +303,7 @@ public UnresolvedExpression visitTableSource(OpenSearchPPLParser.TableSourceCont if (ctx.getChild(0) instanceof OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) { return visitIdentsAsTableQualifiedName((OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) ctx.getChild(0)); } else { - return visitIdentifiers(Arrays.asList(ctx)); + return visitIdentifiers(List.of(ctx)); } } @@ -398,9 +395,9 @@ public UnresolvedExpression visitRightHint(OpenSearchPPLParser.RightHintContext @Override public UnresolvedExpression visitInSubqueryExpr(OpenSearchPPLParser.InSubqueryExprContext ctx) { UnresolvedExpression expr = new InSubquery( - ctx.valueExpressionList().valueExpression().stream() - .map(this::visit).collect(Collectors.toList()), - astBuilder.visitSubSearch(ctx.subSearch())); + ctx.valueExpressionList().valueExpression().stream() + .map(this::visit).collect(Collectors.toList()), + astBuilder.visitSubSearch(ctx.subSearch())); return ctx.NOT() != null ? new Not(expr) : expr; } @@ -421,6 +418,37 @@ public UnresolvedExpression visitInExpr(OpenSearchPPLParser.InExprContext ctx) { return ctx.NOT() != null ? new Not(expr) : expr; } + @Override + public UnresolvedExpression visitCidrMatchFunctionCall(OpenSearchPPLParser.CidrMatchFunctionCallContext ctx) { + return new Cidr(visit(ctx.ipAddress), visit(ctx.cidrBlock)); + } + + @Override + public UnresolvedExpression visitTimestampFunctionCall( + OpenSearchPPLParser.TimestampFunctionCallContext ctx) { + return new Function( + ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx)); + } + + @Override + public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { + + List arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect( + Collectors.toList()); + UnresolvedExpression function = visitExpression(ctx.expression()); + return new LambdaFunction(function, arguments); + } + + private List timestampFunctionArguments( + OpenSearchPPLParser.TimestampFunctionCallContext ctx) { + List args = + Arrays.asList( + new Literal(ctx.timestampFunction().simpleDateTimePart().getText(), DataType.STRING), + visitFunctionArg(ctx.timestampFunction().firstArg), + visitFunctionArg(ctx.timestampFunction().secondArg)); + return args; + } + private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index 6a545f091..44610f3a4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -56,6 +56,7 @@ public StatementBuilderContext getContext() { } public static class StatementBuilderContext { + public static final int FETCH_SIZE = 1000; private int fetchSize; public StatementBuilderContext(int fetchSize) { @@ -63,8 +64,7 @@ public StatementBuilderContext(int fetchSize) { } public static StatementBuilderContext builder() { - //todo set the default statement builder init params configurable - return new StatementBuilderContext(1000); + return new StatementBuilderContext(FETCH_SIZE); } public int getFetchSize() { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java similarity index 97% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java index a01b38a80..9788ac1bc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java @@ -12,12 +12,10 @@ import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; -import java.util.Optional; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -27,7 +25,7 @@ * * @return */ -public interface AggregatorTranslator { +public interface AggregatorTransformer { static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java similarity index 68% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index 8982fe859..0b0fb8314 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -8,17 +8,30 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction$; +import org.apache.spark.sql.catalyst.expressions.CurrentTimeZone$; +import org.apache.spark.sql.catalyst.expressions.CurrentTimestamp$; +import org.apache.spark.sql.catalyst.expressions.DateAddInterval$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.TimestampAdd$; +import org.apache.spark.sql.catalyst.expressions.TimestampDiff$; +import org.apache.spark.sql.catalyst.expressions.ToUTCTimestamp$; +import org.apache.spark.sql.catalyst.expressions.UnaryMinus$; +import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import scala.Option; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.function.Function; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_SUB; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_MONTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.COALESCE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON; @@ -44,13 +57,17 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.SECOND_OF_MINUTE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TO_JSON_STRING; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_TIMESTAMP; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK_OF_YEAR; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; -public interface BuiltinFunctionTranslator { +public interface BuiltinFunctionTransformer { /** * The name mapping between PPL builtin functions to Spark builtin functions. @@ -87,7 +104,9 @@ public interface BuiltinFunctionTranslator { .put(COALESCE, "coalesce") .put(LENGTH, "length") .put(TRIM, "trim") + .put(ARRAY_LENGTH, "array_size") // json functions + .put(TO_JSON_STRING, "to_json") .put(JSON_KEYS, "json_object_keys") .put(JSON_EXTRACT, "get_json_object") .build(); @@ -95,8 +114,8 @@ public interface BuiltinFunctionTranslator { /** * The name mapping between PPL builtin functions to Spark builtin functions. */ - static final Map, UnresolvedFunction>> PPL_TO_SPARK_FUNC_MAPPING - = ImmutableMap., UnresolvedFunction>>builder() + static final Map, Expression>> PPL_TO_SPARK_FUNC_MAPPING + = ImmutableMap., Expression>>builder() // json functions .put( JSON_ARRAY, @@ -111,26 +130,12 @@ public interface BuiltinFunctionTranslator { .put( JSON_ARRAY_LENGTH, args -> { - // Check if the input is an array (from json_array()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - // Input is a JSON array - return UnresolvedFunction$.MODULE$.apply("json_array_length", - seq(UnresolvedFunction$.MODULE$.apply("to_json", seq(args), false)), false); - } else { - // Input is a JSON string - return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); - } + return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); }) .put( JSON, args -> { - // Check if the input is a named_struct (from json_object()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - return UnresolvedFunction$.MODULE$.apply("to_json", seq(args.get(0)), false); - } else { - return UnresolvedFunction$.MODULE$.apply("get_json_object", - seq(args.get(0), Literal$.MODULE$.apply("$")), false); - } + return UnresolvedFunction$.MODULE$.apply("get_json_object", seq(args.get(0), Literal$.MODULE$.apply("$")), false); }) .put( JSON_VALID, @@ -139,6 +144,31 @@ public interface BuiltinFunctionTranslator { seq(UnresolvedFunction$.MODULE$.apply("get_json_object", seq(args.get(0), Literal$.MODULE$.apply("$")), false)), false); }) + .put( + DATE_ADD, + args -> { + return DateAddInterval$.MODULE$.apply(args.get(0), args.get(1), Option.empty(), false); + }) + .put( + DATE_SUB, + args -> { + return DateAddInterval$.MODULE$.apply(args.get(0), UnaryMinus$.MODULE$.apply(args.get(1), true), Option.empty(), true); + }) + .put( + TIMESTAMPADD, + args -> { + return TimestampAdd$.MODULE$.apply(args.get(0).toString(), args.get(1), args.get(2), Option.empty()); + }) + .put( + TIMESTAMPDIFF, + args -> { + return TimestampDiff$.MODULE$.apply(args.get(0).toString(), args.get(1), args.get(2), Option.empty()); + }) + .put( + UTC_TIMESTAMP, + args -> { + return ToUTCTimestamp$.MODULE$.apply(CurrentTimestamp$.MODULE$.apply(), CurrentTimeZone$.MODULE$.apply()); + }) .build(); static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { @@ -153,7 +183,7 @@ static Expression builtinFunction(org.opensearch.sql.ast.expression.Function fun // there is a Spark builtin function mapping with the PPL builtin function return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); } - Function, UnresolvedFunction> alternative = PPL_TO_SPARK_FUNC_MAPPING.get(builtin); + Function, Expression> alternative = PPL_TO_SPARK_FUNC_MAPPING.get(builtin); if (alternative != null) { return alternative.apply(args); } @@ -161,4 +191,21 @@ static Expression builtinFunction(org.opensearch.sql.ast.expression.Function fun return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); } } + + static Expression[] createIntervalArgs(IntervalUnit unit, Expression value) { + Expression[] args = new Expression[7]; + Arrays.fill(args, Literal$.MODULE$.apply(0)); + switch (unit) { + case YEAR: args[0] = value; break; + case MONTH: args[1] = value; break; + case WEEK: args[2] = value; break; + case DAY: args[3] = value; break; + case HOUR: args[4] = value; break; + case MINUTE: args[5] = value; break; + case SECOND: args[6] = value; break; + default: + throw new IllegalArgumentException("Unsupported Interval unit: " + unit); + } + return args; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 62eef90ed..e4defad52 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -14,16 +14,14 @@ import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.NullType$; import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -67,6 +65,8 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; + case UNDEFINED: + return NullType$.MODULE$; default: return StringType$.MODULE$; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java index 0866ca7e9..a34fc0184 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java @@ -16,8 +16,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; -import org.opensearch.sql.ppl.CatalystQueryPlanVisitor.ExpressionAnalyzer; import scala.collection.Seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; @@ -38,7 +38,7 @@ public interface DedupeTransformer { static LogicalPlan retainOneDuplicateEventAndKeepEmpty( Dedupe node, Seq dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); @@ -63,7 +63,7 @@ static LogicalPlan retainOneDuplicateEventAndKeepEmpty( static LogicalPlan retainOneDuplicateEvent( Dedupe node, Seq dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); @@ -87,7 +87,7 @@ static LogicalPlan retainOneDuplicateEvent( static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { // Build isnull Filter for right @@ -137,7 +137,7 @@ static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( static LogicalPlan retainMultipleDuplicateEvents( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // Build isnotnull Filter Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); @@ -163,7 +163,7 @@ static LogicalPlan retainMultipleDuplicateEvents( return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); } - private static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNotNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNotNullExpressions = context.retainAllNamedParseExpressions( @@ -180,7 +180,7 @@ private static Expression buildIsNotNullFilterExpression(Dedupe node, Expression return isNotNullExpr; } - private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNullExpressions = context.retainAllNamedParseExpressions( diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java new file mode 100644 index 000000000..40246d7c9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + + +import scala.PartialFunction; +import scala.runtime.AbstractPartialFunction; + +public interface JavaToScalaTransformer { + static PartialFunction toPartialFunction( + java.util.function.Predicate isDefinedAt, + java.util.function.Function apply) { + return new AbstractPartialFunction() { + @Override + public boolean isDefinedAt(T t) { + return isDefinedAt.test(t); + } + + @Override + public T apply(T t) { + if (isDefinedAt.test(t)) return apply.apply(t); + else return t; + } + }; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java index 58ef15ea9..3673d96d6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.tree.Lookup; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; import org.opensearch.sql.ppl.CatalystQueryPlanVisitor; import scala.Option; @@ -32,7 +33,7 @@ public interface LookupTransformer { /** lookup mapping fields + input fields*/ static List buildLookupRelationProjectList( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List inputFields = new ArrayList<>(node.getInputFieldList()); if (inputFields.isEmpty()) { @@ -45,7 +46,7 @@ static List buildLookupRelationProjectList( static List buildProjectListFromFields( List fields, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { return fields.stream().map(field -> expressionAnalyzer.visitField(field, context)) .map(NamedExpression.class::cast) @@ -54,7 +55,7 @@ static List buildProjectListFromFields( static Expression buildLookupMappingCondition( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // only equi-join conditions are accepted in lookup command List equiConditions = new ArrayList<>(); @@ -81,7 +82,7 @@ static Expression buildLookupMappingCondition( static List buildOutputProjectList( Lookup node, Lookup.OutputStrategy strategy, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List outputProjectList = new ArrayList<>(); for (Map.Entry entry : node.getOutputCandidateMap().entrySet()) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java similarity index 97% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java index 8775d077b..eed7db228 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java @@ -8,7 +8,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Field; @@ -27,7 +26,7 @@ import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -public interface ParseStrategy { +public interface ParseTransformer { /** * transform the parse/grok/patterns command into a standard catalyst RegExpExtract expression * Since spark's RegExpExtract cant accept actual regExp group name we need to translate the group's name into its corresponding index diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java index 83603b031..803daea8b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java @@ -38,7 +38,7 @@ static SortOrder getSortDirection(Sort node, NamedExpression expression) { .findAny(); return field.map(value -> sortOrder((Expression) expression, - (Boolean) value.getFieldArgs().get(0).getValue().getValue())) + isSortedAscending(value))) .orElse(null); } @@ -51,4 +51,8 @@ static SortOrder sortOrder(Expression expression, boolean ascending) { seq(new ArrayList()) ); } + + static boolean isSortedAscending(Field field) { + return (Boolean) field.getFieldArgs().get(0).getValue().getValue(); + } } \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java new file mode 100644 index 000000000..67603ccc7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.*; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; +import scala.Tuple2; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +public interface TrendlineCatalystUtils { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + return computations.stream() + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .collect(Collectors.toList()); + } + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //window lower boundary + expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); + Expression windowLowerBoundary = context.popNamedParseExpressions().get(); + + //window definition + WindowSpecDefinition windowDefinition = new WindowSpecDefinition( + seq(), + seq(), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + + if (node.getComputationType() == Trendline.TrendlineType.SMA) { + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, + node.getAlias(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList())); + } else { + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + } + } + + private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpressionVisitor expressionVisitor, WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //required number of data points + expressionVisitor.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context); + Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get(); + + //count data points function + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context); + Expression countDataPointsFunction = context.popNamedParseExpressions().get(); + //count data points window + WindowExpression countDataPointsWindow = new WindowExpression( + countDataPointsFunction, + trendlineWindow.windowSpec()); + + expressionVisitor.visitLiteral(new Literal(null, DataType.NULL), context); + Expression nullLiteral = context.popNamedParseExpressions().get(); + Tuple2 nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>( + new LessThan(countDataPointsWindow, requiredNumberOfDataPoints), + nullLiteral + ); + return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala new file mode 100644 index 000000000..23b545826 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, CreateArray, Expression, GenericInternalRow, Inline, UnaryExpression} +import org.apache.spark.sql.types.{ArrayType, StructType} + +class FlattenGenerator(override val child: Expression) + extends Inline(child) + with CollectionGenerator { + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case st: StructType => TypeCheckResult.TypeCheckSuccess + case _ => super.checkInputDataTypes() + } + + override def elementSchema: StructType = child.dataType match { + case st: StructType => st + case _ => super.elementSchema + } + + override protected def withNewChildInternal(newChild: Expression): FlattenGenerator = { + newChild.dataType match { + case ArrayType(st: StructType, _) => new FlattenGenerator(newChild) + case st: StructType => withNewChildInternal(CreateArray(Seq(newChild), false)) + case _ => + throw new IllegalArgumentException(s"Unexpected input type ${newChild.dataType}") + } + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala index c435af53d..ed498e98b 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -29,11 +29,8 @@ class PPLSyntaxParser extends Parser { object PlaneUtils { def plan(parser: PPLSyntaxParser, query: String): Statement = { - val astExpressionBuilder = new AstExpressionBuilder() - val astBuilder = new AstBuilder(astExpressionBuilder, query) - astExpressionBuilder.setAstBuilder(astBuilder) - val builder = - new AstStatementBuilder(astBuilder, AstStatementBuilder.StatementBuilderContext.builder()) - builder.visit(parser.parse(query)) + new AstStatementBuilder( + new AstBuilder(query), + AstStatementBuilder.StatementBuilderContext.builder()).visit(parser.parse(query)) } } diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java new file mode 100644 index 000000000..3d3940730 --- /dev/null +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java @@ -0,0 +1,61 @@ +package org.opensearch.sql.expression.function; + +import org.junit.Assert; +import org.junit.Test; + +public class SerializableUdfTest { + + @Test(expected = RuntimeException.class) + public void cidrNullIpTest() { + SerializableUdf.cidrFunction.apply(null, "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrEmptyIpTest() { + SerializableUdf.cidrFunction.apply("", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrNullCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", null); + } + + @Test(expected = RuntimeException.class) + public void cidrEmptyCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", ""); + } + + @Test(expected = RuntimeException.class) + public void cidrInvalidIpTest() { + SerializableUdf.cidrFunction.apply("xxx", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrInvalidCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", "xxx"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTest() { + SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "192.168.0.0/24"); + SerializableUdf.cidrFunction.apply("192.168.0.0", "2001:db8::/324"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTestV6V4() { + SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTestV4V6() { + SerializableUdf.cidrFunction.apply("192.168.0.0", "2001:db8::/324"); + } + + @Test + public void cidrBasicTest() { + Assert.assertTrue(SerializableUdf.cidrFunction.apply("192.168.0.0", "192.168.0.0/24")); + Assert.assertFalse(SerializableUdf.cidrFunction.apply("10.10.0.0", "192.168.0.0/24")); + Assert.assertTrue(SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "2001:db8::/32")); + Assert.assertFalse(SerializableUdf.cidrFunction.apply("2001:0db7:85a3:0000:0000:8a2e:0370:7334", "2001:0db8::/32")); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 2da93d5d8..50ef985d6 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, EqualTo, GreaterThan, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -292,6 +292,44 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("Search multiple tables - with table alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """ + | source=table1, table2, table3 as t + | | where t.name = 'Molly' + |""".stripMargin), + context) + + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val star = UnresolvedStar(None) + val plan1 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table1))) + val plan2 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table2))) + val plan3 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table3))) + + val expectedPlan = + Union(Seq(plan1, plan2, plan3), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logPlan, false) + } + test("test fields + field list") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..308b038bb --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite.scala @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentTimestamp, CurrentTimeZone, DateAddInterval, Literal, MakeInterval, NamedExpression, TimestampAdd, TimestampDiff, ToUTCTimestamp, UnaryMinus} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project + +class PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test DATE_ADD") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 DAY)"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + DateAddInterval( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + MakeInterval( + Literal(0), + Literal(0), + Literal(0), + Literal(2), + Literal(0), + Literal(0), + Literal(0), + failOnError = true)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test DATE_ADD for year") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 YEAR)"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + DateAddInterval( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + MakeInterval( + Literal(2), + Literal(0), + Literal(0), + Literal(0), + Literal(0), + Literal(0), + Literal(0), + failOnError = true)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test DATE_SUB") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = DATE_SUB(DATE('2020-08-26'), INTERVAL 2 DAY)"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + DateAddInterval( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + UnaryMinus( + MakeInterval( + Literal(0), + Literal(0), + Literal(0), + Literal(2), + Literal(0), + Literal(0), + Literal(0), + failOnError = true), + failOnError = true)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test TIMESTAMPADD") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(TimestampAdd("DAY", Literal(17), Literal("2000-01-01 00:00:00")), "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test TIMESTAMPADD with timestamp") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00'))"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + TimestampAdd( + "DAY", + Literal(17), + UnresolvedFunction( + "timestamp", + Seq(Literal("2000-01-01 00:00:00")), + isDistinct = false)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test TIMESTAMPDIFF") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + TimestampDiff("YEAR", Literal("1997-01-01 00:00:00"), Literal("2001-03-06 00:00:00")), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test TIMESTAMPDIFF with timestamp") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = TIMESTAMPDIFF(YEAR, TIMESTAMP('1997-01-01 00:00:00'), TIMESTAMP('2001-03-06 00:00:00'))"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + TimestampDiff( + "YEAR", + UnresolvedFunction( + "timestamp", + Seq(Literal("1997-01-01 00:00:00")), + isDistinct = false), + UnresolvedFunction( + "timestamp", + Seq(Literal("2001-03-06 00:00:00")), + isDistinct = false)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test UTC_TIMESTAMP") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | eval a = UTC_TIMESTAMP()"), context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(ToUTCTimestamp(CurrentTimestamp(), CurrentTimeZone()), "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test CURRENT_TIMEZONE") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | eval a = CURRENT_TIMEZONE()"), context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("current_timezone", Seq.empty, isDistinct = false), "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..2acaac529 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Explode, GeneratorOuter, Literal, RegExpExtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, Project} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanExpandCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test expand only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | expand field_with_array"), context) + + val relation = UnresolvedRelation(Seq("relation")) + val generator = Explode(UnresolvedAttribute("field_with_array")) + val generate = Generate(generator, seq(), false, None, seq(), relation) + val expectedPlan = Project(seq(UnresolvedStar(None)), generate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + s""" + | source = table + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin), + context) + + val relation = UnresolvedRelation(Seq("table")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), relation) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand on array field which is eval array=json_array") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields uid"), + context) + + val relation = UnresolvedRelation(Seq("table")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), relation) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project(seq(UnresolvedAttribute("uid")), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand only field with alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | expand field_with_array as array_list "), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val generate = Generate( + Explode(UnresolvedAttribute("field_with_array")), + seq(), + false, + None, + seq(UnresolvedAttribute("array_list")), + relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = + Aggregate(Seq(groupingState, groupingCompany), Seq(average, state, company), generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats with alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as workers | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("workers")), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval") { + val context = new CatalystPlanContext + val query = "source = table | expand employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval with fields and alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus " + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("worker")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = + Project(Seq(UnresolvedAttribute("worker"), UnresolvedAttribute("bonus")), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and fields") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=table | expand employee | parse description '(?.+@.+)' | fields employee, email"), + context) + val table = UnresolvedRelation(Seq("table")) + val generator = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generator) + val expectedPlan = + Project(Seq(UnresolvedAttribute("employee"), UnresolvedAttribute("email")), parseProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and flatten ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | expand employee | parse description '(?.+@.+)' | flatten roles "), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generateEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..58a6c04b3 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, GeneratorOuter, Literal, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, GlobalLimit, LocalLimit, Project, Sort} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanFlattenCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test flatten only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | flatten field_with_array"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("field_with_array")) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), true, None, seq(), relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val context = new CatalystPlanContext + val query = + "source = relation | fields state, company, employee | flatten employee | fields state, company, salary | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val projectStateCompanyEmployee = + Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("employee")), + table) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + projectStateCompanyEmployee) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val projectStateCompanySalary = Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("salary")), + dropSourceColumn) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + projectStateCompanySalary) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and eval") { + val context = new CatalystPlanContext + val query = "source = relation | flatten employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and parse and flatten") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | flatten employee | parse description '(?.+@.+)' | flatten roles"), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumnEmployee = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generateEmployee) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + dropSourceColumnEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 3ceff7735..f4ed397e3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -271,9 +271,9 @@ class PPLLogicalPlanJoinTranslatorTestSuite pplParser, s""" | source = $testTable1 - | | inner JOIN left = l,right = r ON l.id = r.id $testTable2 - | | left JOIN left = l,right = r ON l.name = r.name $testTable3 - | | cross JOIN left = l,right = r $testTable4 + | | inner JOIN left = l right = r ON l.id = r.id $testTable2 + | | left JOIN left = l right = r ON l.name = r.name $testTable3 + | | cross JOIN left = l right = r $testTable4 | """.stripMargin) val logicalPlan = planTransformer.visit(logPlan, context) val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) @@ -443,17 +443,17 @@ class PPLLogicalPlanJoinTranslatorTestSuite s""" | source = $testTable1 | | head 10 - | | inner JOIN left = l,right = r ON l.id = r.id + | | inner JOIN left = l right = r ON l.id = r.id | [ | source = $testTable2 | | where id > 10 | ] - | | left JOIN left = l,right = r ON l.name = r.name + | | left JOIN left = l right = r ON l.name = r.name | [ | source = $testTable3 | | fields id | ] - | | cross JOIN left = l,right = r + | | cross JOIN left = l right = r | [ | source = $testTable4 | | sort id @@ -565,4 +565,284 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test multiple joins with table alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with table and subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN left = l right = r ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN left = l right = r ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN left = l right = r ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("l", SubqueryAlias("t1", table1)), + SubqueryAlias("r", SubqueryAlias("t2", table2)), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + SubqueryAlias("l", joinPlan1), + SubqueryAlias("r", SubqueryAlias("t3", table3)), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + SubqueryAlias("l", joinPlan2), + SubqueryAlias("r", SubqueryAlias("t4", table4)), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins without table aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN ON table1.id = table2.id table2 + | | JOIN ON table1.id = table3.id table3 + | | JOIN ON table2.id = table4.id table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + table4, + Inner, + Some(EqualTo(UnresolvedAttribute("table2.id"), UnresolvedAttribute("table4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name table2 + | | JOIN right = t3 ON t1.name = t3.name table3 + | | JOIN right = t4 ON t2.name = t4.name table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test side alias will override the subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt + | | fields t1.name, t2.name + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", SubqueryAlias("tt", SubqueryAlias("ttt", table2))), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val expectedPlan = + Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index f5dfc4ec8..6193bc43f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} @@ -48,7 +48,7 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', array(1, 2, 3)))"""), + plan(pplParser, """source=t a = to_json_string(json_object('key', array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -97,7 +97,9 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', json_array(1, 2, 3)))"""), + plan( + pplParser, + """source=t a = to_json_string(json_object('key', json_array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -139,25 +141,21 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } - test("test json_array_length(json_array())") { + test("test array_length(json_array())") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json_array_length(json_array(1,2,3))"""), + plan(pplParser, """source=t a = array_length(json_array(1,2,3))"""), context) val table = UnresolvedRelation(Seq("t")) val jsonFunc = UnresolvedFunction( - "json_array_length", + "array_size", Seq( UnresolvedFunction( - "to_json", - Seq( - UnresolvedFunction( - "array", - Seq(Literal(1), Literal(2), Literal(3)), - isDistinct = false)), + "array", + Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false)), isDistinct = false) val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..9c3c1c8a0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, LambdaFunction, Literal, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project + +class PPLLogicalPlanLambdaFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test forall()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = forall(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("forall", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test exits()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = exists(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("exists", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test filter()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = filter(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("filter", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform()") { + val context = new CatalystPlanContext + // test single argument of lambda + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, x -> x + 1)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction("+", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(1)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform() - test binary lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - without finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - with finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y, x -> x * 10)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val finishLambda = LambdaFunction( + UnresolvedFunction("*", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(10)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda, finishLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala new file mode 100644 index 000000000..213f201cc --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DataTypes + +class PPLLogicalPlanParseCidrmatchTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test cidrmatch for ipv4 for 192.168.1.0/24") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("192.168.1.0/24") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(false)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(false)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field projected") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and cidrmatch(ipAddress, '2003:db8::/32') | fields ip"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedAttribute("ip")), + Filter(And(filterIpv6, cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field bool respond for each ip") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true | eval inRange = case(cidrmatch(ipAddress, '2003:db8::/32'), 'in' else 'out') | fields ip, inRange"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterClause = Filter(filterIpv6, UnresolvedRelation(Seq("t"))) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val equalTo = EqualTo(Literal(true), cidr) + val caseFunction = CaseWhen(Seq((equalTo, Literal("in"))), Literal("out")) + val aliasStatusCategory = Alias(caseFunction, "inRange")() + val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory) + val evalProject = Project(evalProjectList, filterClause) + + val expectedPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("inRange")), evalProject) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..d22750ee0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanTrendlineCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test trendline") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sma(3, age)"), context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, table)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age sma(3, age)"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with multiple trendline sma commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort + age sma(2, age) as two_points_sma sma(3, age) | fields name, age, two_points_sma, age_trendline"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageTrendlineField = UnresolvedAttribute("age_trendline") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "age_trendline")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageTrendlineField), + Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 0978e6898..ef0e76557 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -314,7 +314,13 @@ object FlintREPL extends Logging with FlintJobExecutor { val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) - var futurePrepareQueryExecution: Future[Either[String, Unit]] = null + + val statementsExecutionManager = + instantiateStatementExecutionManager(commandContext) + + var futurePrepareQueryExecution: Future[Either[String, Unit]] = Future { + statementsExecutionManager.prepareStatementExecution() + } try { logInfo(s"""Executing session with sessionId: ${sessionId}""") @@ -324,12 +330,6 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - val statementsExecutionManager = - instantiateStatementExecutionManager(commandContext) - - futurePrepareQueryExecution = Future { - statementsExecutionManager.prepareStatementExecution() - } try { val commandState = CommandState( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index 432d6df11..09a1b3c1e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -6,7 +6,7 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement -import org.opensearch.flint.core.storage.OpenSearchUpdater +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging @@ -29,8 +29,8 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] // Using one reader client within same session will cause concurrency issue. - // To resolve this move the reader creation and getNextStatement method to mirco-batch level - private val flintReader = createOpenSearchQueryReader() + // To resolve this move the reader creation to getNextStatement method at mirco-batch level + private var currentReader: Option[FlintReader] = None override def prepareStatementExecution(): Either[String, Unit] = { checkAndCreateIndex(osClient, resultIndex) @@ -39,12 +39,17 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) } override def terminateStatementExecution(): Unit = { - flintReader.close() + currentReader.foreach(_.close()) + currentReader = None } override def getNextStatement(): Option[FlintStatement] = { - if (flintReader.hasNext) { - val rawStatement = flintReader.next() + if (currentReader.isEmpty) { + currentReader = Some(createOpenSearchQueryReader()) + } + + if (currentReader.get.hasNext) { + val rawStatement = currentReader.get.next() val flintStatement = FlintStatement.deserialize(rawStatement) logInfo(s"Next statement to execute: $flintStatement") Some(flintStatement) @@ -100,7 +105,6 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) | ] | } |}""".stripMargin - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader + osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 5eeccce73..07ed94bdc 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -1387,7 +1387,8 @@ class FlintREPLTest val expectedCalls = Math.ceil(inactivityLimit.toDouble / DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY).toInt - verify(mockOSClient, Mockito.atMost(expectedCalls)).getIndexMetadata(*) + verify(mockOSClient, times(1)).getIndexMetadata(*) + verify(mockOSClient, Mockito.atMost(expectedCalls)).createQueryReader(*, *, *, *) } val testCases = Table(