diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 431ff1199..ad02ac78a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ # This should match the owning team set up in https://github.com/orgs/opensearch-project/teams -* @joshuali925 @dai-chen @rupal-bq @mengweieric @vamsi-amazon @penghuo @seankao-az @anirudha @kaituo @YANG-DB @LantaoJin +* @joshuali925 @dai-chen @rupal-bq @mengweieric @vamsi-amazon @penghuo @seankao-az @anirudha @kaituo @YANG-DB @noCharger @LantaoJin @ykmr1224 diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..5c7240c8d --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,15 @@ +### Description +_Describe what this change achieves._ + +### Related Issues +_List any issues this PR will resolve, e.g. Resolves [...]._ + +### Check List +- [ ] Updated documentation (docs/ppl-lang/README.md) +- [ ] Implemented unit tests +- [ ] Implemented tests for combination with other commands +- [ ] New added source code should include a copyright header +- [ ] Commits are signed per the DCO using `--signoff` + +By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. +For more information on following Developer Certificate of Origin and signing off your commits, please check [here](https://github.com/opensearch-project/sql/blob/main/CONTRIBUTING.md#developer-certificate-of-origin). diff --git a/.github/workflows/test-and-build-workflow.yml b/.github/workflows/test-and-build-workflow.yml index 501b2b737..f8d9bd682 100644 --- a/.github/workflows/test-and-build-workflow.yml +++ b/.github/workflows/test-and-build-workflow.yml @@ -22,8 +22,8 @@ jobs: distribution: 'temurin' java-version: 11 - - name: Integ Test - run: sbt integtest/integration - - name: Style check run: sbt scalafmtCheckAll + + - name: Integ Test + run: sbt integtest/integration diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 1a35fcc69..b7ead807d 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -12,9 +12,10 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Chen Dai | [dai-chen](https://github.com/dai-chen) | Amazon | | Vamsi Manohar | [vamsi-amazon](https://github.com/vamsi-amazon) | Amazon | | Peng Huo | [penghuo](https://github.com/penghuo) | Amazon | -| Lior Perry | [YANG-DB](https://github.com/YANG-DB) | Amazon | +| Lior Perry | [YANG-DB](https://github.com/YANG-DB) | Amazon | | Sean Kao | [seankao-az](https://github.com/seankao-az) | Amazon | | Anirudha Jadhav | [anirudha](https://github.com/anirudha) | Amazon | | Kaituo Li | [kaituo](https://github.com/kaituo) | Amazon | | Louis Chu | [noCharger](https://github.com/noCharger) | Amazon | | Lantao Jin | [LantaoJin](https://github.com/LantaoJin) | Amazon | +| Tomoyuki Morita | [ykmr1224](https://github.com/ykmr1224) | Amazon | diff --git a/build.sbt b/build.sbt index f7653c50c..66b06d6be 100644 --- a/build.sbt +++ b/build.sbt @@ -89,6 +89,7 @@ lazy val flintCore = (project in file("flint-core")) "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), "software.amazon.awssdk" % "auth-crt" % "2.28.10", + "org.projectlombok" % "lombok" % "1.18.30" % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", @@ -153,6 +154,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", "com.github.sbt" % "junit-interface" % "0.13.3" % "test", "org.projectlombok" % "lombok" % "1.18.30", + "com.github.seancfoley" % "ipaddress" % "5.5.1", ), libraryDependencies ++= deps(sparkVersion), // ANTLR settings diff --git a/docs/index.md b/docs/index.md index bb3121ba6..e76cb387a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -549,6 +549,7 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i - `spark.flint.monitor.initialDelaySeconds`: Initial delay in seconds before starting the monitoring task. Default value is 15. - `spark.flint.monitor.intervalSeconds`: Interval in seconds for scheduling the monitoring task. Default value is 60. - `spark.flint.monitor.maxErrorCount`: Maximum number of consecutive errors allowed before stopping the monitoring task. Default value is 5. +- `spark.flint.metadataCacheWrite.enabled`: default is false. enable writing metadata to index mappings _meta as read cache for frontend user to access. Do not use in production, this setting will be removed in later version. #### Data Type Mapping diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 8e6cbaae9..e780f688d 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -1,5 +1,10 @@ ## Example PPL Queries +#### **Comment** +[See additional command details](ppl-comment.md) +- `source=accounts | top gender // finds most common gender of all the accounts` (line comment) +- `source=accounts | dedup 2 gender /* dedup the document with gender field keep 2 duplication */ | fields account_number, gender` (block comment) + #### **Describe** - `describe table` This command is equal to the `DESCRIBE EXTENDED table` SQL command - `describe schema.table` @@ -28,6 +33,12 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | eval b1 = b + 1 | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) - `source = table | eval b1 = lower(b) | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) +**Field-Summary** +[See additional command details](ppl-fieldsummary-command.md) +- `source = t | fieldsummary includefields=status_code nulls=false` +- `source = t | fieldsummary includefields= id, status_code, request_path nulls=true` +- `source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true` + **Nested-Fields** - `source = catalog.schema.table1, catalog.schema.table2 | fields A.nested1, B.nested1` - `source = catalog.table | where struct_col2.field1.subfield > 'valueA' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield` @@ -44,6 +55,19 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where isempty(a)` - `source = table | where isblank(a)` - `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'` +- `source = table | where a not in (1, 2, 3) | fields a,b,c` +- `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] +- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where cidrmatch(ipv6, '2003:db8::/32')` +- `source = table | trendline sma(2, temperature) as temp_trend` + +#### **IP related queries** +[See additional command details](functions/ppl-ip.md) + +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')` +- `source = table | where isV6 = true | eval inRange = case(cidrmatch(ipAddress, '2003:db8::/32'), 'in' else 'out') | fields ip, inRange` ```sql source = table | eval status_category = @@ -92,6 +116,10 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))` +- `source = table | eval digest = md5(fieldName) | fields digest` +- `source = table | eval digest = sha1(fieldName) | fields digest` +- `source = table | eval digest = sha2(fieldName,256) | fields digest` +- `source = table | eval digest = sha2(fieldName,512) | fields digest` #### Fillnull Assumptions: `a`, `b`, `c`, `d`, `e` are existing fields in `table` @@ -102,6 +130,15 @@ Assumptions: `a`, `b`, `c`, `d`, `e` are existing fields in `table` - `source = table | fillnull using a = 101, b = 102` - `source = table | fillnull using a = concat(b, c), d = 2 * pi() * e` +### Flatten +[See additional command details](ppl-flatten-command.md) +Assumptions: `bridges`, `coor` are existing fields in `table`, and the field's types are `struct` or `array>` +- `source = table | flatten bridges` +- `source = table | flatten coor` +- `source = table | flatten bridges | flatten coor` +- `source = table | fields bridges | flatten bridges` +- `source = table | fields country, bridges | flatten bridges | fields country, length | stats avg(length) as avg by country` + ```sql source = table | eval e = eval status_category = case(a >= 200 AND a < 300, 'Success', @@ -158,6 +195,34 @@ source = table | where ispresent(a) | - `source = table | stats avg(age) as avg_state_age by country, state | stats avg(avg_state_age) as avg_country_age by country` - `source = table | stats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | stats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | stats avg(avg_state_age) as avg_adult_country_age by country` +#### **Event Aggregations** +[See additional command details](ppl-eventstats-command.md) + +- `source = table | eventstats avg(a) ` +- `source = table | where a < 50 | eventstats avg(c) ` +- `source = table | eventstats max(c) by b` +- `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats stddev_samp(c)` +- `source = table | eventstats stddev_pop(c)` +- `source = table | eventstats percentile(c, 90)` +- `source = table | eventstats percentile_approx(c, 99)` + +**Limitation: distinct aggregation could not used in `eventstats`:**_ +- `source = table | eventstats distinct_count(c)` (throw exception) + +**Aggregations With Span** +- `source = table | eventstats count(a) by span(a, 10) as a_span` +- `source = table | eventstats sum(age) by span(age, 5) as age_span | head 2` +- `source = table | eventstats avg(age) by span(age, 20) as age_span, country | sort - age_span | head 2` + +**Aggregations With TimeWindow Span (tumble windowing function)** +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date` +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` + +**Aggregations Group by Multiple Times** +- `source = table | eventstats avg(age) as avg_state_age by country, state | eventstats avg(avg_state_age) as avg_country_age by country` +- `source = table | eventstats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | eventstats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | eventstats avg(avg_state_age) as avg_adult_country_age by country` + #### **Dedup** [See additional command details](ppl-dedup-command.md) @@ -240,8 +305,7 @@ source = table | where ispresent(a) | - `source = table1 | cross join left = l right = r table2` - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - -_- **Limitation: sub-searches is unsupported in join right side now**_ +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` #### **Lookup** @@ -268,6 +332,8 @@ _- **Limitation: "REPLACE" or "APPEND" clause must contain "AS"**_ - `source = outer | where a not in [ source = inner | fields b ]` - `source = outer | where (a) not in [ source = inner | fields b ]` - `source = outer | where (a,b,c) not in [ source = inner | fields d,e,f ]` +- `source = outer a in [ source = inner | fields b ]` (search filtering with subquery) +- `source = outer a not in [ source = inner | fields b ]` (search filtering with subquery) - `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) - `source = table1 | inner join left = l right = r on l.a = r.a AND r.a in [ source = inner | fields d ] | fields l.a, r.a, b, c` (as join filter) @@ -317,6 +383,9 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where not exists [ source = inner | where a = c ]` - `source = outer | where exists [ source = inner | where a = c and b = d ]` - `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = outer not exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = table as t1 exists [ source = table as t2 | where t1.a = t2.a ]` (table alias is useful in exists subquery) - `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) @@ -332,8 +401,13 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` - `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` -**Uncorrelated scalar subquery in Select and Where** -- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` +**Uncorrelated scalar subquery in Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | fields a` +- `source = outer | where [ source = inner | stats min(c) ] > 0 | fields a` + +**Uncorrelated scalar subquery in Search filter** +- `source = outer a > [ source = inner | stats min(c) ] | fields a` +- `source = outer [ source = inner | stats min(c) ] > 0 | fields a` **Correlated scalar subquery in Select** - `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` @@ -345,10 +419,23 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` - `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` +**Correlated scalar subquery in Search filter** +- `source = outer a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + **Nested scalar subquery** - `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` - `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` +#### **(Relation) Subquery** +[See additional command details](ppl-subquery-command.md) + +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or Search clause. + +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +- `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` + +_- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ --- #### Experimental Commands: @@ -366,7 +453,7 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in ### Planned Commands: #### **fillnull** - +[See additional command details](ppl-fillnull-command.md) ```sql - `source=accounts | fillnull fields status_code=101` - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` @@ -374,4 +461,3 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` ``` -[See additional command details](planning/ppl-fillnull-command.md) diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 2ddceca0a..d78f4c030 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -22,6 +22,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). * **Commands** + - [`comment`](ppl-comment.md) + - [`explain command `](PPL-Example-Commands.md/#explain) - [`dedup command `](ppl-dedup-command.md) @@ -29,6 +31,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`describe command`](PPL-Example-Commands.md/#describe) - [`fillnull command`](ppl-fillnull-command.md) + + - [`flatten command`](ppl-flatten-command.md) - [`eval command`](ppl-eval-command.md) @@ -48,6 +52,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`stats command`](ppl-stats-command.md) + - [`eventstats command`](ppl-eventstats-command.md) + - [`where command`](ppl-where-command.md) - [`head command`](ppl-head-command.md) @@ -63,7 +69,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`subquery commands`](ppl-subquery-command.md) - [`correlation commands`](ppl-correlation-command.md) - + + - [`trendline commands`](ppl-trendline-command.md) * **Functions** @@ -75,10 +82,17 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`String Functions`](functions/ppl-string.md) + - [`JSON Functions`](functions/ppl-json.md) + - [`Condition Functions`](functions/ppl-condition.md) - [`Type Conversion Functions`](functions/ppl-conversion.md) + - [`Cryptographic Functions`](functions/ppl-cryptographic.md) + + - [`IP Address Functions`](functions/ppl-ip.md) + + - [`Lambda Functions`](functions/ppl-lambda.md) --- ### PPL On Spark @@ -97,4 +111,4 @@ See samples of [PPL queries](PPL-Example-Commands.md) --- ### PPL Project Roadmap -[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) \ No newline at end of file +[PPL Github Project Roadmap](https://github.com/orgs/opensearch-project/projects/214) diff --git a/docs/ppl-lang/functions/ppl-cryptographic.md b/docs/ppl-lang/functions/ppl-cryptographic.md new file mode 100644 index 000000000..ecabc624c --- /dev/null +++ b/docs/ppl-lang/functions/ppl-cryptographic.md @@ -0,0 +1,77 @@ +## PPL Cryptographic Functions + +### `MD5` + +**Description** + +Calculates the MD5 digest and returns the value as a 32 character hex string. + +Usage: `md5('hello')` + +**Argument type:** +- STRING +- Return type: **STRING** + +Example: + + os> source=people | eval `MD5('hello')` = MD5('hello') | fields `MD5('hello')` + fetched rows / total rows = 1/1 + +----------------------------------+ + | MD5('hello') | + |----------------------------------| + | 5d41402abc4b2a76b9719d911017c592 | + +----------------------------------+ + +### `SHA1` + +**Description** + +Returns the hex string result of SHA-1 + +Usage: `sha1('hello')` + +**Argument type:** +- STRING +- Return type: **STRING** + +Example: + + os> source=people | eval `SHA1('hello')` = SHA1('hello') | fields `SHA1('hello')` + fetched rows / total rows = 1/1 + +------------------------------------------+ + | SHA1('hello') | + |------------------------------------------| + | aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d | + +------------------------------------------+ + +### `SHA2` + +**Description** + +Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, and SHA-512). The numBits indicates the desired bit length of the result, which must have a value of 224, 256, 384, 512 + +Usage: `sha2('hello',256)` + +Usage: `sha2('hello',512)` + +**Argument type:** +- STRING, INTEGER +- Return type: **STRING** + +Example: + + os> source=people | eval `SHA2('hello',256)` = SHA2('hello',256) | fields `SHA2('hello',256)` + fetched rows / total rows = 1/1 + +------------------------------------------------------------------+ + | SHA2('hello',256) | + |------------------------------------------------------------------| + | 2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824 | + +------------------------------------------------------------------+ + + os> source=people | eval `SHA2('hello',512)` = SHA2('hello',512) | fields `SHA2('hello',512)` + fetched rows / total rows = 1/1 + +----------------------------------------------------------------------------------------------------------------------------------+ + | SHA2('hello',512) | + |----------------------------------------------------------------------------------------------------------------------------------| + | 9b71d224bd62f3785d96d46ad3ea3d73319bfbc2890caadae2dff72519673ca72323c3d99ba5c11d7c7acc6e14b8c5da0c4663475c2e5c3adef46f73bcdec043 | + +----------------------------------------------------------------------------------------------------------------------------------+ \ No newline at end of file diff --git a/docs/ppl-lang/functions/ppl-datetime.md b/docs/ppl-lang/functions/ppl-datetime.md index d3ca272e3..e479176a4 100644 --- a/docs/ppl-lang/functions/ppl-datetime.md +++ b/docs/ppl-lang/functions/ppl-datetime.md @@ -5,244 +5,32 @@ **Description:** -**Usage:** adddate(date, INTERVAL expr unit) / adddate(date, days) adds the interval of second argument to date; adddate(date, days) adds the second argument as integer number of days to date. -If first argument is TIME, today's date is used; if first argument is DATE, time at midnight is used. +**Usage:** adddate(date, days) adds the second argument as integer number of days to date. +If days is negative abs(days) are subtracted from date. -Argument type: DATE/TIMESTAMP/TIME, INTERVAL/LONG +Argument type: DATE, LONG **Return type map:** -(DATE/TIMESTAMP/TIME, INTERVAL) -> TIMESTAMP - (DATE, LONG) -> DATE -(TIMESTAMP/TIME, LONG) -> TIMESTAMP - -Synonyms: `DATE_ADD`_ when invoked with the INTERVAL form of the second argument. - -Antonyms: `SUBDATE`_ - -Example: - - os> source=people | eval `'2020-08-26' + 1h` = ADDDATE(DATE('2020-08-26'), INTERVAL 1 HOUR), `'2020-08-26' + 1` = ADDDATE(DATE('2020-08-26'), 1), `ts '2020-08-26 01:01:01' + 1` = ADDDATE(TIMESTAMP('2020-08-26 01:01:01'), 1) | fields `'2020-08-26' + 1h`, `'2020-08-26' + 1`, `ts '2020-08-26 01:01:01' + 1` - fetched rows / total rows = 1/1 - +---------------------+--------------------+--------------------------------+ - | '2020-08-26' + 1h | '2020-08-26' + 1 | ts '2020-08-26 01:01:01' + 1 | - |---------------------+--------------------+--------------------------------| - | 2020-08-26 01:00:00 | 2020-08-27 | 2020-08-27 01:01:01 | - +---------------------+--------------------+--------------------------------+ - - - -### `ADDTIME` - -**Description:** - - -**Usage:** addtime(expr1, expr2) adds expr2 to expr1 and returns the result. If argument is TIME, today's date is used; if argument is DATE, time at midnight is used. - -Argument type: DATE/TIMESTAMP/TIME, DATE/TIMESTAMP/TIME - -**Return type map:** - -(DATE/TIMESTAMP, DATE/TIMESTAMP/TIME) -> TIMESTAMP - -(TIME, DATE/TIMESTAMP/TIME) -> TIME - -Antonyms: `SUBTIME`_ - -Example: - - os> source=people | eval `'2008-12-12' + 0` = ADDTIME(DATE('2008-12-12'), DATE('2008-11-15')) | fields `'2008-12-12' + 0` - fetched rows / total rows = 1/1 - +---------------------+ - | '2008-12-12' + 0 | - |---------------------| - | 2008-12-12 00:00:00 | - +---------------------+ - - os> source=people | eval `'23:59:59' + 0` = ADDTIME(TIME('23:59:59'), DATE('2004-01-01')) | fields `'23:59:59' + 0` - fetched rows / total rows = 1/1 - +------------------+ - | '23:59:59' + 0 | - |------------------| - | 23:59:59 | - +------------------+ - - os> source=people | eval `'2004-01-01' + '23:59:59'` = ADDTIME(DATE('2004-01-01'), TIME('23:59:59')) | fields `'2004-01-01' + '23:59:59'` - fetched rows / total rows = 1/1 - +-----------------------------+ - | '2004-01-01' + '23:59:59' | - |-----------------------------| - | 2004-01-01 23:59:59 | - +-----------------------------+ - - os> source=people | eval `'10:20:30' + '00:05:42'` = ADDTIME(TIME('10:20:30'), TIME('00:05:42')) | fields `'10:20:30' + '00:05:42'` - fetched rows / total rows = 1/1 - +---------------------------+ - | '10:20:30' + '00:05:42' | - |---------------------------| - | 10:26:12 | - +---------------------------+ - - os> source=people | eval `'2007-02-28 10:20:30' + '20:40:50'` = ADDTIME(TIMESTAMP('2007-02-28 10:20:30'), TIMESTAMP('2002-03-04 20:40:50')) | fields `'2007-02-28 10:20:30' + '20:40:50'` - fetched rows / total rows = 1/1 - +--------------------------------------+ - | '2007-02-28 10:20:30' + '20:40:50' | - |--------------------------------------| - | 2007-03-01 07:01:20 | - +--------------------------------------+ - - -### `CONVERT_TZ` - - -**Description:** - - -**Usage:** convert_tz(timestamp, from_timezone, to_timezone) constructs a local timestamp converted from the from_timezone to the to_timezone. CONVERT_TZ returns null when any of the three function arguments are invalid, i.e. timestamp is not in the format yyyy-MM-dd HH:mm:ss or the timeszone is not in (+/-)HH:mm. It also is invalid for invalid dates, such as February 30th and invalid timezones, which are ones outside of -13:59 and +14:00. - -Argument type: TIMESTAMP, STRING, STRING - -Return type: TIMESTAMP - -Conversion from +00:00 timezone to +10:00 timezone. Returns the timestamp argument converted from +00:00 to +10:00 - -Example: - - os> source=people | eval `convert_tz('2008-05-15 12:00:00','+00:00','+10:00')` = convert_tz('2008-05-15 12:00:00','+00:00','+10:00') | fields `convert_tz('2008-05-15 12:00:00','+00:00','+10:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-05-15 12:00:00','+00:00','+10:00') | - |-------------------------------------------------------| - | 2008-05-15 22:00:00 | - +-------------------------------------------------------+ - -The valid timezone range for convert_tz is (-13:59, +14:00) inclusive. Timezones outside of the range, such as +15:00 in this example will return null. - -Example: - - os> source=people | eval `convert_tz('2008-05-15 12:00:00','+00:00','+15:00')` = convert_tz('2008-05-15 12:00:00','+00:00','+15:00')| fields `convert_tz('2008-05-15 12:00:00','+00:00','+15:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-05-15 12:00:00','+00:00','+15:00') | - |-------------------------------------------------------| - | null | - +-------------------------------------------------------+ - -Conversion from a positive timezone to a negative timezone that goes over date line. - -Example: - - os> source=people | eval `convert_tz('2008-05-15 12:00:00','+03:30','-10:00')` = convert_tz('2008-05-15 12:00:00','+03:30','-10:00') | fields `convert_tz('2008-05-15 12:00:00','+03:30','-10:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-05-15 12:00:00','+03:30','-10:00') | - |-------------------------------------------------------| - | 2008-05-14 22:30:00 | - +-------------------------------------------------------+ - -Valid dates are required in convert_tz, invalid dates such as April 31st (not a date in the Gregorian calendar) will result in null. - -Example: - - os> source=people | eval `convert_tz('2008-04-31 12:00:00','+03:30','-10:00')` = convert_tz('2008-04-31 12:00:00','+03:30','-10:00') | fields `convert_tz('2008-04-31 12:00:00','+03:30','-10:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-04-31 12:00:00','+03:30','-10:00') | - |-------------------------------------------------------| - | null | - +-------------------------------------------------------+ - -Valid dates are required in convert_tz, invalid dates such as February 30th (not a date in the Gregorian calendar) will result in null. - -Example: - - os> source=people | eval `convert_tz('2008-02-30 12:00:00','+03:30','-10:00')` = convert_tz('2008-02-30 12:00:00','+03:30','-10:00') | fields `convert_tz('2008-02-30 12:00:00','+03:30','-10:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-02-30 12:00:00','+03:30','-10:00') | - |-------------------------------------------------------| - | null | - +-------------------------------------------------------+ - -February 29th 2008 is a valid date because it is a leap year. - -Example: - - os> source=people | eval `convert_tz('2008-02-29 12:00:00','+03:30','-10:00')` = convert_tz('2008-02-29 12:00:00','+03:30','-10:00') | fields `convert_tz('2008-02-29 12:00:00','+03:30','-10:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-02-29 12:00:00','+03:30','-10:00') | - |-------------------------------------------------------| - | 2008-02-28 22:30:00 | - +-------------------------------------------------------+ - -Valid dates are required in convert_tz, invalid dates such as February 29th 2007 (2007 is not a leap year) will result in null. - -Example: - - os> source=people | eval `convert_tz('2007-02-29 12:00:00','+03:30','-10:00')` = convert_tz('2007-02-29 12:00:00','+03:30','-10:00') | fields `convert_tz('2007-02-29 12:00:00','+03:30','-10:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2007-02-29 12:00:00','+03:30','-10:00') | - |-------------------------------------------------------| - | null | - +-------------------------------------------------------+ - -The valid timezone range for convert_tz is (-13:59, +14:00) inclusive. Timezones outside of the range, such as +14:01 in this example will return null. - -Example: - - os> source=people | eval `convert_tz('2008-02-01 12:00:00','+14:01','+00:00')` = convert_tz('2008-02-01 12:00:00','+14:01','+00:00') | fields `convert_tz('2008-02-01 12:00:00','+14:01','+00:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-02-01 12:00:00','+14:01','+00:00') | - |-------------------------------------------------------| - | null | - +-------------------------------------------------------+ - -The valid timezone range for convert_tz is (-13:59, +14:00) inclusive. Timezones outside of the range, such as +14:00 in this example will return a correctly converted date time object. - -Example: - - os> source=people | eval `convert_tz('2008-02-01 12:00:00','+14:00','+00:00')` = convert_tz('2008-02-01 12:00:00','+14:00','+00:00') | fields `convert_tz('2008-02-01 12:00:00','+14:00','+00:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-02-01 12:00:00','+14:00','+00:00') | - |-------------------------------------------------------| - | 2008-01-31 22:00:00 | - +-------------------------------------------------------+ - -The valid timezone range for convert_tz is (-13:59, +14:00) inclusive. Timezones outside of the range, such as -14:00 will result in null - -Example: - - os> source=people | eval `convert_tz('2008-02-01 12:00:00','-14:00','+00:00')` = convert_tz('2008-02-01 12:00:00','-14:00','+00:00') | fields `convert_tz('2008-02-01 12:00:00','-14:00','+00:00')` - fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-02-01 12:00:00','-14:00','+00:00') | - |-------------------------------------------------------| - | null | - +-------------------------------------------------------+ - -The valid timezone range for convert_tz is (-13:59, +14:00) inclusive. This timezone is within range so it is valid and will convert the time. +Antonyms: `SUBDATE` Example: - os> source=people | eval `convert_tz('2008-02-01 12:00:00','-13:59','+00:00')` = convert_tz('2008-02-01 12:00:00','-13:59','+00:00') | fields `convert_tz('2008-02-01 12:00:00','-13:59','+00:00')` + os> source=people | eval `'2020-08-26' + 1` = ADDDATE(DATE('2020-08-26'), 1) | fields `'2020-08-26' + 1` fetched rows / total rows = 1/1 - +-------------------------------------------------------+ - | convert_tz('2008-02-01 12:00:00','-13:59','+00:00') | - |-------------------------------------------------------| - | 2008-02-02 01:59:00 | - +-------------------------------------------------------+ - + +--------------------+ + | '2020-08-26' + 1 | + +--------------------+ + | 2020-08-27 | + +--------------------+ ### `CURDATE` **Description:** +This function requires Spark 3.4.0+, if you use old Spark version, use `CURRENT_DATE` instead. Returns the current time as a value in 'YYYY-MM-DD'. `CURDATE()` returns the time at which it executes as `SYSDATE() <#sysdate>`_ does. @@ -280,25 +68,6 @@ Example: | 2022-08-02 | +------------------+ - -### `CURRENT_TIME` - -**Description:** - - -`CURRENT_TIME()` are synonyms for `CURTIME() <#curtime>`_. - -Example: - - > source=people | eval `CURRENT_TIME()` = CURRENT_TIME() | fields `CURRENT_TIME()` - fetched rows / total rows = 1/1 - +------------------+ - | CURRENT_TIME() | - |------------------+ - | 15:39:05 | - +------------------+ - - ### `CURRENT_TIMESTAMP` **Description:** @@ -317,29 +86,6 @@ Example: +-----------------------+ -### `CURTIME` - -**Description:** - - -Returns the current time as a value in 'hh:mm:ss'. -`CURTIME()` returns the time at which the statement began to execute as `NOW() <#now>`_ does. - -Return type: TIME - -Specification: CURTIME() -> TIME - -Example: - - > source=people | eval `value_1` = CURTIME(), `value_2` = CURTIME() | fields `value_1`, `value_2` - fetched rows / total rows = 1/1 - +-----------+-----------+ - | value_1 | value_2 | - |-----------+-----------| - | 15:39:05 | 15:39:05 | - +-----------+-----------+ - - ### `DATE` **Description:** @@ -377,40 +123,6 @@ Example: | 2020-08-26 | +----------------------------+ - os> source=people | eval `DATE('2020-08-26 13:49')` = DATE('2020-08-26 13:49') | fields `DATE('2020-08-26 13:49')` - fetched rows / total rows = 1/1 - +----------------------------+ - | DATE('2020-08-26 13:49') | - |----------------------------| - | 2020-08-26 | - +----------------------------+ - - -### `DATE_ADD` - -**Description:** - - -**Usage:** date_add(date, INTERVAL expr unit) adds the interval expr to date. If first argument is TIME, today's date is used; if first argument is DATE, time at midnight is used. - -Argument type: DATE/TIMESTAMP/TIME, INTERVAL - -Return type: TIMESTAMP - -Synonyms: `ADDDATE`_ - -Antonyms: `DATE_SUB`_ - -Example: - - os> source=people | eval `'2020-08-26' + 1h` = DATE_ADD(DATE('2020-08-26'), INTERVAL 1 HOUR), `ts '2020-08-26 01:01:01' + 1d` = DATE_ADD(TIMESTAMP('2020-08-26 01:01:01'), INTERVAL 1 DAY) | fields `'2020-08-26' + 1h`, `ts '2020-08-26 01:01:01' + 1d` - fetched rows / total rows = 1/1 - +---------------------+---------------------------------+ - | '2020-08-26' + 1h | ts '2020-08-26 01:01:01' + 1d | - |---------------------+---------------------------------| - | 2020-08-26 01:00:00 | 2020-08-27 01:01:01 | - +---------------------+---------------------------------+ - ### `DATE_FORMAT` @@ -421,42 +133,33 @@ Example: **Usage:** date_format(date, format) formats the date argument using the specifiers in the format argument. If an argument of type TIME is provided, the local date is used. -| Specifier | **Description:** | -|-----------|------------------| -| %a | Abbreviated weekday name (Sun..Sat) | -| %b | Abbreviated month name (Jan..Dec) | -| %c | Month, numeric (0..12) | -| %D | Day of the month with English suffix (0th, 1st, 2nd, 3rd, ...) | -| %d | Day of the month, numeric (00..31) | -| %e | Day of the month, numeric (0..31) | -| %f | Microseconds (000000..999999) | -| %H | Hour (00..23) | -| %h | Hour (01..12) | -| %I | Hour (01..12) | -| %i | Minutes, numeric (00..59) | -| %j | Day of year (001..366) | -| %k | Hour (0..23) | -| %l | Hour (1..12) | -| %M | Month name (January..December) | -| %m | Month, numeric (00..12) | -| %p | AM or PM | -| %r | Time, 12-hour (hh:mm:ss followed by AM or PM) | -| %S | Seconds (00..59) | -| %s | Seconds (00..59) | -| %T | Time, 24-hour (hh:mm:ss) | -| %U | Week (00..53), where Sunday is the first day of the week; WEEK() mode 0 | -| %u | Week (00..53), where Monday is the first day of the week; WEEK() mode 1 | -| %V | Week (01..53), where Sunday is the first day of the week; WEEK() mode 2; used with %X | -| %v | Week (01..53), where Monday is the first day of the week; WEEK() mode 3; used with %x | -| %W | Weekday name (Sunday..Saturday) | -| %w | Day of the week (0=Sunday..6=Saturday) | -| %X | Year for the week where Sunday is the first day of the week, numeric, four digits; used with %V | -| %x | Year for the week, where Monday is the first day of the week, numeric, four digits; used with %v | -| %Y | Year, numeric, four digits | -| %y | Year, numeric (two digits) | -| %% | A literal % character | -| %x | x, for any “x” not listed above | -| x | x, for any smallcase/uppercase alphabet except [aydmshiHIMYDSEL] | +| Symbol | Meaning | Presentation | Examples | +|--------|-------------------------------|----------------|---------------------------------------------| +| G | era | text | AD; Anno Domini | +| y | year | year | 2020; 20 | +| D | day-of-year | number(3) | 189 | +| M/L | month-of-year | month | 7; 07; Jul; July | +| d | day-of-month | number(3) | 28 | +| Q/q | quarter-of-year | number/text | 3; 03; Q3; 3rd quarter | +| E | day-of-week | text | Tue; Tuesday | +| F | aligned day of week in month | number(1) | 3 | +| a | am-pm-of-day | am-pm | PM | +| h | clock-hour-of-am-pm (1-12) | number(2) | 12 | +| K | hour-of-am-pm (0-11) | number(2) | 0 | +| k | clock-hour-of-day (1-24) | number(2) | 0 | +| H | hour-of-day (0-23) | number(2) | 0 | +| m | minute-of-hour | number(2) | 30 | +| s | second-of-minute | number(2) | 55 | +| S | fraction-of-second | fraction | 978 | +| V | time-zone ID | zone-id | America/Los_Angeles; Z; -08:30 | +| z | time-zone name | zone-name | Pacific Standard Time; PST | +| O | localized zone-offset | offset-O | GMT+8; GMT+08:00; UTC-08:00 | +| X | zone-offset 'Z' for zero | offset-X | Z; -08; -0830; -08:30; -083015; -08:30:15 | +| x | zone-offset | offset-x | +0000; -08; -0830; -08:30; -083015; -08:30:15 | +| Z | zone-offset | offset-Z | +0000; -0800; -08:00 | +| [ | optional section start | | | +| ] | optional section end | | | + Argument type: STRING/DATE/TIME/TIMESTAMP, STRING @@ -464,89 +167,20 @@ Return type: STRING Example: - os> source=people | eval `DATE_FORMAT('1998-01-31 13:14:15.012345', '%T.%f')` = DATE_FORMAT('1998-01-31 13:14:15.012345', '%T.%f'), `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), '%Y-%b-%D %r')` = DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), '%Y-%b-%D %r') | fields `DATE_FORMAT('1998-01-31 13:14:15.012345', '%T.%f')`, `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), '%Y-%b-%D %r')` - fetched rows / total rows = 1/1 - +------------------------------------------------------+-----------------------------------------------------------------------+ - | DATE_FORMAT('1998-01-31 13:14:15.012345', '%T.%f') | DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), '%Y-%b-%D %r') | - |------------------------------------------------------+-----------------------------------------------------------------------| - | 13:14:15.012345 | 1998-Jan-31st 01:14:15 PM | - +------------------------------------------------------+-----------------------------------------------------------------------+ - - -### `DATETIME` - - -**Description:** - -**Usage:** `DATETIME(timestamp)/ DATETIME(date, to_timezone)` Converts the datetime to a new timezone - -Argument type: timestamp/STRING - -**Return type map:** - -(TIMESTAMP, STRING) -> TIMESTAMP - -(TIMESTAMP) -> TIMESTAMP - - -Converting timestamp with timezone to the second argument timezone. - -Example: - - os> source=people | eval `DATETIME('2004-02-28 23:00:00-10:00', '+10:00')` = DATETIME('2004-02-28 23:00:00-10:00', '+10:00') | fields `DATETIME('2004-02-28 23:00:00-10:00', '+10:00')` - fetched rows / total rows = 1/1 - +---------------------------------------------------+ - | DATETIME('2004-02-28 23:00:00-10:00', '+10:00') | - |---------------------------------------------------| - | 2004-02-29 19:00:00 | - +---------------------------------------------------+ - - -The valid timezone range for convert_tz is (-13:59, +14:00) inclusive. Timezones outside of the range will result in null. - -Example: - - os> source=people | eval `DATETIME('2008-01-01 02:00:00', '-14:00')` = DATETIME('2008-01-01 02:00:00', '-14:00') | fields `DATETIME('2008-01-01 02:00:00', '-14:00')` - fetched rows / total rows = 1/1 - +---------------------------------------------+ - | DATETIME('2008-01-01 02:00:00', '-14:00') | - |---------------------------------------------| - | null | - +---------------------------------------------+ - - -### `DATE_SUB` - - -**Description:** - - -**Usage:** date_sub(date, INTERVAL expr unit) subtracts the interval expr from date. If first argument is TIME, today's date is used; if first argument is DATE, time at midnight is used. - -Argument type: DATE/TIMESTAMP/TIME, INTERVAL - -Return type: TIMESTAMP - -Synonyms: `SUBDATE`_ - -Antonyms: `DATE_ADD`_ - -Example: - - os> source=people | eval `'2008-01-02' - 31d` = DATE_SUB(DATE('2008-01-02'), INTERVAL 31 DAY), `ts '2020-08-26 01:01:01' + 1h` = DATE_SUB(TIMESTAMP('2020-08-26 01:01:01'), INTERVAL 1 HOUR) | fields `'2008-01-02' - 31d`, `ts '2020-08-26 01:01:01' + 1h` + os> source=people | eval `DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS')` = DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS'), `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a')` = DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a') | fields `DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS')`, `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a')` fetched rows / total rows = 1/1 - +----------------------+---------------------------------+ - | '2008-01-02' - 31d | ts '2020-08-26 01:01:01' + 1h | - |----------------------+---------------------------------| - | 2007-12-02 00:00:00 | 2020-08-26 00:01:01 | - +----------------------+---------------------------------+ + +------------------------------------------------------------------+------------------------------------------------------------------------------------+ + | `DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS')` | `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a')` | + |------------------------------------------------------------------+------------------------------------------------------------------------------------| + | 13:14:15.012345 | 1998-Jan-31st 01:14:15 PM | + +------------------------------------------------------+------------------------------------------------------------------------------------------------+ ### `DATEDIFF` **Usage:** Calculates the difference of date parts of given values. If the first argument is time, today's date is used. -Argument type: DATE/TIMESTAMP/TIME, DATE/TIMESTAMP/TIME +Argument type: DATE/TIMESTAMP, DATE/TIMESTAMP Return type: LONG @@ -585,30 +219,6 @@ Example: +---------------------------+ -### `DAYNAME` - -**Description:** - - -**Usage:** - -`dayname(date)` returns the name of the weekday for date, including Monday, Tuesday, Wednesday, Thursday, Friday, Saturday and Sunday. - -Argument type: STRING/DATE/TIMESTAMP - -Return type: STRING - -Example: - - os> source=people | eval `DAYNAME(DATE('2020-08-26'))` = DAYNAME(DATE('2020-08-26')) | fields `DAYNAME(DATE('2020-08-26'))` - fetched rows / total rows = 1/1 - +-------------------------------+ - | DAYNAME(DATE('2020-08-26')) | - |-------------------------------| - | Wednesday | - +-------------------------------+ - - ### `DAYOFMONTH` **Description:** @@ -762,76 +372,29 @@ Example: +-----------------------------------+ -### `EXTRACT` +### `DAYNAME` **Description:** +This function requires Spark 4.0.0+. -**Usage:** - -extract(part FROM date) returns a LONG with digits in order according to the given 'part' arguments. -The specific format of the returned long is determined by the table below. - -Argument type: PART, where PART is one of the following tokens in the table below. - -The format specifiers found in this table are the same as those found in the `DATE_FORMAT`_ function. - -| Part | Format | -|----------------------|---------------| -| MICROSECOND | %f | -| SECOND | %s | -| MINUTE | %i | -| HOUR | %H | -| DAY | %d | -| WEEK | %X | -| MONTH | %m | -| YEAR | %V | -| SECOND_MICROSECOND | %s%f | -| MINUTE_MICROSECOND | %i%s%f | -| MINUTE_SECOND | %i%s | -| HOUR_MICROSECOND | %H%i%s%f | -| HOUR_SECOND | %H%i%s | -| HOUR_MINUTE | %H%i | -| DAY_MICROSECOND | %d%H%i%s%f | -| DAY_SECOND | %d%H%i%s | -| DAY_MINUTE | %d%H%i | -| DAY_HOUR | %d%H% | -| YEAR_MONTH | %V%m | - - -Return type: LONG - -Example: - - os> source=people | eval `extract(YEAR_MONTH FROM "2023-02-07 10:11:12")` = extract(YEAR_MONTH FROM "2023-02-07 10:11:12") | fields `extract(YEAR_MONTH FROM "2023-02-07 10:11:12")` - fetched rows / total rows = 1/1 - +--------------------------------------------------+ - | extract(YEAR_MONTH FROM "2023-02-07 10:11:12") | - |--------------------------------------------------| - | 202302 | - +--------------------------------------------------+ - - -### `FROM_DAYS` - -**Description:** - +**Usage:** -**Usage:** from_days(N) returns the date value given the day number N. +`dayname(date)` returns the name of the weekday for date, including Monday, Tuesday, Wednesday, Thursday, Friday, Saturday and Sunday. -Argument type: INTEGER/LONG +Argument type: STRING/DATE/TIMESTAMP -Return type: DATE +Return type: STRING Example: - os> source=people | eval `FROM_DAYS(733687)` = FROM_DAYS(733687) | fields `FROM_DAYS(733687)` + os> source=people | eval `DAYNAME(DATE('2020-08-26'))` = DAYNAME(DATE('2020-08-26')) | fields `DAYNAME(DATE('2020-08-26'))` fetched rows / total rows = 1/1 - +---------------------+ - | FROM_DAYS(733687) | - |---------------------| - | 2008-10-07 | - +---------------------+ + +-------------------------------+ + | DAYNAME(DATE('2020-08-26')) | + |-------------------------------| + | Wednesday | + +-------------------------------+ ### `FROM_UNIXTIME` @@ -862,37 +425,13 @@ Examples: | 2008-09-01 06:12:27 | +-----------------------------+ - os> source=people | eval `FROM_UNIXTIME(1220249547, '%T')` = FROM_UNIXTIME(1220249547, '%T') | fields `FROM_UNIXTIME(1220249547, '%T')` + os> source=people | eval `FROM_UNIXTIME(1220249547, 'HH:mm:ss')` = FROM_UNIXTIME(1220249547, 'HH:mm:ss') | fields `FROM_UNIXTIME(1220249547, 'HH:mm:ss')` fetched rows / total rows = 1/1 - +-----------------------------------+ - | FROM_UNIXTIME(1220249547, '%T') | - |-----------------------------------| - | 06:12:27 | - +-----------------------------------+ - - -### `GET_FORMAT` - - -**Description:** - - -**Usage:** - -Returns a string value containing string format specifiers based on the input arguments. - -Argument type: TYPE, STRING, where TYPE must be one of the following tokens: [DATE, TIME, TIMESTAMP], and -STRING must be one of the following tokens: ["USA", "JIS", "ISO", "EUR", "INTERNAL"] (" can be replaced by '). - -Examples: - - os> source=people | eval `GET_FORMAT(DATE, 'USA')` = GET_FORMAT(DATE, 'USA') | fields `GET_FORMAT(DATE, 'USA')` - fetched rows / total rows = 1/1 - +---------------------------+ - | GET_FORMAT(DATE, 'USA') | - |---------------------------| - | %m.%d.%Y | - +---------------------------+ + +-----------------------------------------+ + | FROM_UNIXTIME(1220249547, 'HH:mm:ss') | + |-----------------------------------------| + | 06:12:27 | + +-----------------------------------------+ ### `HOUR` @@ -1003,92 +542,32 @@ Example: +---------------------+ -### `MAKEDATE` +### `MAKE_DATE` **Description:** -Returns a date, given `year` and `day-of-year` values. `dayofyear` must be greater than 0 or the result is `NULL`. The result is also `NULL` if either argument is `NULL`. +Returns a date, given `year`, `month` and `day` values. Arguments are rounded to an integer. -**Limitations**: -- Zero `year` interpreted as 2000; -- Negative `year` is not accepted; -- `day-of-year` should be greater than zero; -- `day-of-year` could be greater than 365/366, calculation switches to the next year(s) (see example). - **Specifications**: -1. MAKEDATE(DOUBLE, DOUBLE) -> DATE +1. MAKE_DATE(INTEGER, INTEGER, INTEGER) -> DATE -Argument type: DOUBLE +Argument type: INTEGER, INTEGER, INTEGER Return type: DATE Example: - os> source=people | eval `MAKEDATE(1945, 5.9)` = MAKEDATE(1945, 5.9), `MAKEDATE(1984, 1984)` = MAKEDATE(1984, 1984) | fields `MAKEDATE(1945, 5.9)`, `MAKEDATE(1984, 1984)` - fetched rows / total rows = 1/1 - +-----------------------+------------------------+ - | MAKEDATE(1945, 5.9) | MAKEDATE(1984, 1984) | - |-----------------------+------------------------| - | 1945-01-06 | 1989-06-06 | - +-----------------------+------------------------+ - - -### `MAKETIME` - - -**Description:** - - -Returns a time value calculated from the hour, minute, and second arguments. Returns `NULL` if any of its arguments are `NULL`. -The second argument can have a fractional part, rest arguments are rounded to an integer. - -**Limitations**: -- 24-hour clock is used, available time range is [00:00:00.0 - 23:59:59.(9)]; -- Up to 9 digits of second fraction part is taken (nanosecond precision). - -**Specifications**: - -1. `MAKETIME(DOUBLE, DOUBLE, DOUBLE)` -> TIME - -Argument type: DOUBLE - -Return type: TIME - -Example: - - os> source=people | eval `MAKETIME(20, 30, 40)` = MAKETIME(20, 30, 40), `MAKETIME(20.2, 49.5, 42.100502)` = MAKETIME(20.2, 49.5, 42.100502) | fields `MAKETIME(20, 30, 40)`, `MAKETIME(20.2, 49.5, 42.100502)` + os> source=people | eval `MAKE_DATE(1945, 5, 9)` = MAKEDATE(1945, 5, 9) | fields `MAKEDATE(1945, 5, 9)` fetched rows / total rows = 1/1 - +------------------------+-----------------------------------+ - | MAKETIME(20, 30, 40) | MAKETIME(20.2, 49.5, 42.100502) | - |------------------------+-----------------------------------| - | 20:30:40 | 20:50:42.100502 | - +------------------------+-----------------------------------+ - - -### `MICROSECOND` - -**Description:** - - -**Usage:** microsecond(expr) returns the microseconds from the time or timestamp expression expr as a number in the range from 0 to 999999. - -Argument type: STRING/TIME/TIMESTAMP - -Return type: INTEGER - -Example: - - os> source=people | eval `MICROSECOND(TIME('01:02:03.123456'))` = MICROSECOND(TIME('01:02:03.123456')) | fields `MICROSECOND(TIME('01:02:03.123456'))` - fetched rows / total rows = 1/1 - +----------------------------------------+ - | MICROSECOND(TIME('01:02:03.123456')) | - |----------------------------------------| - | 123456 | - +----------------------------------------+ + +------------------------+ + | MAKEDATE(1945, 5, 9) | + |------------------------+ + | 1945-05-09 | + +------------------------+ ### `MINUTE` @@ -1115,28 +594,6 @@ Example: +----------------------------+ -### `MINUTE_OF_DAY` - -**Description:** - - -**Usage:** minute(time) returns the amount of minutes in the day, in the range of 0 to 1439. - -Argument type: STRING/TIME/TIMESTAMP - -Return type: INTEGER - -Example: - - os> source=people | eval `MINUTE_OF_DAY(TIME('01:02:03'))` = MINUTE_OF_DAY(TIME('01:02:03')) | fields `MINUTE_OF_DAY(TIME('01:02:03'))` - fetched rows / total rows = 1/1 - +-----------------------------------+ - | MINUTE_OF_DAY(TIME('01:02:03')) | - |-----------------------------------| - | 62 | - +-----------------------------------+ - - ### `MINUTE_OF_HOUR` **Description:** @@ -1210,6 +667,8 @@ Example: ### `MONTHNAME` +This function requires Spark 4.0.0+. + **Description:** @@ -1253,51 +712,6 @@ Example: +---------------------+---------------------+ -### `PERIOD_ADD` - - -**Description:** - - -**Usage:** period_add(P, N) add N months to period P (in the format YYMM or YYYYMM). Returns a value in the format YYYYMM. - -Argument type: INTEGER, INTEGER - -Return type: INTEGER - -Example: - - os> source=people | eval `PERIOD_ADD(200801, 2)` = PERIOD_ADD(200801, 2), `PERIOD_ADD(200801, -12)` = PERIOD_ADD(200801, -12) | fields `PERIOD_ADD(200801, 2)`, `PERIOD_ADD(200801, -12)` - fetched rows / total rows = 1/1 - +-------------------------+---------------------------+ - | PERIOD_ADD(200801, 2) | PERIOD_ADD(200801, -12) | - |-------------------------+---------------------------| - | 200803 | 200701 | - +-------------------------+---------------------------+ - - -### `PERIOD_DIFF` - -**Description:** - - -**Usage:** period_diff(P1, P2) returns the number of months between periods P1 and P2 given in the format YYMM or YYYYMM. - -Argument type: INTEGER, INTEGER - -Return type: INTEGER - -Example: - - os> source=people | eval `PERIOD_DIFF(200802, 200703)` = PERIOD_DIFF(200802, 200703), `PERIOD_DIFF(200802, 201003)` = PERIOD_DIFF(200802, 201003) | fields `PERIOD_DIFF(200802, 200703)`, `PERIOD_DIFF(200802, 201003)` - fetched rows / total rows = 1/1 - +-------------------------------+-------------------------------+ - | PERIOD_DIFF(200802, 200703) | PERIOD_DIFF(200802, 201003) | - |-------------------------------+-------------------------------| - | 11 | -25 | - +-------------------------------+-------------------------------+ - - ### `QUARTER` **Description:** @@ -1320,33 +734,6 @@ Example: +-------------------------------+ -### `SEC_TO_TIME` - -**Description:** - - -**Usage:** - -sec_to_time(number) returns the time in HH:mm:ssss[.nnnnnn] format. -Note that the function returns a time between 00:00:00 and 23:59:59. -If an input value is too large (greater than 86399), the function will wrap around and begin returning outputs starting from 00:00:00. -If an input value is too small (less than 0), the function will wrap around and begin returning outputs counting down from 23:59:59. - -Argument type: INTEGER, LONG, DOUBLE, FLOAT - -Return type: TIME - -Example: - - os> source=people | eval `SEC_TO_TIME(3601)` = SEC_TO_TIME(3601) | eval `SEC_TO_TIME(1234.123)` = SEC_TO_TIME(1234.123) | fields `SEC_TO_TIME(3601)`, `SEC_TO_TIME(1234.123)` - fetched rows / total rows = 1/1 - +---------------------+-------------------------+ - | SEC_TO_TIME(3601) | SEC_TO_TIME(1234.123) | - |---------------------+-------------------------| - | 01:00:01 | 00:20:34.123 | - +---------------------+-------------------------+ - - ### `SECOND` **Description:** @@ -1395,56 +782,24 @@ Example: +--------------------------------------+ -### `STR_TO_DATE` - -**Description:** - - -**Usage:** str_to_date(string, string) is used to extract a TIMESTAMP from the first argument string using the formats specified in the second argument string. -The input argument must have enough information to be parsed as a DATE, TIMESTAMP, or TIME. -Acceptable string format specifiers are the same as those used in the `DATE_FORMAT`_ function. -It returns NULL when a statement cannot be parsed due to an invalid pair of arguments, and when 0 is provided for any DATE field. Otherwise, it will return a TIMESTAMP with the parsed values (as well as default values for any field that was not parsed). - -Argument type: STRING, STRING - -Return type: TIMESTAMP - -Example: - - OS> source=people | eval `str_to_date("01,5,2013", "%d,%m,%Y")` = str_to_date("01,5,2013", "%d,%m,%Y") | fields = `str_to_date("01,5,2013", "%d,%m,%Y")` - fetched rows / total rows = 1/1 - +----------------------------------------+ - | str_to_date("01,5,2013", "%d,%m,%Y") | - |----------------------------------------| - | 2013-05-01 00:00:00 | - +----------------------------------------+ - - ### `SUBDATE` **Description:** -**Usage:** subdate(date, INTERVAL expr unit) / subdate(date, days) subtracts the interval expr from date; subdate(date, days) subtracts the second argument as integer number of days from date. -If first argument is TIME, today's date is used; if first argument is DATE, time at midnight is used. +**Usage:** subdate(date, days) subtracts the second argument as integer number of days from date. -Argument type: DATE/TIMESTAMP/TIME, INTERVAL/LONG +Argument type: DATE/TIMESTAMP, LONG **Return type map:** -(DATE/TIMESTAMP/TIME, INTERVAL) -> TIMESTAMP - (DATE, LONG) -> DATE -(TIMESTAMP/TIME, LONG) -> TIMESTAMP - -Synonyms: `DATE_SUB`_ when invoked with the INTERVAL form of the second argument. - -Antonyms: `ADDDATE`_ +Antonyms: `ADDDATE` Example: - os> source=people | eval `'2008-01-02' - 31d` = SUBDATE(DATE('2008-01-02'), INTERVAL 31 DAY), `'2020-08-26' - 1` = SUBDATE(DATE('2020-08-26'), 1), `ts '2020-08-26 01:01:01' - 1` = SUBDATE(TIMESTAMP('2020-08-26 01:01:01'), 1) | fields `'2008-01-02' - 31d`, `'2020-08-26' - 1`, `ts '2020-08-26 01:01:01' - 1` + os> source=people | eval `'2008-01-02' - 31d` = SUBDATE(DATE('2008-01-02'), 31), `'2020-08-26' - 1` = SUBDATE(DATE('2020-08-26'), 1), `ts '2020-08-26 01:01:01' - 1` = SUBDATE(TIMESTAMP('2020-08-26 01:01:01'), 1) | fields `'2008-01-02' - 31d`, `'2020-08-26' - 1`, `ts '2020-08-26 01:01:01' - 1` fetched rows / total rows = 1/1 +----------------------+--------------------+--------------------------------+ | '2008-01-02' - 31d | '2020-08-26' - 1 | ts '2020-08-26 01:01:01' - 1 | @@ -1453,72 +808,12 @@ Example: +----------------------+--------------------+--------------------------------+ -### `SUBTIME` - -**Description:** - - -**Usage:** subtime(expr1, expr2) subtracts expr2 from expr1 and returns the result. If argument is TIME, today's date is used; if argument is DATE, time at midnight is used. - -Argument type: DATE/TIMESTAMP/TIME, DATE/TIMESTAMP/TIME - -**Return type map:** - -(DATE/TIMESTAMP, DATE/TIMESTAMP/TIME) -> TIMESTAMP - -(TIME, DATE/TIMESTAMP/TIME) -> TIME - -Antonyms: `ADDTIME`_ - -Example: - - os> source=people | eval `'2008-12-12' - 0` = SUBTIME(DATE('2008-12-12'), DATE('2008-11-15')) | fields `'2008-12-12' - 0` - fetched rows / total rows = 1/1 - +---------------------+ - | '2008-12-12' - 0 | - |---------------------| - | 2008-12-12 00:00:00 | - +---------------------+ - - os> source=people | eval `'23:59:59' - 0` = SUBTIME(TIME('23:59:59'), DATE('2004-01-01')) | fields `'23:59:59' - 0` - fetched rows / total rows = 1/1 - +------------------+ - | '23:59:59' - 0 | - |------------------| - | 23:59:59 | - +------------------+ - - os> source=people | eval `'2004-01-01' - '23:59:59'` = SUBTIME(DATE('2004-01-01'), TIME('23:59:59')) | fields `'2004-01-01' - '23:59:59'` - fetched rows / total rows = 1/1 - +-----------------------------+ - | '2004-01-01' - '23:59:59' | - |-----------------------------| - | 2003-12-31 00:00:01 | - +-----------------------------+ - - os> source=people | eval `'10:20:30' - '00:05:42'` = SUBTIME(TIME('10:20:30'), TIME('00:05:42')) | fields `'10:20:30' - '00:05:42'` - fetched rows / total rows = 1/1 - +---------------------------+ - | '10:20:30' - '00:05:42' | - |---------------------------| - | 10:14:48 | - +---------------------------+ - - os> source=people | eval `'2007-03-01 10:20:30' - '20:40:50'` = SUBTIME(TIMESTAMP('2007-03-01 10:20:30'), TIMESTAMP('2002-03-04 20:40:50')) | fields `'2007-03-01 10:20:30' - '20:40:50'` - fetched rows / total rows = 1/1 - +--------------------------------------+ - | '2007-03-01 10:20:30' - '20:40:50' | - |--------------------------------------| - | 2007-02-28 13:39:40 | - +--------------------------------------+ - - ### `SYSDATE` **Description:** -Returns the current date and time as a value in 'YYYY-MM-DD hh:mm:ss[.nnnnnn]'. +Returns the current date and time as a value in 'YYYY-MM-DD hh:mm:ss.nnnnnn'. SYSDATE() returns the time at which it executes. This differs from the behavior for `NOW() <#now>`_, which returns a constant time that indicates the time at which the statement began to execute. If the argument is given, it specifies a fractional seconds precision from 0 to 6, the return value includes a fractional seconds part of that many digits. @@ -1526,152 +821,17 @@ Optional argument type: INTEGER Return type: TIMESTAMP -Specification: SYSDATE([INTEGER]) -> TIMESTAMP - Example: - > source=people | eval `value_1` = SYSDATE(), `value_2` = SYSDATE(6) | fields `value_1`, `value_2` - fetched rows / total rows = 1/1 - +---------------------+----------------------------+ - | value_1 | value_2 | - |---------------------+----------------------------| - | 2022-08-02 15:39:05 | 2022-08-02 15:39:05.123456 | - +---------------------+----------------------------+ - - -### `TIME` - -**Description:** - - -**Usage:** time(expr) constructs a time type with the input string expr as a time. If the argument is of date/time/timestamp, it extracts the time value part from the expression. - -Argument type: STRING/DATE/TIME/TIMESTAMP - -Return type: TIME - -Example: - - os> source=people | eval `TIME('13:49:00')` = TIME('13:49:00') | fields `TIME('13:49:00')` - fetched rows / total rows = 1/1 - +--------------------+ - | TIME('13:49:00') | - |--------------------| - | 13:49:00 | - +--------------------+ - - os> source=people | eval `TIME('13:49')` = TIME('13:49') | fields `TIME('13:49')` - fetched rows / total rows = 1/1 - +-----------------+ - | TIME('13:49') | - |-----------------| - | 13:49:00 | - +-----------------+ - - os> source=people | eval `TIME('2020-08-26 13:49:00')` = TIME('2020-08-26 13:49:00') | fields `TIME('2020-08-26 13:49:00')` - fetched rows / total rows = 1/1 - +-------------------------------+ - | TIME('2020-08-26 13:49:00') | - |-------------------------------| - | 13:49:00 | - +-------------------------------+ - - os> source=people | eval `TIME('2020-08-26 13:49')` = TIME('2020-08-26 13:49') | fields `TIME('2020-08-26 13:49')` + > source=people | eval `SYSDATE()` = SYSDATE() | fields `SYSDATE()` fetched rows / total rows = 1/1 +----------------------------+ - | TIME('2020-08-26 13:49') | - |----------------------------| - | 13:49:00 | + | SYSDATE() | + |----------------------------+ + | 2022-08-02 15:39:05.123456 | +----------------------------+ -### `TIME_FORMAT` - - -**Description:** - - -**Usage:** - -time_format(time, format) formats the time argument using the specifiers in the format argument. -This supports a subset of the time format specifiers available for the `date_format`_ function. -Using date format specifiers supported by `date_format`_ will return 0 or null. -Acceptable format specifiers are listed in the table below. -If an argument of type DATE is passed in, it is treated as a TIMESTAMP at midnight (i.e., 00:00:00). - -| Specifier | **Description** | -|-----------|-----------------| -| %f | Microseconds (000000..999999) | -| %H | Hour (00..23) | -| %h | Hour (01..12) | -| %I | Hour (01..12) | -| %i | Minutes, numeric (00..59) | -| %p | AM or PM | -| %r | Time, 12-hour (hh:mm:ss followed by AM or PM) | -| %S | Seconds (00..59) | -| %s | Seconds (00..59) | -| %T | Time, 24-hour (hh:mm:ss) | - - -Argument type: STRING/DATE/TIME/TIMESTAMP, STRING - -Return type: STRING - -Example: - - os> source=people | eval `TIME_FORMAT('1998-01-31 13:14:15.012345', '%f %H %h %I %i %p %r %S %s %T')` = TIME_FORMAT('1998-01-31 13:14:15.012345', '%f %H %h %I %i %p %r %S %s %T') | fields `TIME_FORMAT('1998-01-31 13:14:15.012345', '%f %H %h %I %i %p %r %S %s %T')` - fetched rows / total rows = 1/1 - +------------------------------------------------------------------------------+ - | TIME_FORMAT('1998-01-31 13:14:15.012345', '%f %H %h %I %i %p %r %S %s %T') | - |------------------------------------------------------------------------------| - | 012345 13 01 01 14 PM 01:14:15 PM 15 15 13:14:15 | - +------------------------------------------------------------------------------+ - - -### `TIME_TO_SEC` - -**Description:** - - -**Usage:** time_to_sec(time) returns the time argument, converted to seconds. - -Argument type: STRING/TIME/TIMESTAMP - -Return type: LONG - -Example: - - os> source=people | eval `TIME_TO_SEC(TIME('22:23:00'))` = TIME_TO_SEC(TIME('22:23:00')) | fields `TIME_TO_SEC(TIME('22:23:00'))` - fetched rows / total rows = 1/1 - +---------------------------------+ - | TIME_TO_SEC(TIME('22:23:00')) | - |---------------------------------| - | 80580 | - +---------------------------------+ - - -### `TIMEDIFF` - -**Description:** - - -**Usage:** returns the difference between two time expressions as a time. - -Argument type: TIME, TIME - -Return type: TIME - -Example: - - os> source=people | eval `TIMEDIFF('23:59:59', '13:00:00')` = TIMEDIFF('23:59:59', '13:00:00') | fields `TIMEDIFF('23:59:59', '13:00:00')` - fetched rows / total rows = 1/1 - +------------------------------------+ - | TIMEDIFF('23:59:59', '13:00:00') | - |------------------------------------| - | 10:59:59 | - +------------------------------------+ - - ### `TIMESTAMP` **Description:** @@ -1699,335 +859,257 @@ Example: +------------------------------------+------------------------------------------------------+ -### `TIMESTAMPADD` - -**Description:** - - -**Usage:** Returns a TIMESTAMP value based on a passed in DATE/TIME/TIMESTAMP/STRING argument and an INTERVAL and INTEGER argument which determine the amount of time to be added. -If the third argument is a STRING, it must be formatted as a valid TIMESTAMP. If only a TIME is provided, a TIMESTAMP is still returned with the DATE portion filled in using the current date. -If the third argument is a DATE, it will be automatically converted to a TIMESTAMP. - -Argument type: INTERVAL, INTEGER, DATE/TIME/TIMESTAMP/STRING - -INTERVAL must be one of the following tokens: `[MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR]` - -Examples: - - os> source=people | eval `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | eval `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` = TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | fields `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')`, `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` - fetched rows / total rows = 1/1 - +------------------------------------------------+----------------------------------------------------+ - | TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | - |------------------------------------------------+----------------------------------------------------| - | 2000-01-18 00:00:00 | 1999-10-01 00:00:00 | - +------------------------------------------------+----------------------------------------------------+ - +### `UNIX_TIMESTAMP` -### `TIMESTAMPDIFF` **Description:** -**Usage:** - -`TIMESTAMPDIFF(interval, start, end)` returns the difference between the start and end date/times in interval units. -If a TIME is provided as an argument, it will be converted to a TIMESTAMP with the DATE portion filled in using the current date. -Arguments will be automatically converted to a TIME/TIMESTAMP when appropriate. -Any argument that is a STRING must be formatted as a valid TIMESTAMP. - -Argument type: INTERVAL, DATE/TIME/TIMESTAMP/STRING, DATE/TIME/TIMESTAMP/STRING - -INTERVAL must be one of the following tokens: [MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] - -Examples: - - os> source=people | eval `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | eval `TIMESTAMPDIFF(SECOND, time('00:00:23'), time('00:00:00'))` = TIMESTAMPDIFF(SECOND, time('00:00:23'), time('00:00:00')) | fields `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')`, `TIMESTAMPDIFF(SECOND, time('00:00:23'), time('00:00:00'))` - fetched rows / total rows = 1/1 - +---------------------------------------------------------------------+-------------------------------------------------------------+ - | TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | TIMESTAMPDIFF(SECOND, time('00:00:23'), time('00:00:00')) | - |---------------------------------------------------------------------+-------------------------------------------------------------| - | 4 | -23 | - +---------------------------------------------------------------------+-------------------------------------------------------------+ - - -### `TO_DAYS` - -**Description:** - +**Usage**: -**Usage:** to_days(date) returns the day number (the number of days since year 0) of the given date. Returns NULL if date is invalid. +Converts given argument to Unix time (seconds since Epoch - very beginning of year 1970). If no argument given, it returns the current Unix time. +The date argument may be a DATE, or TIMESTAMP string, or a number in YYMMDD, YYMMDDhhmmss, YYYYMMDD, or YYYYMMDDhhmmss format. If the argument includes a time part, it may optionally include a fractional seconds part. +If argument is in invalid format or outside of range 1970-01-01 00:00:00 - 3001-01-18 23:59:59.999999 (0 to 32536771199.999999 epoch time), function returns NULL. +You can use `FROM_UNIXTIME`_ to do reverse conversion. -Argument type: STRING/DATE/TIMESTAMP +Argument type: /DOUBLE/DATE/TIMESTAMP -Return type: LONG +Return type: DOUBLE Example: - os> source=people | eval `TO_DAYS(DATE('2008-10-07'))` = TO_DAYS(DATE('2008-10-07')) | fields `TO_DAYS(DATE('2008-10-07'))` + os> source=people | eval `UNIX_TIMESTAMP(double)` = UNIX_TIMESTAMP(20771122143845), `UNIX_TIMESTAMP(timestamp)` = UNIX_TIMESTAMP(TIMESTAMP('1996-11-15 17:05:42')) | fields `UNIX_TIMESTAMP(double)`, `UNIX_TIMESTAMP(timestamp)` fetched rows / total rows = 1/1 - +-------------------------------+ - | TO_DAYS(DATE('2008-10-07')) | - |-------------------------------| - | 733687 | - +-------------------------------+ - + +--------------------------+-----------------------------+ + | UNIX_TIMESTAMP(double) | UNIX_TIMESTAMP(timestamp) | + |--------------------------+-----------------------------| + | 3404817525.0 | 848077542.0 | + +--------------------------+-----------------------------+ -### `TO_SECONDS` +### `WEEK` **Description:** +**Usage:** week(date) returns the week number for date. + -**Usage:** to_seconds(date) returns the number of seconds since the year 0 of the given value. Returns NULL if value is invalid. -An argument of a LONG type can be used. It must be formatted as YMMDD, YYMMDD, YYYMMDD or YYYYMMDD. Note that a LONG type argument cannot have leading 0s as it will be parsed using an octal numbering system. +Argument type: DATE/TIMESTAMP/STRING -Argument type: STRING/LONG/DATE/TIME/TIMESTAMP +Return type: INTEGER -Return type: LONG +Synonyms: `WEEK_OF_YEAR`_ Example: - os> source=people | eval `TO_SECONDS(DATE('2008-10-07'))` = TO_SECONDS(DATE('2008-10-07')) | eval `TO_SECONDS(950228)` = TO_SECONDS(950228) | fields `TO_SECONDS(DATE('2008-10-07'))`, `TO_SECONDS(950228)` + os> source=people | eval `WEEK(DATE('2008-02-20'))` = WEEK(DATE('2008-02-20')) | fields `WEEK(DATE('2008-02-20'))` fetched rows / total rows = 1/1 - +----------------------------------+----------------------+ - | TO_SECONDS(DATE('2008-10-07')) | TO_SECONDS(950228) | - |----------------------------------+----------------------| - | 63390556800 | 62961148800 | - +----------------------------------+----------------------+ + +----------------------------+ + | WEEK(DATE('2008-02-20')) | + |----------------------------+ + | 8 | + +----------------------------+ -### `UNIX_TIMESTAMP` - +### `WEEKDAY` **Description:** -**Usage**: +**Usage:** weekday(date) returns the weekday index for date (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). -Converts given argument to Unix time (seconds since Epoch - very beginning of year 1970). If no argument given, it returns the current Unix time. -The date argument may be a DATE, or TIMESTAMP string, or a number in YYMMDD, YYMMDDhhmmss, YYYYMMDD, or YYYYMMDDhhmmss format. If the argument includes a time part, it may optionally include a fractional seconds part. -If argument is in invalid format or outside of range 1970-01-01 00:00:00 - 3001-01-18 23:59:59.999999 (0 to 32536771199.999999 epoch time), function returns NULL. -You can use `FROM_UNIXTIME`_ to do reverse conversion. +It is similar to the `dayofweek`_ function, but returns different indexes for each day. -Argument type: /DOUBLE/DATE/TIMESTAMP +Argument type: STRING/DATE/TIME/TIMESTAMP -Return type: DOUBLE +Return type: INTEGER Example: - os> source=people | eval `UNIX_TIMESTAMP(double)` = UNIX_TIMESTAMP(20771122143845), `UNIX_TIMESTAMP(timestamp)` = UNIX_TIMESTAMP(TIMESTAMP('1996-11-15 17:05:42')) | fields `UNIX_TIMESTAMP(double)`, `UNIX_TIMESTAMP(timestamp)` + os> source=people | eval `weekday(DATE('2020-08-26'))` = weekday(DATE('2020-08-26')) | eval `weekday(DATE('2020-08-27'))` = weekday(DATE('2020-08-27')) | fields `weekday(DATE('2020-08-26'))`, `weekday(DATE('2020-08-27'))` fetched rows / total rows = 1/1 - +--------------------------+-----------------------------+ - | UNIX_TIMESTAMP(double) | UNIX_TIMESTAMP(timestamp) | - |--------------------------+-----------------------------| - | 3404817525.0 | 848077542.0 | - +--------------------------+-----------------------------+ + +-------------------------------+-------------------------------+ + | weekday(DATE('2020-08-26')) | weekday(DATE('2020-08-27')) | + |-------------------------------+-------------------------------| + | 2 | 3 | + +-------------------------------+-------------------------------+ -### `UTC_DATE` +### `WEEK_OF_YEAR` **Description:** -Returns the current UTC date as a value in 'YYYY-MM-DD'. +**Usage:** week_of_year(date) returns the week number for date. -Return type: DATE -Specification: UTC_DATE() -> DATE +Argument type: DATE/TIMESTAMP/STRING + +Return type: INTEGER + +Synonyms: `WEEK`_ Example: - > source=people | eval `UTC_DATE()` = UTC_DATE() | fields `UTC_DATE()` + os> source=people | eval `WEEK_OF_YEAR(DATE('2008-02-20'))` = WEEK(DATE('2008-02-20'))| fields `WEEK_OF_YEAR(DATE('2008-02-20'))` fetched rows / total rows = 1/1 - +--------------+ - | UTC_DATE() | - |--------------| - | 2022-10-03 | - +--------------+ - + +------------------------------------+ + | WEEK_OF_YEAR(DATE('2008-02-20')) | + |------------------------------------+ + | 8 | + +------------------------------------+ -### `UTC_TIME` +### `YEAR` **Description:** -Returns the current UTC time as a value in 'hh:mm:ss'. +**Usage:** year(date) returns the year for date, in the range 1000 to 9999, or 0 for the “zero” date. -Return type: TIME +Argument type: STRING/DATE/TIMESTAMP -Specification: UTC_TIME() -> TIME +Return type: INTEGER Example: - > source=people | eval `UTC_TIME()` = UTC_TIME() | fields `UTC_TIME()` + os> source=people | eval `YEAR(DATE('2020-08-26'))` = YEAR(DATE('2020-08-26')) | fields `YEAR(DATE('2020-08-26'))` fetched rows / total rows = 1/1 - +--------------+ - | UTC_TIME() | - |--------------| - | 17:54:27 | - +--------------+ + +----------------------------+ + | YEAR(DATE('2020-08-26')) | + |----------------------------| + | 2020 | + +----------------------------+ -### `UTC_TIMESTAMP` +### `DATE_ADD` **Description:** +Usage: date_add(date, INTERVAL expr unit) adds the interval expr to date. -Returns the current UTC timestamp as a value in 'YYYY-MM-DD hh:mm:ss'. +Argument type: DATE, INTERVAL -Return type: TIMESTAMP +Return type: DATE -Specification: UTC_TIMESTAMP() -> TIMESTAMP +Antonyms: `DATE_SUB` -Example: +Example:: - > source=people | eval `UTC_TIMESTAMP()` = UTC_TIMESTAMP() | fields `UTC_TIMESTAMP()` + os> source=people | eval `'2020-08-26' + 1d` = DATE_ADD(DATE('2020-08-26'), INTERVAL 1 DAY) | fields `'2020-08-26' + 1d` fetched rows / total rows = 1/1 +---------------------+ - | UTC_TIMESTAMP() | - |---------------------| - | 2022-10-03 17:54:28 | + | '2020-08-26' + 1d | + |---------------------+ + | 2020-08-27 | +---------------------+ -### `WEEK` +### `DATE_SUB` **Description:** -**Usage:** week(date[, mode]) returns the week number for date. If the mode argument is omitted, the default mode 0 is used. +Usage: date_sub(date, INTERVAL expr unit) subtracts the interval expr from date. -| Mode | First day of week | Range | Week 1 is the first week... | -|------|-------------------|-------|-----------------------------| -| 0 | Sunday | 0-53 | with a Sunday in this year | -| 1 | Monday | 0-53 | with 4 or more days this year | -| 2 | Sunday | 1-53 | with a Sunday in this year | -| 3 | Monday | 1-53 | with 4 or more days this year | -| 4 | Sunday | 0-53 | with 4 or more days this year | -| 5 | Monday | 0-53 | with a Monday in this year | -| 6 | Sunday | 1-53 | with 4 or more days this year | -| 7 | Monday | 1-53 | with a Monday in this year | +Argument type: DATE, INTERVAL +Return type: DATE -Argument type: DATE/TIMESTAMP/STRING - -Return type: INTEGER +Antonyms: `DATE_ADD` -Synonyms: `WEEK_OF_YEAR`_ +Example:: -Example: - - os> source=people | eval `WEEK(DATE('2008-02-20'))` = WEEK(DATE('2008-02-20')), `WEEK(DATE('2008-02-20'), 1)` = WEEK(DATE('2008-02-20'), 1) | fields `WEEK(DATE('2008-02-20'))`, `WEEK(DATE('2008-02-20'), 1)` + os> source=people | eval `'2008-01-02' - 31d` = DATE_SUB(DATE('2008-01-02'), INTERVAL 31 DAY) | fields `'2008-01-02' - 31d` fetched rows / total rows = 1/1 - +----------------------------+-------------------------------+ - | WEEK(DATE('2008-02-20')) | WEEK(DATE('2008-02-20'), 1) | - |----------------------------+-------------------------------| - | 7 | 8 | - +----------------------------+-------------------------------+ + +---------------------+ + | '2008-01-02' - 31d | + |---------------------+ + | 2007-12-02 | + +---------------------+ -### `WEEKDAY` +### `TIMESTAMPADD` **Description:** +Usage: Returns a TIMESTAMP value based on a passed in DATE/TIMESTAMP/STRING argument and an INTERVAL and INTEGER argument which determine the amount of time to be added. +If the third argument is a STRING, it must be formatted as a valid TIMESTAMP. +If the third argument is a DATE, it will be automatically converted to a TIMESTAMP. -**Usage:** weekday(date) returns the weekday index for date (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). - -It is similar to the `dayofweek`_ function, but returns different indexes for each day. +Argument type: INTERVAL, INTEGER, DATE/TIMESTAMP/STRING -Argument type: STRING/DATE/TIME/TIMESTAMP +INTERVAL must be one of the following tokens: [SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] -Return type: INTEGER +Examples:: -Example: - - os> source=people | eval `weekday(DATE('2020-08-26'))` = weekday(DATE('2020-08-26')) | eval `weekday(DATE('2020-08-27'))` = weekday(DATE('2020-08-27')) | fields `weekday(DATE('2020-08-26'))`, `weekday(DATE('2020-08-27'))` + os> source=people | eval `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | eval `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` = TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | fields `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')`, `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` fetched rows / total rows = 1/1 - +-------------------------------+-------------------------------+ - | weekday(DATE('2020-08-26')) | weekday(DATE('2020-08-27')) | - |-------------------------------+-------------------------------| - | 2 | 3 | - +-------------------------------+-------------------------------+ + +----------------------------------------------+--------------------------------------------------+ + | TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | + |----------------------------------------------+--------------------------------------------------| + | 2000-01-18 00:00:00 | 1999-10-01 00:00:00 | + +----------------------------------------------+--------------------------------------------------+ -### `WEEK_OF_YEAR` +### `TIMESTAMPDIFF` **Description:** +Usage: TIMESTAMPDIFF(interval, start, end) returns the difference between the start and end date/times in interval units. +Arguments will be automatically converted to a ]TIMESTAMP when appropriate. +Any argument that is a STRING must be formatted as a valid TIMESTAMP. -**Usage:** week_of_year(date[, mode]) returns the week number for date. If the mode argument is omitted, the default mode 0 is used. - -| Mode | First day of week | Range | Week 1 is the first week ... | -|------|-------------------|-------|------------------------------| -| 0 | Sunday | 0-53 | with a Sunday in this year | -| 1 | Monday | 0-53 | with 4 or more days this year| -| 2 | Sunday | 1-53 | with a Sunday in this year | -| 3 | Monday | 1-53 | with 4 or more days this year| -| 4 | Sunday | 0-53 | with 4 or more days this year| -| 5 | Monday | 0-53 | with a Monday in this year | -| 6 | Sunday | 1-53 | with 4 or more days this year| -| 7 | Monday | 1-53 | with a Monday in this year | - - -Argument type: DATE/TIMESTAMP/STRING +Argument type: INTERVAL, DATE/TIMESTAMP/STRING, DATE/TIMESTAMP/STRING -Return type: INTEGER +INTERVAL must be one of the following tokens: [SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] -Synonyms: `WEEK`_ +Examples:: -Example: - - os> source=people | eval `WEEK_OF_YEAR(DATE('2008-02-20'))` = WEEK(DATE('2008-02-20')), `WEEK_OF_YEAR(DATE('2008-02-20'), 1)` = WEEK_OF_YEAR(DATE('2008-02-20'), 1) | fields `WEEK_OF_YEAR(DATE('2008-02-20'))`, `WEEK_OF_YEAR(DATE('2008-02-20'), 1)` + os> source=people | eval `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | eval `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))` = TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00')) | fields `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')`, `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))` fetched rows / total rows = 1/1 - +------------------------------------+---------------------------------------+ - | WEEK_OF_YEAR(DATE('2008-02-20')) | WEEK_OF_YEAR(DATE('2008-02-20'), 1) | - |------------------------------------+---------------------------------------| - | 7 | 8 | - +------------------------------------+---------------------------------------+ + +-------------------------------------------------------------------+-------------------------------------------------------------------------------------------+ + | TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00')) | + |-------------------------------------------------------------------+-------------------------------------------------------------------------------------------| + | 4 | -23 | + +-------------------------------------------------------------------+-------------------------------------------------------------------------------------------+ -### `YEAR` +### `UTC_TIMESTAMP` **Description:** +Returns the current UTC timestamp as a value in 'YYYY-MM-DD hh:mm:ss'. -**Usage:** year(date) returns the year for date, in the range 1000 to 9999, or 0 for the “zero” date. - -Argument type: STRING/DATE/TIMESTAMP +Return type: TIMESTAMP -Return type: INTEGER +Specification: UTC_TIMESTAMP() -> TIMESTAMP -Example: +Example:: - os> source=people | eval `YEAR(DATE('2020-08-26'))` = YEAR(DATE('2020-08-26')) | fields `YEAR(DATE('2020-08-26'))` + > source=people | eval `UTC_TIMESTAMP()` = UTC_TIMESTAMP() | fields `UTC_TIMESTAMP()` fetched rows / total rows = 1/1 - +----------------------------+ - | YEAR(DATE('2020-08-26')) | - |----------------------------| - | 2020 | - +----------------------------+ - + +---------------------+ + | UTC_TIMESTAMP() | + |---------------------| + | 2022-10-03 17:54:28 | + +---------------------+ -### `YEARWEEK` +### `CURRENT_TIMEZONE` **Description:** +Returns the current local timezone. -**Usage:** yearweek(date) returns the year and week for date as an integer. It accepts and optional mode arguments aligned with those available for the `WEEK`_ function. - -Argument type: STRING/DATE/TIME/TIMESTAMP - -Return type: INTEGER +Return type: STRING -Example: +Example:: - os> source=people | eval `YEARWEEK('2020-08-26')` = YEARWEEK('2020-08-26') | eval `YEARWEEK('2019-01-05', 1)` = YEARWEEK('2019-01-05', 1) | fields `YEARWEEK('2020-08-26')`, `YEARWEEK('2019-01-05', 1)` + > source=people | eval `CURRENT_TIMEZONE()` = CURRENT_TIMEZONE() | fields `CURRENT_TIMEZONE()` fetched rows / total rows = 1/1 - +--------------------------+-----------------------------+ - | YEARWEEK('2020-08-26') | YEARWEEK('2019-01-05', 1) | - |--------------------------+-----------------------------| - | 202034 | 201901 | - +--------------------------+-----------------------------+ - + +------------------------+ + | CURRENT_TIMEZONE() | + |------------------------| + | America/Chicago | + +------------------------+ diff --git a/docs/ppl-lang/functions/ppl-expressions.md b/docs/ppl-lang/functions/ppl-expressions.md index 6315663c2..171f97385 100644 --- a/docs/ppl-lang/functions/ppl-expressions.md +++ b/docs/ppl-lang/functions/ppl-expressions.md @@ -127,7 +127,7 @@ OR operator : NOT operator : - os> source=accounts | where not age in (32, 33) | fields age ; + os> source=accounts | where age not in (32, 33) | fields age ; fetched rows / total rows = 2/2 +-------+ | age | diff --git a/docs/ppl-lang/functions/ppl-ip.md b/docs/ppl-lang/functions/ppl-ip.md new file mode 100644 index 000000000..fb0b468ba --- /dev/null +++ b/docs/ppl-lang/functions/ppl-ip.md @@ -0,0 +1,35 @@ +## PPL IP Address Functions + +### `CIDRMATCH` + +**Description** + +`CIDRMATCH(ip, cidr)` checks if ip is within the specified cidr range. + +**Argument type:** + - STRING, STRING + - Return type: **BOOLEAN** + +Example: + + os> source=ips | where cidrmatch(ip, '192.169.1.0/24') | fields ip + fetched rows / total rows = 1/1 + +--------------+ + | ip | + |--------------| + | 192.169.1.5 | + +--------------+ + + os> source=ipsv6 | where cidrmatch(ip, '2003:db8::/32') | fields ip + fetched rows / total rows = 1/1 + +-----------------------------------------+ + | ip | + |-----------------------------------------| + | 2003:0db8:0000:0000:0000:0000:0000:0000 | + +-----------------------------------------+ + +Note: + - `ip` can be an IPv4 or an IPv6 address + - `cidr` can be an IPv4 or an IPv6 block + - `ip` and `cidr` must be either both IPv4 or both IPv6 + - `ip` and `cidr` must both be valid and non-empty/non-null \ No newline at end of file diff --git a/docs/ppl-lang/functions/ppl-json.md b/docs/ppl-lang/functions/ppl-json.md new file mode 100644 index 000000000..1953e8c70 --- /dev/null +++ b/docs/ppl-lang/functions/ppl-json.md @@ -0,0 +1,237 @@ +## PPL JSON Functions + +### `JSON` + +**Description** + +`json(value)` Evaluates whether a value can be parsed as JSON. Returns the json string if valid, null otherwise. + +**Argument type:** STRING/JSON_ARRAY/JSON_OBJECT + +**Return type:** STRING + +A STRING expression of a valid JSON object format. + +Example: + + os> source=people | eval `valid_json()` = json('[1,2,3,{"f1":1,"f2":[5,6]},4]') | fields valid_json + fetched rows / total rows = 1/1 + +---------------------------------+ + | valid_json | + +---------------------------------+ + | [1,2,3,{"f1":1,"f2":[5,6]},4] | + +---------------------------------+ + + os> source=people | eval `invalid_json()` = json('{"invalid": "json"') | fields invalid_json + fetched rows / total rows = 1/1 + +----------------+ + | invalid_json | + +----------------+ + | null | + +----------------+ + + +### `JSON_OBJECT` + +**Description** + +`json_object(, [, , ]...)` returns a JSON object from members of key-value pairs. + +**Argument type:** +- A \ must be STRING. +- A \ can be any data types. + +**Return type:** JSON_OBJECT (Spark StructType) + +A StructType expression of a valid JSON object. + +Example: + + os> source=people | eval result = json(json_object('key', 123.45)) | fields result + fetched rows / total rows = 1/1 + +------------------+ + | result | + +------------------+ + | {"key":123.45} | + +------------------+ + + os> source=people | eval result = json(json_object('outer', json_object('inner', 123.45))) | fields result + fetched rows / total rows = 1/1 + +------------------------------+ + | result | + +------------------------------+ + | {"outer":{"inner":123.45}} | + +------------------------------+ + + +### `JSON_ARRAY` + +**Description** + +`json_array(...)` Creates a JSON ARRAY using a list of values. + +**Argument type:** +- A \ can be any kind of value such as string, number, or boolean. + +**Return type:** ARRAY (Spark ArrayType) + +An array of any supported data type for a valid JSON array. + +Example: + + os> source=people | eval `json_array` = json_array(1, 2, 0, -1, 1.1, -0.11) + fetched rows / total rows = 1/1 + +----------------------------+ + | json_array | + +----------------------------+ + | 1.0,2.0,0.0,-1.0,1.1,-0.11 | + +----------------------------+ + + os> source=people | eval `json_array_object` = json(json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11))) + fetched rows / total rows = 1/1 + +----------------------------------------+ + | json_array_object | + +----------------------------------------+ + | {"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]} | + +----------------------------------------+ + +### `JSON_ARRAY_LENGTH` + +**Description** + +`json_array_length(jsonArray)` Returns the number of elements in the outermost JSON array. + +**Argument type:** STRING/JSON_ARRAY + +A STRING expression of a valid JSON array format, or JSON_ARRAY object. + +**Return type:** INTEGER + +`NULL` is returned in case of any other valid JSON string, `NULL` or an invalid JSON. + +Example: + + os> source=people | eval `lenght1` = json_array_length('[1,2,3,4]'), `lenght2` = json_array_length('[1,2,3,{"f1":1,"f2":[5,6]},4]'), `not_array` = json_array_length('{"key": 1}') + fetched rows / total rows = 1/1 + +-----------+-----------+-------------+ + | lenght1 | lenght2 | not_array | + +-----------+-----------+-------------+ + | 4 | 5 | null | + +-----------+-----------+-------------+ + + os> source=people | eval `json_array` = json_array_length(json_array(1,2,3,4)), `empty_array` = json_array_length(json_array()) + fetched rows / total rows = 1/1 + +--------------+---------------+ + | json_array | empty_array | + +--------------+---------------+ + | 4 | 0 | + +--------------+---------------+ + +### `JSON_EXTRACT` + +**Description** + +`json_extract(jsonStr, path)` Extracts json object from a json string based on json path specified. Return null if the input json string is invalid. + +**Argument type:** STRING, STRING + +**Return type:** STRING + +A STRING expression of a valid JSON object format. + +`NULL` is returned in case of an invalid JSON. + +Example: + + os> source=people | eval `json_extract('{"a":"b"}', '$.a')` = json_extract('{"a":"b"}', '$a') + fetched rows / total rows = 1/1 + +----------------------------------+ + | json_extract('{"a":"b"}', 'a') | + +----------------------------------+ + | b | + +----------------------------------+ + + os> source=people | eval `json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[1].b')` = json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[1].b') + fetched rows / total rows = 1/1 + +-----------------------------------------------------------+ + | json_extract('{"a":[{"b":1.0},{"b":2.0}]}', '$.a[1].b') | + +-----------------------------------------------------------+ + | 2.0 | + +-----------------------------------------------------------+ + + os> source=people | eval `json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[*].b')` = json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[*].b') + fetched rows / total rows = 1/1 + +-----------------------------------------------------------+ + | json_extract('{"a":[{"b":1.0},{"b":2.0}]}', '$.a[*].b') | + +-----------------------------------------------------------+ + | [1.0,2.0] | + +-----------------------------------------------------------+ + + os> source=people | eval `invalid_json` = json_extract('{"invalid": "json"') + fetched rows / total rows = 1/1 + +----------------+ + | invalid_json | + +----------------+ + | null | + +----------------+ + + +### `JSON_KEYS` + +**Description** + +`json_keys(jsonStr)` Returns all the keys of the outermost JSON object as an array. + +**Argument type:** STRING + +A STRING expression of a valid JSON object format. + +**Return type:** ARRAY[STRING] + +`NULL` is returned in case of any other valid JSON string, or an empty string, or an invalid JSON. + +Example: + + os> source=people | eval `keys` = json_keys('{"f1":"abc","f2":{"f3":"a","f4":"b"}}') + fetched rows / total rows = 1/1 + +------------+ + | keus | + +------------+ + | [f1, f2] | + +------------+ + + os> source=people | eval `keys` = json_keys('[1,2,3,{"f1":1,"f2":[5,6]},4]') + fetched rows / total rows = 1/1 + +--------+ + | keys | + +--------+ + | null | + +--------+ + +### `JSON_VALID` + +**Description** + +`json_valid(jsonStr)` Evaluates whether a JSON string uses valid JSON syntax and returns TRUE or FALSE. + +**Argument type:** STRING + +**Return type:** BOOLEAN + +Example: + + os> source=people | eval `valid_json` = json_valid('[1,2,3,4]'), `invalid_json` = json_valid('{"invalid": "json"') | feilds `valid_json`, `invalid_json` + fetched rows / total rows = 1/1 + +--------------+----------------+ + | valid_json | invalid_json | + +--------------+----------------+ + | True | False | + +--------------+----------------+ + + os> source=accounts | where json_valid('[1,2,3,4]') and isnull(email) | fields account_number, email + fetched rows / total rows = 1/1 + +------------------+---------+ + | account_number | email | + |------------------+---------| + | 13 | null | + +------------------+---------+ diff --git a/docs/ppl-lang/functions/ppl-lambda.md b/docs/ppl-lang/functions/ppl-lambda.md new file mode 100644 index 000000000..cdb6f9e8f --- /dev/null +++ b/docs/ppl-lang/functions/ppl-lambda.md @@ -0,0 +1,187 @@ +## Lambda Functions + +### `FORALL` + +**Description** + +`forall(array, lambda)` Evaluates whether a lambda predicate holds for all elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + **Note:** The lambda expression can access the nested fields of the array elements. This applies to all lambda functions introduced in this document. + +Consider constructing the following array: + + array = [ + {"a":1, "b":1}, + {"a":-1, "b":2} + ] + +and perform lambda functions against the nested fields `a` or `b`. See the examples: + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + +### `EXISTS` + +**Description** + +`exists(array, lambda)` Evaluates whether a lambda predicate holds for one or more elements in the array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if at least one element in the array satisfies the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + +### `FILTER` + +**Description** + +`filter(array, lambda)` Filters the input array using the given lambda function. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains all elements in the input array that satisfy the lambda predicate. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [1, 2] | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [] | + +-----------+ + +### `TRANSFORM` + +**Description** + +`transform(array, lambda)` Transform elements in an array using the lambda transform function. The second argument implies the index of the element if using binary lambda function. This is similar to a `map` in functional programming. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains the result of applying the lambda transform function to each element in the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [2, 3, 4] | + +--------------+ + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [1, 3, 5] | + +--------------+ + +### `REDUCE` + +**Description** + +`reduce(array, start, merge_lambda, finish_lambda)` Applies a binary merge lambda function to a start value and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish lambda function. + +**Argument type:** ARRAY, ANY, LAMBDA, LAMBDA + +**Return type:** ANY + +The final result of applying the lambda functions to the start value and the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 6 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 16 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 60 | + +-----------+ diff --git a/docs/ppl-lang/planning/ppl-between.md b/docs/ppl-lang/planning/ppl-between.md new file mode 100644 index 000000000..6c8e300e8 --- /dev/null +++ b/docs/ppl-lang/planning/ppl-between.md @@ -0,0 +1,17 @@ +## between syntax proposal + +1. **Proposed syntax** + - `... | where expr1 [NOT] BETWEEN expr2 AND expr3` + - evaluate if expr1 is [not] in between expr2 and expr3 + - `... | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] + - `... | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' + +### New syntax definition in ANTLR + +```ANTLR + +logicalExpression + ... + | expr1 = functionArg NOT? BETWEEN expr2 = functionArg AND expr3 = functionArg # between + +``` \ No newline at end of file diff --git a/docs/ppl-lang/ppl-comment.md b/docs/ppl-lang/ppl-comment.md new file mode 100644 index 000000000..3a869955b --- /dev/null +++ b/docs/ppl-lang/ppl-comment.md @@ -0,0 +1,34 @@ +## Comments + +Comments are not evaluated texts. PPL supports both line comments and block comments. + +### Line Comments + +Line comments begin with two slashes `//` and end with a new line. + +Example:: + + os> source=accounts | top gender // finds most common gender of all the accounts + fetched rows / total rows = 2/2 + +----------+ + | gender | + |----------| + | M | + | F | + +----------+ + +### Block Comments + +Block comments begin with a slash followed by an asterisk `\*` and end with an asterisk followed by a slash `*/`. + +Example:: + + os> source=accounts | dedup 2 gender /* dedup the document with gender field keep 2 duplication */ | fields account_number, gender + fetched rows / total rows = 3/3 + +------------------+----------+ + | account_number | gender | + |------------------+----------| + | 1 | M | + | 6 | M | + | 13 | F | + +------------------+----------+ \ No newline at end of file diff --git a/docs/ppl-lang/ppl-eval-command.md b/docs/ppl-lang/ppl-eval-command.md index cd0898c1b..1908c087c 100644 --- a/docs/ppl-lang/ppl-eval-command.md +++ b/docs/ppl-lang/ppl-eval-command.md @@ -80,6 +80,8 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))` +- `source = table | eval f = a in ('foo', 'bar') | fields f` +- `source = table | eval f = a not in ('foo', 'bar') | fields f` Eval with `case` example: ```sql diff --git a/docs/ppl-lang/ppl-eventstats-command.md b/docs/ppl-lang/ppl-eventstats-command.md new file mode 100644 index 000000000..9a65d5052 --- /dev/null +++ b/docs/ppl-lang/ppl-eventstats-command.md @@ -0,0 +1,327 @@ +## PPL `eventstats` command + +### Description +The `eventstats` command enriches your event data with calculated summary statistics. It operates by analyzing specified fields within your events, computing various statistical measures, and then appending these results as new fields to each original event. + +Key aspects of `eventstats`: + +1. It performs calculations across the entire result set or within defined groups. +2. The original events remain intact, with new fields added to contain the statistical results. +3. The command is particularly useful for comparative analysis, identifying outliers, or providing additional context to individual events. + +### Difference between [`stats`](ppl-stats-command.md) and `eventstats` +The `stats` and `eventstats` commands are both used for calculating statistics, but they have some key differences in how they operate and what they produce: + +- Output Format: + - `stats`: Produces a summary table with only the calculated statistics. + - `eventstats`: Adds the calculated statistics as new fields to the existing events, preserving the original data. +- Event Retention: + - `stats`: Reduces the result set to only the statistical summary, discarding individual events. + - `eventstats`: Retains all original events and adds new fields with the calculated statistics. +- Use Cases: + - `stats`: Best for creating summary reports or dashboards. Often used as a final command to summarize results. + - `eventstats`: Useful when you need to enrich events with statistical context for further analysis or filtering. Can be used mid-search to add statistics that can be used in subsequent commands. + +### Syntax +`eventstats ... [by-clause]` + +### **aggregation:** +mandatory. A aggregation function. The argument of aggregation must be field. + +**by-clause**: optional. + +#### Syntax: +`by [span-expression,] [field,]...` + +**Description:** + +The by clause could be the fields and expressions like scalar functions and aggregation functions. +Besides, the span clause can be used to split specific field into buckets in the same interval, the eventstats then does the aggregation by these span buckets. + +**Default**: + +If no `` is specified, the eventstats command aggregates over the entire result set. + +### **`span-expression`**: +optional, at most one. + +#### Syntax: +`span(field_expr, interval_expr)` + +**Description:** + +The unit of the interval expression is the natural unit by default. +If the field is a date and time type field, and the interval is in date/time units, you will need to specify the unit in the interval expression. + +For example, to split the field ``age`` into buckets by 10 years, it looks like ``span(age, 10)``. And here is another example of time span, the span to split a ``timestamp`` field into hourly intervals, it looks like ``span(timestamp, 1h)``. + +* Available time unit: +``` ++----------------------------+ +| Span Interval Units | ++============================+ +| millisecond (ms) | ++----------------------------+ +| second (s) | ++----------------------------+ +| minute (m, case sensitive) | ++----------------------------+ +| hour (h) | ++----------------------------+ +| day (d) | ++----------------------------+ +| week (w) | ++----------------------------+ +| month (M, case sensitive) | ++----------------------------+ +| quarter (q) | ++----------------------------+ +| year (y) | ++----------------------------+ +``` + +### Aggregation Functions + +#### _COUNT_ + +**Description** + +Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example: + + os> source=accounts | eventstats count(); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+---------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | count() | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+---------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 4 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 4 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 4 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 4 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+---------+ + + +#### _SUM_ + +**Description** + +`SUM(expr)`. Returns the sum of expr. + +Example: + + os> source=accounts | eventstats sum(age) by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | sum(age) by gender | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 101 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 101 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 101 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + + +#### _AVG_ + +**Description** + +`AVG(expr)`. Returns the average value of expr. + +Example: + + os> source=accounts | eventstats avg(age) by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | avg(age) by gender | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 33.67 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 33.67 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28.00 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 33.67 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + + +#### MAX + +**Description** + +`MAX(expr)` Returns the maximum value of expr. + +Example: + + os> source=accounts | eventstats max(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | max(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 36 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 36 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 36 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 36 | ++----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + + +#### MIN + +**Description** + +`MIN(expr)` Returns the minimum value of expr. + +Example: + + os> source=accounts | eventstats min(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | min(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 28 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 28 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 28 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + + +#### STDDEV_SAMP + +**Description** + +`STDDEV_SAMP(expr)` Return the sample standard deviation of expr. + +Example: + + os> source=accounts | eventstats stddev_samp(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | stddev_samp(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 3.304037933599835 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 3.304037933599835 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 3.304037933599835 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 3.304037933599835 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + + +#### STDDEV_POP + +**Description** + +`STDDEV_POP(expr)` Return the population standard deviation of expr. + +Example: + + os> source=accounts | eventstats stddev_pop(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | stddev_pop(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 2.8613807855648994 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 2.8613807855648994 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 2.8613807855648994 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 2.8613807855648994 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + + +#### PERCENTILE or PERCENTILE_APPROX + +**Description** + +`PERCENTILE(expr, percent)` or `PERCENTILE_APPROX(expr, percent)` Return the approximate percentile value of expr at the specified percentage. + +* percent: The number must be a constant between 0 and 100. +--- + +Examples: + + os> source=accounts | eventstats percentile(age, 90) by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | percentile(age, 90) by gender | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 36 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 36 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 36 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------------------+ + + +### Example 1: Calculate the average, sum and count of a field by group + +The example show calculate the average age, sum age and count of events of all the accounts group by gender. + +PPL query: + + os> source=accounts | eventstats avg(age) as avg_age, sum(age) as sum_age, count() as count by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+-----------+-------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | avg_age | sum_age | count | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+-----------+-------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 33.666667 | 101 | 3 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 33.666667 | 101 | 3 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28.000000 | 28 | 1 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 33.666667 | 101 | 3 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+-----------+-------+ + + +### Example 2: Calculate the count by a span + +The example gets the count of age by the interval of 10 years. + +PPL query: + + os> source=accounts | eventstats count(age) by span(age, 10) as age_span + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+----------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | age_span | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+----------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 3 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 3 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 1 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 3 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+----------+ + + +### Example 3: Calculate the count by a gender and span + +The example gets the count of age by the interval of 5 years and group by gender. + +PPL query: + + os> source=accounts | eventstats count() as cnt by span(age, 5) as age_span, gender + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | cnt | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 2 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 1 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 1 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 2 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----+ + + +### Usage +- `source = table | eventstats avg(a) ` +- `source = table | where a < 50 | eventstats avg(c) ` +- `source = table | eventstats max(c) by b` +- `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats distinct_count(c)` +- `source = table | eventstats stddev_samp(c)` +- `source = table | eventstats stddev_pop(c)` +- `source = table | eventstats percentile(c, 90)` +- `source = table | eventstats percentile_approx(c, 99)` + +**Aggregations With Span** +- `source = table | eventstats count(a) by span(a, 10) as a_span` +- `source = table | eventstats sum(age) by span(age, 5) as age_span | head 2` +- `source = table | eventstats avg(age) by span(age, 20) as age_span, country | sort - age_span | head 2` + +**Aggregations With TimeWindow Span (tumble windowing function)** + +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date` +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` + +**Aggregations Group by Multiple Levels** + +- `source = table | eventstats avg(age) as avg_state_age by country, state | eventstats avg(avg_state_age) as avg_country_age by country` +- `source = table | eventstats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | eventstats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | eventstats avg(avg_state_age) as avg_adult_country_age by country` + diff --git a/docs/ppl-lang/ppl-fieldsummary-command.md b/docs/ppl-lang/ppl-fieldsummary-command.md new file mode 100644 index 000000000..468c2046b --- /dev/null +++ b/docs/ppl-lang/ppl-fieldsummary-command.md @@ -0,0 +1,83 @@ +## PPL `fieldsummary` command + +**Description** +Using `fieldsummary` command to : + - Calculate basic statistics for each field (count, distinct count, min, max, avg, stddev, mean ) + - Determine the data type of each field + +**Syntax** + +`... | fieldsummary (nulls=true/false)` + +* command accepts any preceding pipe before the terminal `fieldsummary` command and will take them into account. +* `includefields`: list of all the columns to be collected with statistics into a unified result set +* `nulls`: optional; if the true, include the null values in the aggregation calculations (replace null with zero for numeric values) + +### Example 1: + +PPL query: + + os> source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 2 | 2 | 301 | 403 | 352.0 | 352.0 | 72.12489168102785 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Example 2: + +PPL query: + + os> source = t | fieldsummary includefields= id, status_code, request_path nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "id" | 6 | 6 | 1 | 6 | 3.5 | 3.5 | 1.8708286933869707 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 4 | 3 | 200 | 403 | 184.0 | 184.0 | 161.16699413961905 | 2 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "request_path" | 2 | 2 | /about| /home | 0.0 | 0.0 | 0 | 2 |"string"| + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Additional Info +The actual query is translated into the following SQL-like statement: + +```sql + SELECT + id AS Field, + COUNT(id) AS COUNT, + COUNT(DISTINCT id) AS COUNT_DISTINCT, + MIN(id) AS MIN, + MAX(id) AS MAX, + AVG(id) AS AVG, + MEAN(id) AS MEAN, + STDDEV(id) AS STDDEV, + (COUNT(1) - COUNT(id)) AS Nulls, + TYPEOF(id) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +UNION + SELECT + status_code AS Field, + COUNT(status_code) AS COUNT, + COUNT(DISTINCT status_code) AS COUNT_DISTINCT, + MIN(status_code) AS MIN, + MAX(status_code) AS MAX, + AVG(status_code) AS AVG, + MEAN(status_code) AS MEAN, + STDDEV(status_code) AS STDDEV, + (COUNT(1) - COUNT(status_code)) AS Nulls, + TYPEOF(status_code) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +``` +For each such columns (id, status_code) there will be a unique statement and all the fields will be presented togather in the result using a UNION operator + + +### Limitation: + - `topvalues` option was removed from this command due the possible performance impact of such sub-query. As an alternative one can use the `top` command directly as shown [here](ppl-top-command.md). + diff --git a/docs/ppl-lang/ppl-fillnull-command.md b/docs/ppl-lang/ppl-fillnull-command.md new file mode 100644 index 000000000..00064849c --- /dev/null +++ b/docs/ppl-lang/ppl-fillnull-command.md @@ -0,0 +1,92 @@ +## PPL `fillnull` command + +### Description +Using ``fillnull`` command to fill null with provided value in one or more fields in the search result. + + +### Syntax +`fillnull [with in ["," ]] | [using = ["," = ]]` + +* null-replacement: mandatory. The value used to replace `null`s. +* nullable-field: mandatory. Field reference. The `null` values in the field referred to by the property will be replaced with the values from the null-replacement. + + +### Example 1: fillnull one field + +The example show fillnull one field. + +PPL query: + + os> source=logs | fields status_code | eval input=status_code | fillnull value = 0 status_code; +| input | status_code | +|-------|-------------| +| 403 | 403 | +| 403 | 403 | +| NULL | 0 | +| NULL | 0 | +| 200 | 200 | +| 404 | 404 | +| 500 | 500 | +| NULL | 0 | +| 500 | 500 | +| 404 | 404 | +| 200 | 200 | +| 500 | 500 | +| NULL | 0 | +| NULL | 0 | +| 404 | 404 | + + +### Example 2: fillnull applied to multiple fields + +The example show fillnull applied to multiple fields. + +PPL query: + + os> source=logs | fields request_path, timestamp | eval input_request_path=request_path, input_timestamp = timestamp | fillnull value = '???' request_path, timestamp; +| input_request_path | input_timestamp | request_path | timestamp | +|--------------------|-----------------------|--------------|------------------------| +| /contact | NULL | /contact | ??? | +| /home | NULL | /home | ??? | +| /about | 2023-10-01 10:30:00 | /about | 2023-10-01 10:30:00 | +| /home | 2023-10-01 10:15:00 | /home | 2023-10-01 10:15:00 | +| NULL | 2023-10-01 10:20:00 | ??? | 2023-10-01 10:20:00 | +| NULL | 2023-10-01 11:05:00 | ??? | 2023-10-01 11:05:00 | +| /about | NULL | /about | ??? | +| /home | 2023-10-01 10:00:00 | /home | 2023-10-01 10:00:00 | +| /contact | NULL | /contact | ??? | +| NULL | 2023-10-01 10:05:00 | ??? | 2023-10-01 10:05:00 | +| NULL | 2023-10-01 10:50:00 | ??? | 2023-10-01 10:50:00 | +| /services | NULL | /services | ??? | +| /home | 2023-10-01 10:45:00 | /home | 2023-10-01 10:45:00 | +| /services | 2023-10-01 11:00:00 | /services | 2023-10-01 11:00:00 | +| NULL | 2023-10-01 10:35:00 | ??? | 2023-10-01 10:35:00 | + +### Example 3: fillnull applied to multiple fields with various `null` replacement values + +The example show fillnull with various values used to replace `null`s. +- `/error` in `request_path` field +- `1970-01-01 00:00:00` in `timestamp` field + +PPL query: + + os> source=logs | fields request_path, timestamp | eval input_request_path=request_path, input_timestamp = timestamp | fillnull using request_path = '/error', timestamp='1970-01-01 00:00:00'; + + +| input_request_path | input_timestamp | request_path | timestamp | +|--------------------|-----------------------|--------------|------------------------| +| /contact | NULL | /contact | 1970-01-01 00:00:00 | +| /home | NULL | /home | 1970-01-01 00:00:00 | +| /about | 2023-10-01 10:30:00 | /about | 2023-10-01 10:30:00 | +| /home | 2023-10-01 10:15:00 | /home | 2023-10-01 10:15:00 | +| NULL | 2023-10-01 10:20:00 | /error | 2023-10-01 10:20:00 | +| NULL | 2023-10-01 11:05:00 | /error | 2023-10-01 11:05:00 | +| /about | NULL | /about | 1970-01-01 00:00:00 | +| /home | 2023-10-01 10:00:00 | /home | 2023-10-01 10:00:00 | +| /contact | NULL | /contact | 1970-01-01 00:00:00 | +| NULL | 2023-10-01 10:05:00 | /error | 2023-10-01 10:05:00 | +| NULL | 2023-10-01 10:50:00 | /error | 2023-10-01 10:50:00 | +| /services | NULL | /services | 1970-01-01 00:00:00 | +| /home | 2023-10-01 10:45:00 | /home | 2023-10-01 10:45:00 | +| /services | 2023-10-01 11:00:00 | /services | 2023-10-01 11:00:00 | +| NULL | 2023-10-01 10:35:00 | /error | 2023-10-01 10:35:00 | \ No newline at end of file diff --git a/docs/ppl-lang/ppl-flatten-command.md b/docs/ppl-lang/ppl-flatten-command.md new file mode 100644 index 000000000..4c1ae5d0d --- /dev/null +++ b/docs/ppl-lang/ppl-flatten-command.md @@ -0,0 +1,90 @@ +## PPL `flatten` command + +### Description +Using `flatten` command to flatten a field of type: +- `struct` +- `array>` + + +### Syntax +`flatten ` + +* field: to be flattened. The field must be of supported type. + +### Test table +#### Schema +| col\_name | data\_type | +|-----------|-------------------------------------------------| +| \_time | string | +| bridges | array\\> | +| city | string | +| coor | struct\ | +| country | string | +#### Data +| \_time | bridges | city | coor | country | +|---------------------|----------------------------------------------|---------|------------------------|---------------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | {35, 51.5074, -0.1278} | England | +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | {35, 48.8566, 2.3522} | France | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | {2, 45.4408, 12.3155} | Italy | +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | {200, 50.0755, 14.4378}| Czech Republic| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| {96, 47.4979, 19.0402} | Hungary | +| 1990-09-13T12:00:00 | NULL | Warsaw | NULL | Poland | + + + +### Example 1: flatten struct +This example shows how to flatten a struct field. +PPL query: + - `source=table | flatten coor` + +| \_time | bridges | city | country | alt | lat | long | +|---------------------|----------------------------------------------|---------|---------------|-----|--------|--------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | England | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | France | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | Italy | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | Czech Republic| 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| Hungary | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | NULL | Warsaw | Poland | NULL| NULL | NULL | + + + +### Example 2: flatten array + +The example shows how to flatten an array of struct fields. + +PPL query: + - `source=table | flatten bridges` + +| \_time | city | coor | country | length | name | +|---------------------|---------|------------------------|---------------|--------|-------------------| +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 801 | Tower Bridge | +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 928 | London Bridge | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 232 | Pont Neuf | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 160 | Pont Alexandre III| +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 48 | Rialto Bridge | +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 11 | Bridge of Sighs | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 516 | Charles Bridge | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 343 | Legion Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 375 | Chain Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 333 | Liberty Bridge | +| 1990-09-13T12:00:00 | Warsaw | NULL | Poland | NULL | NULL | + + +### Example 3: flatten array and struct +This example shows how to flatten multiple fields. +PPL query: + - `source=table | flatten bridges | flatten coor` + +| \_time | city | country | length | name | alt | lat | long | +|---------------------|---------|---------------|--------|-------------------|------|--------|--------| +| 2024-09-13T12:00:00 | London | England | 801 | Tower Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | London | England | 928 | London Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | Paris | France | 232 | Pont Neuf | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Paris | France | 160 | Pont Alexandre III| 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Venice | Italy | 48 | Rialto Bridge | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Venice | Italy | 11 | Bridge of Sighs | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 516 | Charles Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 343 | Legion Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Budapest| Hungary | 375 | Chain Bridge | 96 | 47.4979| 19.0402| +| 2024-09-13T12:00:00 | Budapest| Hungary | 333 | Liberty Bridge | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | Warsaw | Poland | NULL | NULL | NULL | NULL | NULL | \ No newline at end of file diff --git a/docs/ppl-lang/ppl-search-command.md b/docs/ppl-lang/ppl-search-command.md index f81d9d907..bccfd04f0 100644 --- a/docs/ppl-lang/ppl-search-command.md +++ b/docs/ppl-lang/ppl-search-command.md @@ -32,7 +32,7 @@ The example show fetch all the document from accounts index with . PPL query: - os> source=accounts account_number=1 or gender="F"; + os> SEARCH source=accounts account_number=1 or gender="F"; +------------------+-------------+--------------------+-----------+----------+--------+------------+---------+-------+----------------------+------------+ | account_number | firstname | address | balance | gender | city | employer | state | age | email | lastname | |------------------+-------------+--------------------+-----------+----------+--------+------------+---------+-------+----------------------+------------| diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index ac0f98fe8..c4a0c337c 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -1,6 +1,6 @@ ## PPL SubQuery Commands: -**Syntax** +### Syntax The subquery command should be implemented using a clean, logical syntax that integrates with existing PPL structure. ```sql @@ -21,13 +21,15 @@ For additional info See [Issue](https://github.com/opensearch-project/opensearch --- -**InSubquery usage** +### InSubquery usage - `source = outer | where a in [ source = inner | fields b ]` - `source = outer | where (a) in [ source = inner | fields b ]` - `source = outer | where (a,b,c) in [ source = inner | fields d,e,f ]` - `source = outer | where a not in [ source = inner | fields b ]` - `source = outer | where (a) not in [ source = inner | fields b ]` - `source = outer | where (a,b,c) not in [ source = inner | fields d,e,f ]` +- `source = outer a in [ source = inner | fields b ]` (search filtering with subquery) +- `source = outer a not in [ source = inner | fields b ]` (search filtering with subquery) - `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) - `source = table1 | inner join left = l right = r on l.a = r.a AND r.a in [ source = inner | fields d ] | fields l.a, r.a, b, c` (as join filter) @@ -111,8 +113,9 @@ source = supplier nation | sort s_name ``` +--- -**ExistsSubquery usage** +### ExistsSubquery usage Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 @@ -120,6 +123,9 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where not exists [ source = inner | where a = c ]` - `source = outer | where exists [ source = inner | where a = c and b = d ]` - `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = outer not exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = table as t1 exists [ source = table as t2 | where t1.a = t2.a ]` (table alias is useful in exists subquery) - `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) @@ -163,8 +169,9 @@ source = orders | sort o_orderpriority | fields o_orderpriority, order_count ``` +--- -**ScalarSubquery usage** +### ScalarSubquery usage Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested @@ -172,8 +179,11 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` - `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` -**Uncorrelated scalar subquery in Select and Where** -- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` +**Uncorrelated scalar subquery in Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | fields a` + +**Uncorrelated scalar subquery in Search filter** +- `source = outer a > [ source = inner | stats min(c) ] | fields a` **Correlated scalar subquery in Select** - `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` @@ -185,6 +195,10 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` - `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` +**Correlated scalar subquery in Search filter** +- `source = outer a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + **Nested scalar subquery** - `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` - `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` @@ -240,27 +254,77 @@ source = spark_catalog.default.outer source = spark_catalog.default.inner | where c = 1 | stats min(d) | sort d ] ``` +--- + +### (Relation) Subquery +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or From clause. + +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +- `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` -### **Additional Context** +**_SQL Migration examples with Subquery PPL:_** + +tpch q13 +```sql +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc +``` +Rewritten by PPL (Relation) Subquery: +```sql +SEARCH source = [ + SEARCH source = customer + | LEFT OUTER JOIN left = c right = o ON c_custkey = o_custkey + [ + SEARCH source = orders + | WHERE not like(o_comment, '%special%requests%') + ] + | STATS COUNT(o_orderkey) AS c_count BY c_custkey +] AS c_orders +| STATS COUNT(o_orderkey) AS c_count BY c_custkey +| STATS COUNT(1) AS custdist BY c_count +| SORT - custdist, - c_count +``` +--- -`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expression. The common usage of subquery expression is in `where` clause: +### Additional Context -The `where` command syntax is: +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` as subquery expressions, their common usage is in `where` clause and `search filter`. +Where command: +``` +| where | ... ``` -| where +Search filter: ``` -So the subquery is part of boolean expression, such as +search source=* | ... +``` +A subquery expression could be used in boolean expression, for example ```sql -| where orders.order_id in (subquery source=returns | where return_reason="damaged" | return order_id) +| where orders.order_id in [ source=returns | where return_reason="damaged" | field order_id ] ``` -The `orders.order_id in (subquery source=...)` is a ``. - -In general, we name this kind of subquery clause the `InSubquery` expression, it is a ``, one kind of `subquery expressions`. +The `orders.order_id in [ source=... ]` is a ``. -PS: there are many kinds of `subquery expressions`, another commonly used one is `ScalarSubquery` expression: +In general, we name this kind of subquery clause the `InSubquery` expression, it is a ``. **Subquery with Different Join Types** @@ -326,4 +390,18 @@ source = outer | eval l = "nonEmpty" | fields l ``` -This query just print "nonEmpty" if the inner table is not empty. \ No newline at end of file +This query just print "nonEmpty" if the inner table is not empty. + +**Table alias in subquery** + +Table alias is useful in query which contains a subquery, for example + +```sql +select a, ( + select sum(b) + from catalog.schema.table1 as t1 + where t1.a = t2.a + ) sum_b + from catalog.schema.table2 as t2 +``` +`t1` and `t2` are table aliases which are used in correlated subquery, `sum_b` are subquery alias. diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md new file mode 100644 index 000000000..393a9dd59 --- /dev/null +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -0,0 +1,60 @@ +## PPL trendline Command + +**Description** +Using ``trendline`` command to calculate moving averages of fields. + + +### Syntax +`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving average should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +And the moment only the Simple Moving Average (SMA) type is supported. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data-point + n: The number of data-points in the moving window (period) + t: The current time index + + SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t + +### Example 1: Calculate simple moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.5| + | 14| 257|2023-04-06 17:07:...| 13.5| + | 15| 258|2023-04-06 17:07:...| 14.5| + +-----------+---------+--------------------+----------+ + +### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id sma(2, temperature) as temp_trend_2 sma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.5| NULL| + | 13| 256|2023-04-06 17:07:...| 13.5| 14.0| + | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| + | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| + +-----------+---------+--------------------+------------+------------------+ diff --git a/docs/ppl-lang/ppl-where-command.md b/docs/ppl-lang/ppl-where-command.md index 94ddc1f5c..c954623c3 100644 --- a/docs/ppl-lang/ppl-where-command.md +++ b/docs/ppl-lang/ppl-where-command.md @@ -41,6 +41,10 @@ PPL query: - `source = table | where isempty(a)` - `source = table | where isblank(a)` - `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'` +- `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] +- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | eval status_category = case(a >= 200 AND a < 300, 'Success', diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala index 982b7df23..1cae64b83 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/metadata/log/FlintMetadataLogEntry.scala @@ -18,6 +18,10 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState * log entry id * @param state * Flint index state + * @param lastRefreshStartTime + * timestamp when last refresh started for manual or external scheduler refresh + * @param lastRefreshCompleteTime + * timestamp when last refresh completed for manual or external scheduler refresh * @param entryVersion * entry version fields for consistency control * @param error @@ -28,10 +32,12 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState case class FlintMetadataLogEntry( id: String, /** - * This is currently used as streaming job start time. In future, this should represent the - * create timestamp of the log entry + * This is currently used as streaming job start time for internal scheduler. In future, this + * should represent the create timestamp of the log entry */ createTime: Long, + lastRefreshStartTime: Long, + lastRefreshCompleteTime: Long, state: IndexState, entryVersion: Map[String, Any], error: String, @@ -40,26 +46,48 @@ case class FlintMetadataLogEntry( def this( id: String, createTime: Long, + lastRefreshStartTime: Long, + lastRefreshCompleteTime: Long, state: IndexState, entryVersion: JMap[String, Any], error: String, properties: JMap[String, Any]) = { - this(id, createTime, state, entryVersion.asScala.toMap, error, properties.asScala.toMap) + this( + id, + createTime, + lastRefreshStartTime, + lastRefreshCompleteTime, + state, + entryVersion.asScala.toMap, + error, + properties.asScala.toMap) } def this( id: String, createTime: Long, + lastRefreshStartTime: Long, + lastRefreshCompleteTime: Long, state: IndexState, entryVersion: JMap[String, Any], error: String, properties: Map[String, Any]) = { - this(id, createTime, state, entryVersion.asScala.toMap, error, properties) + this( + id, + createTime, + lastRefreshStartTime, + lastRefreshCompleteTime, + state, + entryVersion.asScala.toMap, + error, + properties) } } object FlintMetadataLogEntry { + val EMPTY_TIMESTAMP = 0L + /** * Flint index state enum. */ diff --git a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java index 04ef216c4..9facd89ef 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java @@ -87,6 +87,10 @@ static void recordOperationSuccess(String metricNamePrefix) { MetricsUtil.incrementCounter(successMetricName); } + static void recordLatency(String metricNamePrefix, long latencyMilliseconds) { + MetricsUtil.addHistoricGauge(metricNamePrefix + ".processingTime", latencyMilliseconds); + } + /** * Records the failure of an OpenSearch operation by incrementing the corresponding metric counter. * If the exception is an OpenSearchException with a specific status code (e.g., 403), @@ -107,6 +111,8 @@ static void recordOperationFailure(String metricNamePrefix, Exception e) { if (statusCode == 403) { String forbiddenErrorMetricName = metricNamePrefix + ".403.count"; MetricsUtil.incrementCounter(forbiddenErrorMetricName); + } else if (statusCode == 429) { + MetricsUtil.incrementCounter(metricNamePrefix + ".429.count"); } String failureMetricName = metricNamePrefix + "." + (statusCode / 100) + "xx.count"; diff --git a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java index 1b83f032a..31f012256 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java @@ -5,6 +5,8 @@ package org.opensearch.flint.core; +import java.util.Arrays; +import java.util.function.Consumer; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; @@ -40,7 +42,10 @@ import org.opensearch.flint.core.storage.BulkRequestRateLimiter; import org.opensearch.flint.core.storage.OpenSearchBulkRetryWrapper; +import static org.opensearch.flint.core.metrics.MetricConstants.OS_BULK_OP_METRIC_PREFIX; +import static org.opensearch.flint.core.metrics.MetricConstants.OS_CREATE_OP_METRIC_PREFIX; import static org.opensearch.flint.core.metrics.MetricConstants.OS_READ_OP_METRIC_PREFIX; +import static org.opensearch.flint.core.metrics.MetricConstants.OS_SEARCH_OP_METRIC_PREFIX; import static org.opensearch.flint.core.metrics.MetricConstants.OS_WRITE_OP_METRIC_PREFIX; /** @@ -67,112 +72,126 @@ public RestHighLevelClientWrapper(RestHighLevelClient client, BulkRequestRateLim @Override public BulkResponse bulk(BulkRequest bulkRequest, RequestOptions options) throws IOException { - return execute(OS_WRITE_OP_METRIC_PREFIX, () -> { + return execute(() -> { try { rateLimiter.acquirePermit(); return bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); } catch (InterruptedException e) { throw new RuntimeException("rateLimiter.acquirePermit was interrupted.", e); } - }); + }, OS_WRITE_OP_METRIC_PREFIX, OS_BULK_OP_METRIC_PREFIX); } @Override public ClearScrollResponse clearScroll(ClearScrollRequest clearScrollRequest, RequestOptions options) throws IOException { - return execute(OS_READ_OP_METRIC_PREFIX, () -> client.clearScroll(clearScrollRequest, options)); + return execute(() -> client.clearScroll(clearScrollRequest, options), + OS_READ_OP_METRIC_PREFIX); } @Override public CreateIndexResponse createIndex(CreateIndexRequest createIndexRequest, RequestOptions options) throws IOException { - return execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.indices().create(createIndexRequest, options)); + return execute(() -> client.indices().create(createIndexRequest, options), + OS_WRITE_OP_METRIC_PREFIX, OS_CREATE_OP_METRIC_PREFIX); } @Override public void updateIndexMapping(PutMappingRequest putMappingRequest, RequestOptions options) throws IOException { - execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.indices().putMapping(putMappingRequest, options)); + execute(() -> client.indices().putMapping(putMappingRequest, options), + OS_WRITE_OP_METRIC_PREFIX); } @Override public void deleteIndex(DeleteIndexRequest deleteIndexRequest, RequestOptions options) throws IOException { - execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.indices().delete(deleteIndexRequest, options)); + execute(() -> client.indices().delete(deleteIndexRequest, options), + OS_WRITE_OP_METRIC_PREFIX); } @Override public DeleteResponse delete(DeleteRequest deleteRequest, RequestOptions options) throws IOException { - return execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.delete(deleteRequest, options)); + return execute(() -> client.delete(deleteRequest, options), OS_WRITE_OP_METRIC_PREFIX); } @Override public GetResponse get(GetRequest getRequest, RequestOptions options) throws IOException { - return execute(OS_READ_OP_METRIC_PREFIX, () -> client.get(getRequest, options)); + return execute(() -> client.get(getRequest, options), OS_READ_OP_METRIC_PREFIX); } @Override public GetIndexResponse getIndex(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException { - return execute(OS_READ_OP_METRIC_PREFIX, () -> client.indices().get(getIndexRequest, options)); + return execute(() -> client.indices().get(getIndexRequest, options), + OS_READ_OP_METRIC_PREFIX); } @Override public IndexResponse index(IndexRequest indexRequest, RequestOptions options) throws IOException { - return execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.index(indexRequest, options)); + return execute(() -> client.index(indexRequest, options), OS_WRITE_OP_METRIC_PREFIX); } @Override public Boolean doesIndexExist(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException { - return execute(OS_READ_OP_METRIC_PREFIX, () -> client.indices().exists(getIndexRequest, options)); + return execute(() -> client.indices().exists(getIndexRequest, options), + OS_READ_OP_METRIC_PREFIX); } @Override public SearchResponse search(SearchRequest searchRequest, RequestOptions options) throws IOException { - return execute(OS_READ_OP_METRIC_PREFIX, () -> client.search(searchRequest, options)); + return execute(() -> client.search(searchRequest, options), OS_READ_OP_METRIC_PREFIX, OS_SEARCH_OP_METRIC_PREFIX); } @Override public SearchResponse scroll(SearchScrollRequest searchScrollRequest, RequestOptions options) throws IOException { - return execute(OS_READ_OP_METRIC_PREFIX, () -> client.scroll(searchScrollRequest, options)); + return execute(() -> client.scroll(searchScrollRequest, options), OS_READ_OP_METRIC_PREFIX); } @Override public UpdateResponse update(UpdateRequest updateRequest, RequestOptions options) throws IOException { - return execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.update(updateRequest, options)); + return execute(() -> client.update(updateRequest, options), OS_WRITE_OP_METRIC_PREFIX); } - @Override - public IndicesStatsResponse stats(IndicesStatsRequest request) throws IOException { - return execute(OS_READ_OP_METRIC_PREFIX, - () -> { - OpenSearchClient openSearchClient = - new OpenSearchClient(new RestClientTransport(client.getLowLevelClient(), - new JacksonJsonpMapper())); - return openSearchClient.indices().stats(request); - }); - } + @Override + public IndicesStatsResponse stats(IndicesStatsRequest request) throws IOException { + return execute(() -> { + OpenSearchClient openSearchClient = + new OpenSearchClient(new RestClientTransport(client.getLowLevelClient(), + new JacksonJsonpMapper())); + return openSearchClient.indices().stats(request); + }, OS_READ_OP_METRIC_PREFIX + ); + } @Override public CreatePitResponse createPit(CreatePitRequest request) throws IOException { - return execute(OS_WRITE_OP_METRIC_PREFIX, () -> openSearchClient().createPit(request)); + return execute(() -> openSearchClient().createPit(request), OS_WRITE_OP_METRIC_PREFIX); } /** * Executes a given operation, tracks metrics, and handles exceptions. * - * @param metricNamePrefix the prefix for the metric name - * @param operation the operation to execute * @param the return type of the operation + * @param operation the operation to execute + * @param metricNamePrefixes array of prefixes for the metric name * @return the result of the operation * @throws IOException if an I/O exception occurs */ - private T execute(String metricNamePrefix, IOCallable operation) throws IOException { + private T execute(IOCallable operation, String... metricNamePrefixes) throws IOException { + long startTime = System.currentTimeMillis(); try { T result = operation.call(); - IRestHighLevelClient.recordOperationSuccess(metricNamePrefix); + eachPrefix(IRestHighLevelClient::recordOperationSuccess, metricNamePrefixes); return result; } catch (Exception e) { - IRestHighLevelClient.recordOperationFailure(metricNamePrefix, e); + eachPrefix(prefix -> IRestHighLevelClient.recordOperationFailure(prefix, e), metricNamePrefixes); throw e; + } finally { + long latency = System.currentTimeMillis() - startTime; + eachPrefix(prefix -> IRestHighLevelClient.recordLatency(prefix, latency), metricNamePrefixes); } } + private static void eachPrefix(Consumer fn, String... prefixes) { + Arrays.stream(prefixes).forEach(fn); + } + /** * Functional interface for operations that can throw IOException. * diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java new file mode 100644 index 000000000..8e5288110 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import com.codahale.metrics.Gauge; +import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Value; + +/** + * Gauge which stores historic data points with timestamps. + * This is used for emitting separate data points per request, instead of single aggregated metrics. + */ +public class HistoricGauge implements Gauge { + @AllArgsConstructor + @Value + public static class DataPoint { + Long value; + long timestamp; + } + + private final List dataPoints = Collections.synchronizedList(new LinkedList<>()); + + /** + * This method will just return first value. + * @return first value + */ + @Override + public Long getValue() { + if (!dataPoints.isEmpty()) { + return dataPoints.get(0).value; + } else { + return null; + } + } + + /** + * Add new data point. Current time stamp will be attached to the data point. + * @param value metric value + */ + public void addDataPoint(Long value) { + dataPoints.add(new DataPoint(value, System.currentTimeMillis())); + } + + /** + * Return copy of dataPoints and remove them from internal list + * @return copy of the data points + */ + public List pollDataPoints() { + int size = dataPoints.size(); + List result = new ArrayList<>(dataPoints.subList(0, size)); + if (size > 0) { + dataPoints.subList(0, size).clear(); + } + return result; + } + + @VisibleForTesting + public List getDataPoints() { + return dataPoints; + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 4cdfcee01..ef4d01652 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -22,6 +22,24 @@ public final class MetricConstants { */ public static final String OS_WRITE_OP_METRIC_PREFIX = "opensearch.write"; + /** + * Prefixes for OpenSearch API specific metrics + */ + public static final String OS_CREATE_OP_METRIC_PREFIX = "opensearch.create"; + public static final String OS_SEARCH_OP_METRIC_PREFIX = "opensearch.search"; + public static final String OS_BULK_OP_METRIC_PREFIX = "opensearch.bulk"; + + /** + * Metric name for request size of opensearch bulk request + */ + public static final String OPENSEARCH_BULK_SIZE_METRIC = "opensearch.bulk.size.count"; + + /** + * Metric name for opensearch bulk request retry count + */ + public static final String OPENSEARCH_BULK_RETRY_COUNT_METRIC = "opensearch.bulk.retry.count"; + public static final String OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC = "opensearch.bulk.allRetryFailed.count"; + /** * Metric name for counting the errors encountered with Amazon S3 operations. */ @@ -112,7 +130,52 @@ public final class MetricConstants { */ public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count"; + /** + * Metric for tracking the latency of query execution (start to complete query execution) excluding result write. + */ + public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime"; + + /** + * Metric for tracking the total bytes read from input + */ + public static final String INPUT_TOTAL_BYTES_READ = "input.totalBytesRead.count"; + + /** + * Metric for tracking the total records read from input + */ + public static final String INPUT_TOTAL_RECORDS_READ = "input.totalRecordsRead.count"; + + /** + * Metric for tracking the total bytes written to output + */ + public static final String OUTPUT_TOTAL_BYTES_WRITTEN = "output.totalBytesWritten.count"; + + /** + * Metric for tracking the total records written to output + */ + public static final String OUTPUT_TOTAL_RECORDS_WRITTEN = "output.totalRecordsWritten.count"; + + /** + * Metric for tracking the latency of checkpoint deletion + */ + public static final String CHECKPOINT_DELETE_TIME_METRIC = "checkpoint.delete.processingTime"; + + /** + * Prefix for externalScheduler metrics + */ + public static final String EXTERNAL_SCHEDULER_METRIC_PREFIX = "externalScheduler."; + + /** + * Metric prefix for tracking the index state transitions + */ + public static final String INDEX_STATE_UPDATED_TO_PREFIX = "indexState.updatedTo."; + + /** + * Metric for tracking the index state transitions + */ + public static final String INITIAL_CONDITION_CHECK_FAILED_PREFIX = "initialConditionCheck.failed."; + private MetricConstants() { // Private constructor to prevent instantiation } -} \ No newline at end of file +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index 81a482d5e..5a0f0f5ad 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -9,6 +9,7 @@ import com.codahale.metrics.Gauge; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.Timer; +import java.util.function.Supplier; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.FlintIndexMetricSource; @@ -75,6 +76,15 @@ public static void decrementCounter(String metricName, boolean isIndexMetric) { } } + public static void setCounter(String metricName, boolean isIndexMetric, long n) { + Counter counter = getOrCreateCounter(metricName, isIndexMetric); + if (counter != null) { + counter.dec(counter.getCount()); + counter.inc(n); + LOG.info("counter: " + counter.getCount()); + } + } + /** * Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist. * @@ -107,6 +117,43 @@ public static Long stopTimer(Timer.Context context) { return context != null ? context.stop() : null; } + public static Timer getTimer(String metricName, boolean isIndexMetric) { + return getOrCreateTimer(metricName, isIndexMetric); + } + + /** + * Registers a HistoricGauge metric with the provided name and value. + * + * @param metricName The name of the HistoricGauge metric to register. + * @param value The value to be stored + */ + public static void addHistoricGauge(String metricName, final long value) { + HistoricGauge historicGauge = getOrCreateHistoricGauge(metricName); + if (historicGauge != null) { + historicGauge.addDataPoint(value); + } + } + + /** + * Automatically emit latency metric as Historic Gauge for the execution of supplier + * @param supplier the lambda to be metered + * @param metricName name of the metric + * @return value returned by supplier + */ + public static T withLatencyAsHistoricGauge(Supplier supplier, String metricName) { + long startTime = System.currentTimeMillis(); + try { + return supplier.get(); + } finally { + addHistoricGauge(metricName, System.currentTimeMillis() - startTime); + } + } + + private static HistoricGauge getOrCreateHistoricGauge(String metricName) { + MetricRegistry metricRegistry = getMetricRegistry(false); + return metricRegistry != null ? metricRegistry.gauge(metricName, HistoricGauge::new) : null; + } + /** * Registers a gauge metric with the provided name and value. * diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java index a5ea190c5..9104e1b34 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java @@ -47,6 +47,7 @@ import java.util.stream.LongStream; import java.util.stream.Stream; import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups; +import org.opensearch.flint.core.metrics.HistoricGauge; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -145,7 +146,11 @@ public void report(final SortedMap gauges, gauges.size() + counters.size() + 10 * histograms.size() + 10 * timers.size()); for (final Map.Entry gaugeEntry : gauges.entrySet()) { - processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + if (gaugeEntry.getValue() instanceof HistoricGauge) { + processHistoricGauge(gaugeEntry.getKey(), (HistoricGauge) gaugeEntry.getValue(), metricData); + } else { + processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + } } for (final Map.Entry counterEntry : counters.entrySet()) { @@ -227,6 +232,13 @@ private void processGauge(final String metricName, final Gauge gauge, final List } } + private void processHistoricGauge(final String metricName, final HistoricGauge gauge, final List metricData) { + for (HistoricGauge.DataPoint dataPoint: gauge.pollDataPoints()) { + stageMetricDatum(true, metricName, dataPoint.getValue().doubleValue(), StandardUnit.None, DIMENSION_GAUGE, metricData, + dataPoint.getTimestamp()); + } + } + private void processCounter(final String metricName, final Counting counter, final List metricData) { long currentCount = counter.getCount(); Long lastCount = lastPolledCounts.get(counter); @@ -333,12 +345,25 @@ private void processHistogram(final String metricName, final Histogram histogram *

* If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted */ + private void stageMetricDatum(final boolean metricConfigured, + final String metricName, + final double metricValue, + final StandardUnit standardUnit, + final String dimensionValue, + final List metricData + ) { + stageMetricDatum(metricConfigured, metricName, metricValue, standardUnit, + dimensionValue, metricData, builder.clock.getTime()); + } + private void stageMetricDatum(final boolean metricConfigured, final String metricName, final double metricValue, final StandardUnit standardUnit, final String dimensionValue, - final List metricData) { + final List metricData, + final Long timestamp + ) { // Only submit metrics that show some data, so let's save some money if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) { final DimensionedName dimensionedName = DimensionedName.decode(metricName); @@ -351,7 +376,7 @@ private void stageMetricDatum(final boolean metricConfigured, MetricInfo metricInfo = getMetricInfo(dimensionedName, dimensions); for (Set dimensionSet : metricInfo.getDimensionSets()) { MetricDatum datum = new MetricDatum() - .withTimestamp(new Date(builder.clock.getTime())) + .withTimestamp(new Date(timestamp)) .withValue(cleanMetricValue(metricValue)) .withMetricName(metricInfo.getMetricName()) .withDimensions(dimensionSet) diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java index da3a446d4..e3029d61c 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.metrics.reporter; import com.amazonaws.services.cloudwatch.model.Dimension; diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameBuilder.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameBuilder.java index 603c58f95..dc4e6bc4d 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameBuilder.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameBuilder.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.metrics.reporter; import com.amazonaws.services.cloudwatch.model.Dimension; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java index e6fed4126..466787c81 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransaction.java @@ -16,6 +16,8 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLog; import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry; import org.opensearch.flint.common.metadata.log.OptimisticTransaction; +import org.opensearch.flint.core.metrics.MetricConstants; +import org.opensearch.flint.core.metrics.MetricsUtil; /** * Default optimistic transaction implementation that captures the basic workflow for @@ -73,6 +75,7 @@ public T commit(Function operation) { // Perform initial log check if (!initialCondition.test(latest)) { LOG.warning("Initial log entry doesn't satisfy precondition " + latest); + emitConditionCheckFailedMetric(latest); throw new IllegalStateException( String.format("Index state [%s] doesn't satisfy precondition", latest.state())); } @@ -86,6 +89,8 @@ public T commit(Function operation) { initialLog = initialLog.copy( initialLog.id(), initialLog.createTime(), + initialLog.lastRefreshStartTime(), + initialLog.lastRefreshCompleteTime(), initialLog.state(), latest.entryVersion(), initialLog.error(), @@ -102,6 +107,7 @@ public T commit(Function operation) { metadataLog.purge(); } else { metadataLog.add(finalLog); + emitFinalLogStateMetric(finalLog); } return result; } catch (Exception e) { @@ -117,4 +123,12 @@ public T commit(Function operation) { throw new IllegalStateException("Failed to commit transaction operation", e); } } + + private void emitConditionCheckFailedMetric(FlintMetadataLogEntry latest) { + MetricsUtil.addHistoricGauge(MetricConstants.INITIAL_CONDITION_CHECK_FAILED_PREFIX + latest.state() + ".count", 1); + } + + private void emitFinalLogStateMetric(FlintMetadataLogEntry finalLog) { + MetricsUtil.addHistoricGauge(MetricConstants.INDEX_STATE_UPDATED_TO_PREFIX + finalLog.state() + ".count", 1); + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala new file mode 100644 index 000000000..bfafd3eb3 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics + +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.SparkSession + +/** + * Collect and emit bytesRead/Written and recordsRead/Written metrics + */ +class ReadWriteBytesSparkListener extends SparkListener with Logging { + var bytesRead: Long = 0 + var recordsRead: Long = 0 + var bytesWritten: Long = 0 + var recordsWritten: Long = 0 + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val inputMetrics = taskEnd.taskMetrics.inputMetrics + val outputMetrics = taskEnd.taskMetrics.outputMetrics + val ids = s"(${taskEnd.taskInfo.taskId}, ${taskEnd.taskInfo.partitionId})" + logInfo( + s"${ids} Input: bytesRead=${inputMetrics.bytesRead}, recordsRead=${inputMetrics.recordsRead}") + logInfo( + s"${ids} Output: bytesWritten=${outputMetrics.bytesWritten}, recordsWritten=${outputMetrics.recordsWritten}") + + bytesRead += inputMetrics.bytesRead + recordsRead += inputMetrics.recordsRead + bytesWritten += outputMetrics.bytesWritten + recordsWritten += outputMetrics.recordsWritten + } + + def emitMetrics(): Unit = { + logInfo(s"Input: totalBytesRead=${bytesRead}, totalRecordsRead=${recordsRead}") + logInfo(s"Output: totalBytesWritten=${bytesWritten}, totalRecordsWritten=${recordsWritten}") + MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_BYTES_READ, bytesRead) + MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_RECORDS_READ, recordsRead) + MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_BYTES_WRITTEN, bytesWritten) + MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_RECORDS_WRITTEN, recordsWritten) + } +} + +object ReadWriteBytesSparkListener { + def withMetrics[T](spark: SparkSession, lambda: () => T): T = { + val listener = new ReadWriteBytesSparkListener() + spark.sparkContext.addSparkListener(listener) + + val result = lambda() + + spark.sparkContext.removeSparkListener(listener) + listener.emitMetrics() + + result + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiter.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiter.java index af298cc8f..797dc2d02 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiter.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiter.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.storage; import dev.failsafe.RateLimiter; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolder.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolder.java index 0453c70c8..73fdb8843 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolder.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolder.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.storage; import org.opensearch.flint.core.FlintOptions; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java index 0b78304d2..f90dda9a0 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverter.java @@ -101,7 +101,7 @@ public static String toJson(FlintMetadataLogEntry logEntry) throws JsonProcessin ObjectMapper mapper = new ObjectMapper(); ObjectNode json = mapper.createObjectNode(); - json.put("version", "1.0"); + json.put("version", "1.1"); json.put("latestId", logEntry.id()); json.put("type", "flintindexstate"); json.put("state", logEntry.state().toString()); @@ -109,6 +109,8 @@ public static String toJson(FlintMetadataLogEntry logEntry) throws JsonProcessin json.put("jobId", jobId); json.put("dataSourceName", logEntry.properties().get("dataSourceName").get().toString()); json.put("jobStartTime", logEntry.createTime()); + json.put("lastRefreshStartTime", logEntry.lastRefreshStartTime()); + json.put("lastRefreshCompleteTime", logEntry.lastRefreshCompleteTime()); json.put("lastUpdateTime", lastUpdateTime); json.put("error", logEntry.error()); @@ -138,6 +140,8 @@ public static FlintMetadataLogEntry constructLogEntry( id, /* sourceMap may use Integer or Long even though it's always long in index mapping */ ((Number) sourceMap.get("jobStartTime")).longValue(), + ((Number) sourceMap.get("lastRefreshStartTime")).longValue(), + ((Number) sourceMap.get("lastRefreshCompleteTime")).longValue(), FlintMetadataLogEntry.IndexState$.MODULE$.from((String) sourceMap.get("state")), Map.of("seqNo", seqNo, "primaryTerm", primaryTerm), (String) sourceMap.get("error"), diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 8c327b664..24c9df492 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -132,7 +132,9 @@ public void purge() { public FlintMetadataLogEntry emptyLogEntry() { return new FlintMetadataLogEntry( "", - 0L, + FlintMetadataLogEntry.EMPTY_TIMESTAMP(), + FlintMetadataLogEntry.EMPTY_TIMESTAMP(), + FlintMetadataLogEntry.EMPTY_TIMESTAMP(), FlintMetadataLogEntry.IndexState$.MODULE$.EMPTY(), Map.of("seqNo", UNASSIGNED_SEQ_NO, "primaryTerm", UNASSIGNED_PRIMARY_TERM), "", @@ -146,6 +148,8 @@ private FlintMetadataLogEntry createLogEntry(FlintMetadataLogEntry logEntry) { logEntry.copy( latestId, logEntry.createTime(), + logEntry.lastRefreshStartTime(), + logEntry.lastRefreshCompleteTime(), logEntry.state(), logEntry.entryVersion(), logEntry.error(), @@ -184,6 +188,8 @@ private FlintMetadataLogEntry writeLogEntry( logEntry = new FlintMetadataLogEntry( logEntry.id(), logEntry.createTime(), + logEntry.lastRefreshStartTime(), + logEntry.lastRefreshCompleteTime(), logEntry.state(), Map.of("seqNo", response.getSeqNo(), "primaryTerm", response.getPrimaryTerm()), logEntry.error(), diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapper.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapper.java index 279c9b642..14e3b7099 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapper.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapper.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.storage; import dev.failsafe.Failsafe; @@ -6,6 +11,7 @@ import dev.failsafe.function.CheckedPredicate; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; import org.opensearch.action.DocWriteRequest; @@ -15,6 +21,8 @@ import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; import org.opensearch.flint.core.http.FlintRetryOptions; +import org.opensearch.flint.core.metrics.MetricConstants; +import org.opensearch.flint.core.metrics.MetricsUtil; import org.opensearch.rest.RestStatus; public class OpenSearchBulkRetryWrapper { @@ -37,11 +45,19 @@ public OpenSearchBulkRetryWrapper(FlintRetryOptions retryOptions) { */ public BulkResponse bulkWithPartialRetry(RestHighLevelClient client, BulkRequest bulkRequest, RequestOptions options) { + final AtomicInteger requestCount = new AtomicInteger(0); try { final AtomicReference nextRequest = new AtomicReference<>(bulkRequest); - return Failsafe + BulkResponse res = Failsafe .with(retryPolicy) + .onFailure((event) -> { + if (event.isRetry()) { + MetricsUtil.addHistoricGauge( + MetricConstants.OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC, 1); + } + }) .get(() -> { + requestCount.incrementAndGet(); BulkResponse response = client.bulk(nextRequest.get(), options); if (retryPolicy.getConfig().allowsRetries() && bulkItemRetryableResultPredicate.test( response)) { @@ -49,11 +65,15 @@ public BulkResponse bulkWithPartialRetry(RestHighLevelClient client, BulkRequest } return response; }); + return res; } catch (FailsafeException ex) { LOG.severe("Request failed permanently. Re-throwing original exception."); // unwrap original exception and throw throw new RuntimeException(ex.getCause()); + } finally { + MetricsUtil.addHistoricGauge(MetricConstants.OPENSEARCH_BULK_SIZE_METRIC, bulkRequest.estimatedSizeInBytes()); + MetricsUtil.addHistoricGauge(MetricConstants.OPENSEARCH_BULK_RETRY_COUNT_METRIC, requestCount.get() - 1); } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java index 0d84b4956..eb7264a84 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.storage; import org.opensearch.action.support.WriteRequest; diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransactionMetricTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransactionMetricTest.java new file mode 100644 index 000000000..838e9978c --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/metadata/log/DefaultOptimisticTransactionMetricTest.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata.log; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.flint.common.metadata.log.FlintMetadataLog; +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry; + +import java.util.Optional; +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState$; +import org.opensearch.flint.core.metrics.MetricsTestUtil; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class DefaultOptimisticTransactionMetricTest { + + @Mock + private FlintMetadataLog metadataLog; + + @Mock + private FlintMetadataLogEntry logEntry; + + @InjectMocks + private DefaultOptimisticTransaction transaction; + + @Test + void testCommitWithValidInitialCondition() throws Exception { + MetricsTestUtil.withMetricEnv(verifier -> { + when(metadataLog.getLatest()).thenReturn(Optional.of(logEntry)); + when(metadataLog.add(any(FlintMetadataLogEntry.class))).thenReturn(logEntry); + when(logEntry.state()).thenReturn(IndexState$.MODULE$.ACTIVE()); + + transaction.initialLog(entry -> true) + .transientLog(entry -> logEntry) + .finalLog(entry -> logEntry) + .commit(entry -> "Success"); + + verify(metadataLog, times(2)).add(logEntry); + verifier.assertHistoricGauge("indexState.updatedTo.active.count", 1); + }); + } + + @Test + void testConditionCheckFailed() throws Exception { + MetricsTestUtil.withMetricEnv(verifier -> { + when(metadataLog.getLatest()).thenReturn(Optional.of(logEntry)); + when(logEntry.state()).thenReturn(IndexState$.MODULE$.DELETED()); + + transaction.initialLog(entry -> false) + .finalLog(entry -> logEntry); + + assertThrows(IllegalStateException.class, () -> { + transaction.commit(entry -> "Should Fail"); + }); + verifier.assertHistoricGauge("initialConditionCheck.failed.deleted.count", 1); + }); + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java new file mode 100644 index 000000000..f3d842af2 --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import org.junit.Test; +import static org.junit.Assert.*; +import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint; + +import java.util.List; + +public class HistoricGaugeTest { + + @Test + public void testGetValue_EmptyGauge_ShouldReturnNull() { + HistoricGauge gauge= new HistoricGauge(); + assertNull(gauge.getValue()); + } + + @Test + public void testGetValue_WithSingleDataPoint_ShouldReturnFirstValue() { + HistoricGauge gauge= new HistoricGauge(); + Long value = 100L; + gauge.addDataPoint(value); + + assertEquals(value, gauge.getValue()); + } + + @Test + public void testGetValue_WithMultipleDataPoints_ShouldReturnFirstValue() { + HistoricGauge gauge= new HistoricGauge(); + Long firstValue = 100L; + Long secondValue = 200L; + gauge.addDataPoint(firstValue); + gauge.addDataPoint(secondValue); + + assertEquals(firstValue, gauge.getValue()); + } + + @Test + public void testPollDataPoints_WithMultipleDataPoints_ShouldReturnAndClearDataPoints() { + HistoricGauge gauge= new HistoricGauge(); + gauge.addDataPoint(100L); + gauge.addDataPoint(200L); + gauge.addDataPoint(300L); + + List dataPoints = gauge.pollDataPoints(); + + assertEquals(3, dataPoints.size()); + assertEquals(Long.valueOf(100L), dataPoints.get(0).getValue()); + assertEquals(Long.valueOf(200L), dataPoints.get(1).getValue()); + assertEquals(Long.valueOf(300L), dataPoints.get(2).getValue()); + + assertTrue(gauge.pollDataPoints().isEmpty()); + } + + @Test + public void testAddDataPoint_ShouldAddDataPointWithCorrectValueAndTimestamp() { + HistoricGauge gauge= new HistoricGauge(); + Long value = 100L; + gauge.addDataPoint(value); + + List dataPoints = gauge.pollDataPoints(); + + assertEquals(1, dataPoints.size()); + assertEquals(value, dataPoints.get(0).getValue()); + assertTrue(dataPoints.get(0).getTimestamp() > 0); + } + + @Test + public void testPollDataPoints_EmptyGauge_ShouldReturnEmptyList() { + HistoricGauge gauge= new HistoricGauge(); + List dataPoints = gauge.pollDataPoints(); + + assertTrue(dataPoints.isEmpty()); + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsTestUtil.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsTestUtil.java new file mode 100644 index 000000000..05febb92b --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsTestUtil.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +import com.codahale.metrics.MetricRegistry; +import java.util.List; +import lombok.AllArgsConstructor; +import org.apache.spark.SparkEnv; +import org.apache.spark.metrics.source.Source; +import org.mockito.MockedStatic; +import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint; + +/** + * Utility class for verifying metrics + */ +public class MetricsTestUtil { + @AllArgsConstructor + public static class MetricsVerifier { + + final MetricRegistry metricRegistry; + + public void assertMetricExist(String metricName) { + assertNotNull(metricRegistry.getMetrics().get(metricName)); + } + + public void assertMetricClass(String metricName, Class clazz) { + assertMetricExist(metricName); + assertEquals(clazz, metricRegistry.getMetrics().get(metricName).getClass()); + } + + public void assertHistoricGauge(String metricName, long... values) { + HistoricGauge historicGauge = getHistoricGauge(metricName); + List dataPoints = historicGauge.getDataPoints(); + for (int i = 0; i < values.length; i++) { + assertEquals(values[i], dataPoints.get(i).getValue().longValue()); + } + } + + private HistoricGauge getHistoricGauge(String metricName) { + assertMetricClass(metricName, HistoricGauge.class); + return (HistoricGauge) metricRegistry.getMetrics().get(metricName); + } + + public void assertMetricNotExist(String metricName) { + assertNull(metricRegistry.getMetrics().get(metricName)); + } + } + + @FunctionalInterface + public interface ThrowableConsumer { + void accept(T t) throws Exception; + } + + public static void withMetricEnv(ThrowableConsumer test) throws Exception { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + Source metricSource = mock(Source.class); + MetricRegistry metricRegistry = new MetricRegistry(); + when(metricSource.metricRegistry()).thenReturn(metricRegistry); + when(sparkEnv.metricsSystem().getSourcesByName(any()).head()).thenReturn(metricSource); + + test.accept(new MetricsVerifier(metricRegistry)); + } + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index b54269ce0..c586c729a 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -1,8 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.metrics; import com.codahale.metrics.Counter; import com.codahale.metrics.Gauge; import com.codahale.metrics.Timer; +import java.time.Duration; +import java.util.List; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.FlintIndexMetricSource; @@ -14,6 +21,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -101,6 +109,34 @@ private void testStartStopTimerHelper(boolean isIndexMetric) { } } + @Test + public void testGetTimer() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + // Mock SparkEnv + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + // Mock appropriate MetricSource + String sourceName = FlintMetricSource.FLINT_INDEX_METRIC_SOURCE_NAME(); + Source metricSource = Mockito.spy(new FlintIndexMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn( + metricSource); + + // Test the methods + String testMetric = "testPrefix.processingTime"; + long duration = 500; + MetricsUtil.getTimer(testMetric, true).update(duration, TimeUnit.MILLISECONDS); + + // Verify interactions + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(metricSource, times(1)).metricRegistry(); + Timer timer = metricSource.metricRegistry().getTimers().get(testMetric); + Assertions.assertNotNull(timer); + Assertions.assertEquals(1L, timer.getCount()); + assertEquals(Duration.ofMillis(duration).getNano(), timer.getSnapshot().getMean(), 0.1); + } + } + @Test public void testRegisterGauge() { testRegisterGaugeHelper(false); @@ -169,4 +205,31 @@ public void testDefaultBehavior() { Assertions.assertNotNull(flintMetricSource.metricRegistry().getGauges().get(testGaugeMetric)); } } -} \ No newline at end of file + + @Test + public void testAddHistoricGauge() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + String sourceName = FlintMetricSource.FLINT_METRIC_SOURCE_NAME(); + Source metricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn(metricSource); + + long value1 = 100L; + long value2 = 200L; + String gaugeName = "test.gauge"; + MetricsUtil.addHistoricGauge(gaugeName, value1); + MetricsUtil.addHistoricGauge(gaugeName, value2); + + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(metricSource, times(2)).metricRegistry(); + + HistoricGauge gauge = (HistoricGauge)metricSource.metricRegistry().getGauges().get(gaugeName); + Assertions.assertNotNull(gauge); + List dataPoints = gauge.pollDataPoints(); + Assertions.assertEquals(value1, dataPoints.get(0).getValue()); + Assertions.assertEquals(value2, dataPoints.get(1).getValue()); + } + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java index 6bc6a9c2d..c9f6d62f5 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.metrics.reporter; import static org.hamcrest.CoreMatchers.hasItems; diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java index ae8fdfa9a..599ed107d 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.auth; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolderTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolderTest.java index f2f160973..b12c2c522 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolderTest.java +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterHolderTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.storage; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterTest.java index d86f06d24..b87c9f797 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterTest.java +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.storage; diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala b/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala index 577dfc5fc..2708d48e8 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/FlintMetadataLogEntryOpenSearchConverterSuite.scala @@ -25,6 +25,10 @@ class FlintMetadataLogEntryOpenSearchConverterTest val sourceMap = JMap.of( "jobStartTime", 1234567890123L.asInstanceOf[Object], + "lastRefreshStartTime", + 1234567890123L.asInstanceOf[Object], + "lastRefreshCompleteTime", + 1234567890123L.asInstanceOf[Object], "state", "active".asInstanceOf[Object], "dataSourceName", @@ -36,6 +40,8 @@ class FlintMetadataLogEntryOpenSearchConverterTest when(mockLogEntry.id).thenReturn("id") when(mockLogEntry.state).thenReturn(FlintMetadataLogEntry.IndexState.ACTIVE) when(mockLogEntry.createTime).thenReturn(1234567890123L) + when(mockLogEntry.lastRefreshStartTime).thenReturn(1234567890123L) + when(mockLogEntry.lastRefreshCompleteTime).thenReturn(1234567890123L) when(mockLogEntry.error).thenReturn("") when(mockLogEntry.properties).thenReturn(Map("dataSourceName" -> "testDataSource")) } @@ -45,7 +51,7 @@ class FlintMetadataLogEntryOpenSearchConverterTest val expectedJsonWithoutLastUpdateTime = s""" |{ - | "version": "1.0", + | "version": "1.1", | "latestId": "id", | "type": "flintindexstate", | "state": "active", @@ -53,6 +59,8 @@ class FlintMetadataLogEntryOpenSearchConverterTest | "jobId": "unknown", | "dataSourceName": "testDataSource", | "jobStartTime": 1234567890123, + | "lastRefreshStartTime": 1234567890123, + | "lastRefreshCompleteTime": 1234567890123, | "error": "" |} |""".stripMargin @@ -67,15 +75,22 @@ class FlintMetadataLogEntryOpenSearchConverterTest logEntry shouldBe a[FlintMetadataLogEntry] logEntry.id shouldBe "id" logEntry.createTime shouldBe 1234567890123L + logEntry.lastRefreshStartTime shouldBe 1234567890123L + logEntry.lastRefreshCompleteTime shouldBe 1234567890123L logEntry.state shouldBe FlintMetadataLogEntry.IndexState.ACTIVE logEntry.error shouldBe "" logEntry.properties.get("dataSourceName").get shouldBe "testDataSource" } - it should "construct log entry with integer jobStartTime value" in { + it should "construct log entry with integer timestamp value" in { + // Use Integer instead of Long for timestamps val testSourceMap = JMap.of( "jobStartTime", - 1234567890.asInstanceOf[Object], // Integer instead of Long + 1234567890.asInstanceOf[Object], + "lastRefreshStartTime", + 1234567890.asInstanceOf[Object], + "lastRefreshCompleteTime", + 1234567890.asInstanceOf[Object], "state", "active".asInstanceOf[Object], "dataSourceName", @@ -87,6 +102,8 @@ class FlintMetadataLogEntryOpenSearchConverterTest logEntry shouldBe a[FlintMetadataLogEntry] logEntry.id shouldBe "id" logEntry.createTime shouldBe 1234567890 + logEntry.lastRefreshStartTime shouldBe 1234567890 + logEntry.lastRefreshCompleteTime shouldBe 1234567890 logEntry.state shouldBe FlintMetadataLogEntry.IndexState.ACTIVE logEntry.error shouldBe "" logEntry.properties.get("dataSourceName").get shouldBe "testDataSource" diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapperTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapperTest.java index fa57da842..43bd8d2b2 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapperTest.java +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapperTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.storage; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -24,11 +29,14 @@ import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; import org.opensearch.flint.core.http.FlintRetryOptions; +import org.opensearch.flint.core.metrics.MetricConstants; +import org.opensearch.flint.core.metrics.MetricsTestUtil; import org.opensearch.rest.RestStatus; @ExtendWith(MockitoExtension.class) class OpenSearchBulkRetryWrapperTest { + private static final long ESTIMATED_SIZE_IN_BYTES = 1000L; @Mock BulkRequest bulkRequest; @Mock @@ -45,12 +53,7 @@ class OpenSearchBulkRetryWrapperTest { DocWriteResponse docWriteResponse; @Mock IndexRequest indexRequest0, indexRequest1; - @Mock IndexRequest docWriteRequest2; -// BulkItemRequest[] bulkItemRequests = new BulkItemRequest[] { -// new BulkItemRequest(0, docWriteRequest0), -// new BulkItemRequest(1, docWriteRequest1), -// new BulkItemRequest(2, docWriteRequest2), -// }; + BulkItemResponse successItem = new BulkItemResponse(0, OpType.CREATE, docWriteResponse); BulkItemResponse failureItem = new BulkItemResponse(0, OpType.CREATE, new Failure("index", "id", null, @@ -65,87 +68,125 @@ class OpenSearchBulkRetryWrapperTest { @Test public void withRetryWhenCallSucceed() throws Exception { - OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( - retryOptionsWithRetry); - when(client.bulk(bulkRequest, options)).thenReturn(successResponse); - when(successResponse.hasFailures()).thenReturn(false); - - BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); - - assertEquals(response, successResponse); - verify(client).bulk(bulkRequest, options); + MetricsTestUtil.withMetricEnv(verifier -> { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(bulkRequest, options)).thenReturn(successResponse); + when(successResponse.hasFailures()).thenReturn(false); + when(bulkRequest.estimatedSizeInBytes()).thenReturn(ESTIMATED_SIZE_IN_BYTES); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, successResponse); + verify(client).bulk(bulkRequest, options); + + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_SIZE_METRIC, ESTIMATED_SIZE_IN_BYTES); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_RETRY_COUNT_METRIC, 0); + verifier.assertMetricNotExist(MetricConstants.OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC); + }); } @Test public void withRetryWhenCallConflict() throws Exception { - OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( - retryOptionsWithRetry); - when(client.bulk(any(), eq(options))) - .thenReturn(conflictResponse); - mockConflictResponse(); - when(conflictResponse.hasFailures()).thenReturn(true); - - BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); - - assertEquals(response, conflictResponse); - verify(client).bulk(bulkRequest, options); + MetricsTestUtil.withMetricEnv(verifier -> { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(any(), eq(options))) + .thenReturn(conflictResponse); + mockConflictResponse(); + when(conflictResponse.hasFailures()).thenReturn(true); + when(bulkRequest.estimatedSizeInBytes()).thenReturn(ESTIMATED_SIZE_IN_BYTES); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, conflictResponse); + verify(client).bulk(bulkRequest, options); + + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_SIZE_METRIC, ESTIMATED_SIZE_IN_BYTES); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_RETRY_COUNT_METRIC, 0); + verifier.assertMetricNotExist(MetricConstants.OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC); + }); } @Test public void withRetryWhenCallFailOnce() throws Exception { - OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( - retryOptionsWithRetry); - when(client.bulk(any(), eq(options))) - .thenReturn(failureResponse) - .thenReturn(successResponse); - mockFailureResponse(); - when(successResponse.hasFailures()).thenReturn(false); - when(bulkRequest.requests()).thenReturn(ImmutableList.of(indexRequest0, indexRequest1)); - - BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); - - assertEquals(response, successResponse); - verify(client, times(2)).bulk(any(), eq(options)); + MetricsTestUtil.withMetricEnv(verifier -> { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(any(), eq(options))) + .thenReturn(failureResponse) + .thenReturn(successResponse); + mockFailureResponse(); + when(successResponse.hasFailures()).thenReturn(false); + when(bulkRequest.requests()).thenReturn(ImmutableList.of(indexRequest0, indexRequest1)); + when(bulkRequest.estimatedSizeInBytes()).thenReturn(ESTIMATED_SIZE_IN_BYTES); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, successResponse); + verify(client, times(2)).bulk(any(), eq(options)); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_SIZE_METRIC, ESTIMATED_SIZE_IN_BYTES); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_RETRY_COUNT_METRIC, 1); + verifier.assertMetricNotExist(MetricConstants.OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC); + }); } @Test public void withRetryWhenAllCallFail() throws Exception { - OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( - retryOptionsWithRetry); - when(client.bulk(any(), eq(options))) - .thenReturn(failureResponse); - mockFailureResponse(); - - BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); - - assertEquals(response, failureResponse); - verify(client, times(3)).bulk(any(), eq(options)); + MetricsTestUtil.withMetricEnv(verifier -> { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(any(), eq(options))) + .thenReturn(failureResponse); + when(bulkRequest.estimatedSizeInBytes()).thenReturn(ESTIMATED_SIZE_IN_BYTES); + mockFailureResponse(); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, failureResponse); + verify(client, times(3)).bulk(any(), eq(options)); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_SIZE_METRIC, ESTIMATED_SIZE_IN_BYTES); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_RETRY_COUNT_METRIC, 2); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC, 1); + }); } @Test public void withRetryWhenCallThrowsShouldNotRetry() throws Exception { - OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( - retryOptionsWithRetry); - when(client.bulk(bulkRequest, options)).thenThrow(new RuntimeException("test")); - - assertThrows(RuntimeException.class, - () -> bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options)); - - verify(client).bulk(bulkRequest, options); + MetricsTestUtil.withMetricEnv(verifier -> { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(bulkRequest, options)).thenThrow(new RuntimeException("test")); + when(bulkRequest.estimatedSizeInBytes()).thenReturn(ESTIMATED_SIZE_IN_BYTES); + + assertThrows(RuntimeException.class, + () -> bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options)); + + verify(client).bulk(bulkRequest, options); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_SIZE_METRIC, ESTIMATED_SIZE_IN_BYTES); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_RETRY_COUNT_METRIC, 0); + verifier.assertMetricNotExist(MetricConstants.OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC); + }); } @Test public void withoutRetryWhenCallFail() throws Exception { - OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( - retryOptionsWithoutRetry); - when(client.bulk(bulkRequest, options)) - .thenReturn(failureResponse); - mockFailureResponse(); - - BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); - - assertEquals(response, failureResponse); - verify(client).bulk(bulkRequest, options); + MetricsTestUtil.withMetricEnv(verifier -> { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithoutRetry); + when(client.bulk(bulkRequest, options)) + .thenReturn(failureResponse); + when(bulkRequest.estimatedSizeInBytes()).thenReturn(ESTIMATED_SIZE_IN_BYTES); + mockFailureResponse(); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, failureResponse); + verify(client).bulk(bulkRequest, options); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_SIZE_METRIC, ESTIMATED_SIZE_IN_BYTES); + verifier.assertHistoricGauge(MetricConstants.OPENSEARCH_BULK_RETRY_COUNT_METRIC, 0); + verifier.assertMetricNotExist(MetricConstants.OPENSEARCH_BULK_ALL_RETRY_FAILED_COUNT_METRIC); + }); } private void mockFailureResponse() { diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 68721d235..bdcc120c0 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -253,6 +253,10 @@ object FlintSparkConf { FlintConfig("spark.metadata.accessAWSCredentialsProvider") .doc("AWS credentials provider for metadata access permission") .createOptional() + val METADATA_CACHE_WRITE = FlintConfig("spark.flint.metadataCacheWrite.enabled") + .doc("Enable Flint metadata cache write to Flint index mappings") + .createWithDefault("false") + val CUSTOM_SESSION_MANAGER = FlintConfig("spark.flint.job.customSessionManager") .createOptional() @@ -309,6 +313,8 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable def monitorMaxErrorCount(): Int = MONITOR_MAX_ERROR_COUNT.readFrom(reader).toInt + def isMetadataCacheWriteEnabled: Boolean = METADATA_CACHE_WRITE.readFrom(reader).toBoolean + /** * spark.sql.session.timeZone */ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 309877fdd..68d2409ee 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -20,9 +20,11 @@ import org.opensearch.flint.core.metadata.log.FlintMetadataLogServiceBuilder import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._ import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex +import org.opensearch.flint.spark.metadatacache.FlintMetadataCacheWriterBuilder import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode._ +import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.opensearch.flint.spark.scheduler.{AsyncQuerySchedulerBuilder, FlintSparkJobExternalSchedulingService, FlintSparkJobInternalSchedulingService, FlintSparkJobSchedulingService} import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder.AsyncQuerySchedulerAction import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex @@ -55,8 +57,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w FlintIndexMetadataServiceBuilder.build(flintSparkConf.flintOptions()) } + private val flintMetadataCacheWriter = FlintMetadataCacheWriterBuilder.build(flintSparkConf) + private val flintAsyncQueryScheduler: AsyncQueryScheduler = { - AsyncQuerySchedulerBuilder.build(flintSparkConf.flintOptions()) + AsyncQuerySchedulerBuilder.build(spark, flintSparkConf.flintOptions()) } override protected val flintMetadataLogService: FlintMetadataLogService = { @@ -116,7 +120,6 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w throw new IllegalStateException(s"Flint index $indexName already exists") } } else { - val metadata = index.metadata() val jobSchedulingService = FlintSparkJobSchedulingService.create( index, spark, @@ -128,15 +131,18 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .transientLog(latest => latest.copy(state = CREATING)) .finalLog(latest => latest.copy(state = ACTIVE)) .commit(latest => { - if (latest == null) { // in case transaction capability is disabled - flintClient.createIndex(indexName, metadata) - flintIndexMetadataService.updateIndexMetadata(indexName, metadata) - } else { - logInfo(s"Creating index with metadata log entry ID ${latest.id}") - flintClient.createIndex(indexName, metadata.copy(latestId = Some(latest.id))) - flintIndexMetadataService - .updateIndexMetadata(indexName, metadata.copy(latestId = Some(latest.id))) + val metadata = latest match { + case null => // in case transaction capability is disabled + index.metadata() + case latestEntry => + logInfo(s"Creating index with metadata log entry ID ${latestEntry.id}") + index + .metadata() + .copy(latestId = Some(latestEntry.id), latestLogEntry = Some(latest)) } + flintClient.createIndex(indexName, metadata) + flintIndexMetadataService.updateIndexMetadata(indexName, metadata) + flintMetadataCacheWriter.updateMetadataCache(indexName, metadata) jobSchedulingService.handleJob(index, AsyncQuerySchedulerAction.SCHEDULE) }) } @@ -155,22 +161,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val index = describeIndex(indexName) .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) val indexRefresh = FlintSparkIndexRefresh.create(indexName, index) - tx - .initialLog(latest => latest.state == ACTIVE) - .transientLog(latest => - latest.copy(state = REFRESHING, createTime = System.currentTimeMillis())) - .finalLog(latest => { - // Change state to active if full, otherwise update index state regularly - if (indexRefresh.refreshMode == AUTO) { - logInfo("Scheduling index state monitor") - flintIndexMonitor.startMonitor(indexName) - latest - } else { - logInfo("Updating index state to active") - latest.copy(state = ACTIVE) - } - }) - .commit(_ => indexRefresh.start(spark, flintSparkConf)) + indexRefresh.refreshMode match { + case AUTO => refreshIndexAuto(index, indexRefresh, tx) + case FULL | INCREMENTAL => refreshIndexManual(index, indexRefresh, tx) + } }.flatten /** @@ -189,7 +183,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w attachLatestLogEntry(indexName, metadata) } .toList - .flatMap(FlintSparkIndexFactory.create) + .flatMap(metadata => FlintSparkIndexFactory.create(spark, metadata)) } else { Seq.empty } @@ -208,7 +202,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w if (flintClient.exists(indexName)) { val metadata = flintIndexMetadataService.getIndexMetadata(indexName) val metadataWithEntry = attachLatestLogEntry(indexName, metadata) - FlintSparkIndexFactory.create(metadataWithEntry) + FlintSparkIndexFactory.create(spark, metadataWithEntry) } else { Option.empty } @@ -229,16 +223,16 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val originalOptions = describeIndex(indexName) .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) .options - validateUpdateAllowed(originalOptions, index.options) - val isSchedulerModeChanged = - index.options.isExternalSchedulerEnabled() != originalOptions.isExternalSchedulerEnabled() + validateUpdateAllowed(originalOptions, index.options) withTransaction[Option[String]](indexName, "Update Flint index") { tx => // Relies on validation to prevent: // 1. auto-to-auto updates besides scheduler_mode // 2. any manual-to-manual updates // 3. both refresh_mode and scheduler_mode updated - (index.options.autoRefresh(), isSchedulerModeChanged) match { + ( + index.options.autoRefresh(), + isSchedulerModeChanged(originalOptions, index.options)) match { case (true, true) => updateSchedulerMode(index, tx) case (true, false) => updateIndexManualToAuto(index, tx) case (false, false) => updateIndexAutoToManual(index, tx) @@ -333,7 +327,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val index = describeIndex(indexName) if (index.exists(_.options.autoRefresh())) { - val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get + val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(spark, index.get).get FlintSparkIndexRefresh .create(updatedIndex.name(), updatedIndex) .validate(spark) @@ -478,11 +472,17 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w "Altering index to full/incremental refresh") case (false, true) => - // original refresh_mode is auto, only allow changing scheduler_mode - validateChangedOptions( - changedOptions, - Set(SCHEDULER_MODE), - "Altering index when auto_refresh remains true") + // original refresh_mode is auto, only allow changing scheduler_mode and potentially refresh_interval + var allowedOptions = Set(SCHEDULER_MODE) + val schedulerMode = + if (updatedOptions.isExternalSchedulerEnabled()) SchedulerMode.EXTERNAL + else SchedulerMode.INTERNAL + val contextPrefix = + s"Altering index when auto_refresh remains true and scheduler_mode is $schedulerMode" + if (updatedOptions.isExternalSchedulerEnabled()) { + allowedOptions += REFRESH_INTERVAL + } + validateChangedOptions(changedOptions, allowedOptions, contextPrefix) case (false, false) => // original refresh_mode is full/incremental, not allowed to change any options @@ -507,6 +507,69 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w } } + private def isSchedulerModeChanged( + originalOptions: FlintSparkIndexOptions, + updatedOptions: FlintSparkIndexOptions): Boolean = { + updatedOptions.isExternalSchedulerEnabled() != originalOptions.isExternalSchedulerEnabled() + } + + /** + * Handles refresh for refresh mode AUTO, which is used exclusively by auto refresh index with + * internal scheduler. Refresh start time and complete time aren't tracked for streaming job. + * TODO: in future, track MicroBatchExecution time for streaming job and update as well + */ + private def refreshIndexAuto( + index: FlintSparkIndex, + indexRefresh: FlintSparkIndexRefresh, + tx: OptimisticTransaction[Option[String]]): Option[String] = { + val indexName = index.name + tx + .initialLog(latest => latest.state == ACTIVE) + .transientLog(latest => + latest.copy(state = REFRESHING, createTime = System.currentTimeMillis())) + .finalLog(latest => { + logInfo("Scheduling index state monitor") + flintIndexMonitor.startMonitor(indexName) + latest + }) + .commit(_ => indexRefresh.start(spark, flintSparkConf)) + } + + /** + * Handles refresh for refresh mode FULL and INCREMENTAL, which is used by full refresh index, + * incremental refresh index, and auto refresh index with external scheduler. Stores refresh + * start time and complete time. + */ + private def refreshIndexManual( + index: FlintSparkIndex, + indexRefresh: FlintSparkIndexRefresh, + tx: OptimisticTransaction[Option[String]]): Option[String] = { + val indexName = index.name + tx + .initialLog(latest => latest.state == ACTIVE) + .transientLog(latest => { + val currentTime = System.currentTimeMillis() + val updatedLatest = latest + .copy(state = REFRESHING, createTime = currentTime, lastRefreshStartTime = currentTime) + flintMetadataCacheWriter + .updateMetadataCache( + indexName, + index.metadata.copy(latestLogEntry = Some(updatedLatest))) + updatedLatest + }) + .finalLog(latest => { + logInfo("Updating index state to active") + val updatedLatest = + latest.copy(state = ACTIVE, lastRefreshCompleteTime = System.currentTimeMillis()) + flintMetadataCacheWriter + .updateMetadataCache( + indexName, + index.metadata.copy(latestLogEntry = Some(updatedLatest))) + updatedLatest + }) + .commit(_ => indexRefresh.start(spark, flintSparkConf)) + } + private def updateIndexAutoToManual( index: FlintSparkIndex, tx: OptimisticTransaction[Option[String]]): Option[String] = { @@ -526,8 +589,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .transientLog(latest => latest.copy(state = UPDATING)) .finalLog(latest => latest.copy(state = jobSchedulingService.stateTransitions.finalStateForUnschedule)) - .commit(_ => { + .commit(latest => { flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) + flintMetadataCacheWriter + .updateMetadataCache(indexName, index.metadata.copy(latestLogEntry = Some(latest))) logInfo("Update index options complete") jobSchedulingService.handleJob(index, AsyncQuerySchedulerAction.UNSCHEDULE) None @@ -553,8 +618,10 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w .finalLog(latest => { latest.copy(state = jobSchedulingService.stateTransitions.finalStateForUpdate) }) - .commit(_ => { + .commit(latest => { flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) + flintMetadataCacheWriter + .updateMetadataCache(indexName, index.metadata.copy(latestLogEntry = Some(latest))) logInfo("Update index options complete") jobSchedulingService.handleJob(index, AsyncQuerySchedulerAction.UPDATE) }) @@ -566,7 +633,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val indexName = index.name val indexLogEntry = index.latestLogEntry.get val internalSchedulingService = - new FlintSparkJobInternalSchedulingService(spark, flintIndexMonitor) + new FlintSparkJobInternalSchedulingService(spark, flintSparkConf, flintIndexMonitor) val externalSchedulingService = new FlintSparkJobExternalSchedulingService(flintAsyncQueryScheduler, flintSparkConf) @@ -587,7 +654,8 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) logInfo("Update index options complete") oldService.handleJob(index, AsyncQuerySchedulerAction.UNSCHEDULE) - logInfo(s"Unscheduled ${if (isExternal) "internal" else "external"} jobs") + logInfo( + s"Unscheduled refresh jobs from ${if (isExternal) "internal" else "external"} scheduler") newService.handleJob(index, AsyncQuerySchedulerAction.UPDATE) }) } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkCheckpoint.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkCheckpoint.scala index 4c18fea77..6ae55bd33 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkCheckpoint.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkCheckpoint.scala @@ -8,6 +8,7 @@ package org.opensearch.flint.spark import java.util.UUID import org.apache.hadoop.fs.{FSDataOutputStream, Path} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -75,7 +76,9 @@ class FlintSparkCheckpoint(spark: SparkSession, val checkpointLocation: String) */ def delete(): Unit = { try { - checkpointManager.delete(checkpointRootDir) + MetricsUtil.withLatencyAsHistoricGauge( + () => checkpointManager.delete(checkpointRootDir), + MetricConstants.CHECKPOINT_DELETE_TIME_METRIC) logInfo(s"Checkpoint directory $checkpointRootDir deleted") } catch { case e: Exception => diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 0391741cf..2ff2883a9 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -92,7 +92,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get) + validateIndex(FlintSparkIndexFactory.create(flint.spark, updatedMetadata).get) } /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index 78636d992..ca659550d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -25,6 +25,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession /** * Flint Spark index factory that encapsulates specific Flint index instance creation. This is for @@ -35,14 +36,16 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index from generic Flint metadata. * + * @param spark + * Spark session * @param metadata * Flint metadata * @return * Flint index instance, or None if any error during creation */ - def create(metadata: FlintMetadata): Option[FlintSparkIndex] = { + def create(spark: SparkSession, metadata: FlintMetadata): Option[FlintSparkIndex] = { try { - Some(doCreate(metadata)) + Some(doCreate(spark, metadata)) } catch { case e: Exception => logWarning(s"Failed to create Flint index from metadata $metadata", e) @@ -53,24 +56,26 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index with default options. * + * @param spark + * Spark session * @param index * Flint index - * @param metadata - * Flint metadata * @return * Flint index with default options */ - def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = { + def createWithDefaultOptions( + spark: SparkSession, + index: FlintSparkIndex): Option[FlintSparkIndex] = { val originalOptions = index.options val updatedOptions = FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions) val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - this.create(updatedMetadata) + this.create(spark, updatedMetadata) } - private def doCreate(metadata: FlintMetadata): FlintSparkIndex = { + private def doCreate(spark: SparkSession, metadata: FlintMetadata): FlintSparkIndex = { val indexOptions = FlintSparkIndexOptions( metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) val latestLogEntry = metadata.latestLogEntry @@ -118,6 +123,7 @@ object FlintSparkIndexFactory extends Logging { FlintSparkMaterializedView( metadata.name, metadata.source, + getMvSourceTables(spark, metadata), metadata.indexedColumns.map { colInfo => getString(colInfo, "columnName") -> getString(colInfo, "columnType") }.toMap, @@ -134,6 +140,15 @@ object FlintSparkIndexFactory extends Logging { .toMap } + private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = { + val sourceTables = getArrayString(metadata.properties, "sourceTables") + if (sourceTables.isEmpty) { + FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source) + } else { + sourceTables + } + } + private def getString(map: java.util.Map[String, AnyRef], key: String): String = { map.get(key).asInstanceOf[String] } @@ -146,4 +161,12 @@ object FlintSparkIndexFactory extends Logging { Some(value.asInstanceOf[String]) } } + + private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = { + map.get(key) match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index 4bfc50c55..9b58a696c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -10,7 +10,6 @@ import java.util.{Collections, UUID} import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization -import org.opensearch.flint.core.logging.CustomLogging.logInfo import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode @@ -257,7 +256,7 @@ object FlintSparkIndexOptions { .externalSchedulerIntervalThreshold()) case (false, _, Some("external")) => throw new IllegalArgumentException( - "spark.flint.job.externalScheduler.enabled is false but refresh interval is set to external scheduler mode") + "spark.flint.job.externalScheduler.enabled is false but scheduler_mode is set to external") case _ => updatedOptions += (SCHEDULER_MODE.toString -> SchedulerMode.INTERNAL.toString) } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala index 1aaa85075..7e9922655 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala @@ -11,9 +11,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} +import org.apache.spark.sql.flint.{loadTable, parseTableName} /** * Flint Spark validation helper. @@ -31,16 +30,10 @@ trait FlintSparkValidationHelper extends Logging { * true if all non Hive, otherwise false */ def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = { - // Extract source table name (possibly more than one for MV query) val tableNames = index match { case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName) case covering: FlintSparkCoveringIndex => Seq(covering.tableName) - case mv: FlintSparkMaterializedView => - spark.sessionState.sqlParser - .parsePlan(mv.query) - .collect { case relation: UnresolvedRelation => - qualifyTableName(spark, relation.tableName) - } + case mv: FlintSparkMaterializedView => mv.sourceTables.toSeq } // Validate if any source table is not supported (currently Hive only) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintDisabledMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintDisabledMetadataCacheWriter.scala new file mode 100644 index 000000000..4099da3ff --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintDisabledMetadataCacheWriter.scala @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.opensearch.flint.common.metadata.FlintMetadata + +/** + * Default implementation of {@link FlintMetadataCacheWriter} that does nothing + */ +class FlintDisabledMetadataCacheWriter extends FlintMetadataCacheWriter { + override def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit = { + // Do nothing + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala new file mode 100644 index 000000000..e1c0f318c --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import scala.collection.JavaConverters.mapAsScalaMapConverter + +import org.opensearch.flint.common.metadata.FlintMetadata +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.spark.FlintSparkIndexOptions +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser + +/** + * Flint metadata cache defines metadata required to store in read cache for frontend user to + * access. + */ +case class FlintMetadataCache( + metadataCacheVersion: String, + /** Refresh interval for Flint index with auto refresh. Unit: seconds */ + refreshInterval: Option[Int], + /** Source table names for building the Flint index. */ + sourceTables: Array[String], + /** Timestamp when Flint index is last refreshed. Unit: milliseconds */ + lastRefreshTime: Option[Long]) { + + /** + * Convert FlintMetadataCache to a map. Skips a field if its value is not defined. + */ + def toMap: Map[String, AnyRef] = { + val fieldNames = getClass.getDeclaredFields.map(_.getName) + val fieldValues = productIterator.toList + + fieldNames + .zip(fieldValues) + .flatMap { + case (_, None) => List.empty + case (name, Some(value)) => List((name, value)) + case (name, value) => List((name, value)) + } + .toMap + .mapValues(_.asInstanceOf[AnyRef]) + } +} + +object FlintMetadataCache { + + val metadataCacheVersion = "1.0" + + def apply(metadata: FlintMetadata): FlintMetadataCache = { + val indexOptions = FlintSparkIndexOptions( + metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) + val refreshInterval = if (indexOptions.autoRefresh()) { + indexOptions + .refreshInterval() + .map(IntervalSchedulerParser.parseAndConvertToMillis) + .map(millis => (millis / 1000).toInt) // convert to seconds + } else { + None + } + val sourceTables = metadata.kind match { + case MV_INDEX_TYPE => + metadata.properties.get("sourceTables") match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + case _ => Array(metadata.source) + } + val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry => + entry.lastRefreshCompleteTime match { + case FlintMetadataLogEntry.EMPTY_TIMESTAMP => None + case timestamp => Some(timestamp) + } + } + + FlintMetadataCache(metadataCacheVersion, refreshInterval, sourceTables, lastRefreshTime) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriter.scala new file mode 100644 index 000000000..c256463c3 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriter.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.opensearch.flint.common.metadata.{FlintIndexMetadataService, FlintMetadata} + +/** + * Writes {@link FlintMetadataCache} to a storage of choice. This is different from {@link + * FlintIndexMetadataService} which persists the full index metadata to a storage for single + * source of truth. + */ +trait FlintMetadataCacheWriter { + + /** + * Update metadata cache for a Flint index. + * + * @param indexName + * index name + * @param metadata + * index metadata to update the cache + */ + def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit + +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriterBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriterBuilder.scala new file mode 100644 index 000000000..be821ae25 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheWriterBuilder.scala @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.apache.spark.sql.flint.config.FlintSparkConf + +object FlintMetadataCacheWriterBuilder { + def build(flintSparkConf: FlintSparkConf): FlintMetadataCacheWriter = { + if (flintSparkConf.isMetadataCacheWriteEnabled) { + new FlintOpenSearchMetadataCacheWriter(flintSparkConf.flintOptions()) + } else { + new FlintDisabledMetadataCacheWriter + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala new file mode 100644 index 000000000..2bc373792 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import java.util + +import scala.collection.JavaConverters._ + +import org.opensearch.client.RequestOptions +import org.opensearch.client.indices.PutMappingRequest +import org.opensearch.common.xcontent.XContentType +import org.opensearch.flint.common.metadata.{FlintIndexMetadataService, FlintMetadata} +import org.opensearch.flint.core.{FlintOptions, IRestHighLevelClient} +import org.opensearch.flint.core.metadata.FlintIndexMetadataServiceBuilder +import org.opensearch.flint.core.metadata.FlintJsonHelper._ +import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} + +import org.apache.spark.internal.Logging + +/** + * Writes {@link FlintMetadataCache} to index mappings `_meta` field for frontend user to access. + */ +class FlintOpenSearchMetadataCacheWriter(options: FlintOptions) + extends FlintMetadataCacheWriter + with Logging { + + /** + * Since metadata cache shares the index mappings _meta field with OpenSearch index metadata + * storage, this flag is to allow for preserving index metadata that is already stored in _meta + * when updating metadata cache. + */ + private val includeSpec: Boolean = + FlintIndexMetadataServiceBuilder + .build(options) + .isInstanceOf[FlintOpenSearchIndexMetadataService] + + override def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit = { + logInfo(s"Updating metadata cache for $indexName"); + val osIndexName = OpenSearchClientUtils.sanitizeIndexName(indexName) + var client: IRestHighLevelClient = null + try { + client = OpenSearchClientUtils.createClient(options) + val request = new PutMappingRequest(osIndexName) + request.source(serialize(metadata), XContentType.JSON) + client.updateIndexMapping(request, RequestOptions.DEFAULT) + } catch { + case e: Exception => + throw new IllegalStateException( + s"Failed to update metadata cache for Flint index $osIndexName", + e) + } finally + if (client != null) { + client.close() + } + } + + /** + * Serialize FlintMetadataCache from FlintMetadata. Modified from {@link + * FlintOpenSearchIndexMetadataService} + */ + private[metadatacache] def serialize(metadata: FlintMetadata): String = { + try { + buildJson(builder => { + objectField(builder, "_meta") { + // If _meta is used as index metadata storage, preserve them. + if (includeSpec) { + builder + .field("version", metadata.version.version) + .field("name", metadata.name) + .field("kind", metadata.kind) + .field("source", metadata.source) + .field("indexedColumns", metadata.indexedColumns) + + if (metadata.latestId.isDefined) { + builder.field("latestId", metadata.latestId.get) + } + optionalObjectField(builder, "options", metadata.options) + } + + optionalObjectField(builder, "properties", buildPropertiesMap(metadata)) + } + builder.field("properties", metadata.schema) + }) + } catch { + case e: Exception => + throw new IllegalStateException("Failed to jsonify cache metadata", e) + } + } + + /** + * Since _meta.properties is shared by both index metadata and metadata cache, here we merge the + * two maps. + */ + private def buildPropertiesMap(metadata: FlintMetadata): util.Map[String, AnyRef] = { + val metadataCacheProperties = FlintMetadataCache(metadata).toMap + + if (includeSpec) { + (metadataCacheProperties ++ metadata.properties.asScala).asJava + } else { + metadataCacheProperties.asJava + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index caa75be75..aecfc99df 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * MV name * @param query * source query that generates MV data + * @param sourceTables + * source table names * @param outputSchema * output schema * @param options @@ -44,6 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class FlintSparkMaterializedView( mvName: String, query: String, + sourceTables: Array[String], outputSchema: Map[String, String], override val options: FlintSparkIndexOptions = empty, override val latestLogEntry: Option[FlintMetadataLogEntry] = None) @@ -64,6 +67,7 @@ case class FlintSparkMaterializedView( metadataBuilder(this) .name(mvName) .source(query) + .addProperty("sourceTables", sourceTables) .indexedColumns(indexColumnMaps) .schema(schema) .build() @@ -165,10 +169,30 @@ object FlintSparkMaterializedView { flintIndexNamePrefix(mvName) } + /** + * Extract source table names (possibly more than one) from the query. + * + * @param spark + * Spark session + * @param query + * source query that generates MV data + * @return + * source table names + */ + def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = { + spark.sessionState.sqlParser + .parsePlan(query) + .collect { case relation: UnresolvedRelation => + qualifyTableName(spark, relation.tableName) + } + .toArray + } + /** Builder class for MV build */ class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { private var mvName: String = "" private var query: String = "" + private var sourceTables: Array[String] = Array.empty[String] /** * Set MV name. @@ -193,6 +217,7 @@ object FlintSparkMaterializedView { */ def query(query: String): Builder = { this.query = query + this.sourceTables = extractSourceTableNames(flint.spark, query) this } @@ -221,7 +246,7 @@ object FlintSparkMaterializedView { field.name -> field.dataType.simpleString } .toMap - FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions) + FlintSparkMaterializedView(mvName, query, sourceTables, outputSchema, indexOptions) } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala index d343fd999..bedeeba54 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.refresh import java.util.Collections +import org.opensearch.flint.core.metrics.ReadWriteBytesSparkListener import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions, FlintSparkValidationHelper} import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode} @@ -67,15 +68,17 @@ class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) // Flint index has specialized logic and capability for incremental refresh case refresh: StreamingRefresh => logInfo("Start refreshing index in streaming style") - val job = - refresh - .buildStream(spark) - .writeStream - .queryName(indexName) - .format(FLINT_DATASOURCE) - .options(flintSparkConf.properties) - .addSinkOptions(options, flintSparkConf) - .start(indexName) + val job = ReadWriteBytesSparkListener.withMetrics( + spark, + () => + refresh + .buildStream(spark) + .writeStream + .queryName(indexName) + .format(FLINT_DATASOURCE) + .options(flintSparkConf.properties) + .addSinkOptions(options, flintSparkConf) + .start(indexName)) Some(job.id.toString) // Otherwise, fall back to foreachBatch + batch refresh diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java index 3620608b0..330b38f02 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java @@ -7,9 +7,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.flint.config.FlintSparkConf; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.core.FlintOptions; +import java.io.IOException; import java.lang.reflect.Constructor; /** @@ -28,11 +31,27 @@ public enum AsyncQuerySchedulerAction { REMOVE } - public static AsyncQueryScheduler build(FlintOptions options) { + public static AsyncQueryScheduler build(SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilder().doBuild(sparkSession, options); + } + + /** + * Builds an AsyncQueryScheduler based on the provided options. + * + * @param sparkSession The SparkSession to be used. + * @param options The FlintOptions containing configuration details. + * @return An instance of AsyncQueryScheduler. + */ + protected AsyncQueryScheduler doBuild(SparkSession sparkSession, FlintOptions options) throws IOException { String className = options.getCustomAsyncQuerySchedulerClass(); if (className.isEmpty()) { - return new OpenSearchAsyncQueryScheduler(options); + OpenSearchAsyncQueryScheduler scheduler = createOpenSearchAsyncQueryScheduler(options); + // Check if the scheduler has access to the required index. Disable the external scheduler otherwise. + if (!hasAccessToSchedulerIndex(scheduler)){ + setExternalSchedulerEnabled(sparkSession, false); + } + return scheduler; } // Attempts to instantiate AsyncQueryScheduler using reflection @@ -45,4 +64,16 @@ public static AsyncQueryScheduler build(FlintOptions options) { throw new RuntimeException("Failed to instantiate AsyncQueryScheduler: " + className, e); } } + + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return new OpenSearchAsyncQueryScheduler(options); + } + + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return scheduler.hasAccessToSchedulerIndex(); + } + + protected void setExternalSchedulerEnabled(SparkSession sparkSession, boolean enabled) { + sparkSession.sqlContext().setConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED().key(), String.valueOf(enabled)); + } } \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala index d043746c0..197b2f8c7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala @@ -10,6 +10,7 @@ import java.time.Instant import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState import org.opensearch.flint.common.scheduler.AsyncQueryScheduler import org.opensearch.flint.common.scheduler.model.{AsyncQuerySchedulerRequest, LangType} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.opensearch.flint.core.storage.OpenSearchClientUtils import org.opensearch.flint.spark.FlintSparkIndex import org.opensearch.flint.spark.refresh.util.RefreshMetricsAspect @@ -83,8 +84,16 @@ class FlintSparkJobExternalSchedulingService( case AsyncQuerySchedulerAction.REMOVE => flintAsyncQueryScheduler.removeJob(request) case _ => throw new IllegalArgumentException(s"Unsupported action: $action") } + addExternalSchedulerMetrics(action) None // Return None for all cases } } + + private def addExternalSchedulerMetrics(action: AsyncQuerySchedulerAction): Unit = { + val actionName = action.name().toLowerCase() + MetricsUtil.addHistoricGauge( + MetricConstants.EXTERNAL_SCHEDULER_METRIC_PREFIX + actionName + ".count", + 1) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala index d22eff2c9..8928357c7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf */ class FlintSparkJobInternalSchedulingService( spark: SparkSession, + flintSparkConf: FlintSparkConf, flintIndexMonitor: FlintSparkIndexMonitor) extends FlintSparkJobSchedulingService with Logging { @@ -55,12 +56,9 @@ class FlintSparkJobInternalSchedulingService( index: FlintSparkIndex, action: AsyncQuerySchedulerAction): Option[String] = { val indexName = index.name() - action match { case AsyncQuerySchedulerAction.SCHEDULE => None // No-op case AsyncQuerySchedulerAction.UPDATE => - logInfo("Scheduling index state monitor") - flintIndexMonitor.startMonitor(indexName) startRefreshingJob(index) case AsyncQuerySchedulerAction.UNSCHEDULE => logInfo("Stopping index state monitor") @@ -81,7 +79,17 @@ class FlintSparkJobInternalSchedulingService( private def startRefreshingJob(index: FlintSparkIndex): Option[String] = { logInfo(s"Starting refreshing job for index ${index.name()}") val indexRefresh = FlintSparkIndexRefresh.create(index.name(), index) - indexRefresh.start(spark, new FlintSparkConf(spark.conf.getAll.toMap.asJava)) + val jobId = indexRefresh.start(spark, flintSparkConf) + + // NOTE: Resolution for previous concurrency issue + // This code addresses a previously identified concurrency issue with recoverIndex + // where scheduled FlintSparkIndexMonitorTask couldn't detect the active Spark streaming job ID. The issue + // was caused by starting the FlintSparkIndexMonitor before the Spark streaming job was fully + // initialized. In this fixed version, we start the monitor after the streaming job has been + // initiated, ensuring that the job ID is available for detection. + logInfo("Scheduling index state monitor") + flintIndexMonitor.startMonitor(index.name()) + jobId } /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala index 6e25d8a8c..b813c7dd0 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala @@ -70,7 +70,7 @@ object FlintSparkJobSchedulingService { if (isExternalSchedulerEnabled(index)) { new FlintSparkJobExternalSchedulingService(flintAsyncQueryScheduler, flintSparkConf) } else { - new FlintSparkJobInternalSchedulingService(spark, flintIndexMonitor) + new FlintSparkJobInternalSchedulingService(spark, flintSparkConf, flintIndexMonitor) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java index 19532254b..a1ef45825 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -37,6 +38,7 @@ import org.opensearch.jobscheduler.spi.schedule.Schedule; import org.opensearch.rest.RestStatus; +import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.time.Instant; @@ -55,6 +57,11 @@ public class OpenSearchAsyncQueryScheduler implements AsyncQueryScheduler { private static final ObjectMapper mapper = new ObjectMapper(); private final FlintOptions flintOptions; + @VisibleForTesting + public OpenSearchAsyncQueryScheduler() { + this.flintOptions = new FlintOptions(ImmutableMap.of()); + } + public OpenSearchAsyncQueryScheduler(FlintOptions options) { this.flintOptions = options; } @@ -124,6 +131,28 @@ void createAsyncQuerySchedulerIndex(IRestHighLevelClient client) { } } + /** + * Checks if the current setup has access to the scheduler index. + * + * This method attempts to create a client and ensure that the scheduler index exists. + * If these operations succeed, it indicates that the user has the necessary permissions + * to access and potentially modify the scheduler index. + * + * @see #createClient() + * @see #ensureIndexExists(IRestHighLevelClient) + */ + public boolean hasAccessToSchedulerIndex() throws IOException { + IRestHighLevelClient client = createClient(); + try { + ensureIndexExists(client); + return true; + } catch (Throwable e) { + LOG.error("Failed to ensure index exists", e); + return false; + } finally { + client.close(); + } + } private void ensureIndexExists(IRestHighLevelClient client) { try { if (!client.doesIndexExist(new GetIndexRequest(SCHEDULER_INDEX_NAME), RequestOptions.DEFAULT)) { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java index 8745681b9..9622b4c64 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/IntervalSchedulerParser.java @@ -18,21 +18,30 @@ public class IntervalSchedulerParser { /** - * Parses a schedule string into an IntervalSchedule. + * Parses a schedule string into an integer in milliseconds. * * @param scheduleStr the schedule string to parse - * @return the parsed IntervalSchedule + * @return the parsed integer * @throws IllegalArgumentException if the schedule string is invalid */ - public static IntervalSchedule parse(String scheduleStr) { + public static Long parseAndConvertToMillis(String scheduleStr) { if (Strings.isNullOrEmpty(scheduleStr)) { throw new IllegalArgumentException("Schedule string must not be null or empty."); } - Long millis = Triggers.convert(scheduleStr); + return Triggers.convert(scheduleStr); + } + /** + * Parses a schedule string into an IntervalSchedule. + * + * @param scheduleStr the schedule string to parse + * @return the parsed IntervalSchedule + * @throws IllegalArgumentException if the schedule string is invalid + */ + public static IntervalSchedule parse(String scheduleStr) { // Convert milliseconds to minutes (rounding down) - int minutes = (int) (millis / (60 * 1000)); + int minutes = (int) (parseAndConvertToMillis(scheduleStr) / (60 * 1000)); // Use the current time as the start time Instant startTime = Instant.now(); diff --git a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java index 67b5afee5..3c65a96a5 100644 --- a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java +++ b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java @@ -5,43 +5,80 @@ package org.opensearch.flint.core.scheduler; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.SQLContext; +import org.junit.Before; import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.common.scheduler.model.AsyncQuerySchedulerRequest; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder; import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler; +import java.io.IOException; + import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AsyncQuerySchedulerBuilderTest { + @Mock + private SparkSession sparkSession; + + @Mock + private SQLContext sqlContext; + + private AsyncQuerySchedulerBuilderForLocalTest testBuilder; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + when(sparkSession.sqlContext()).thenReturn(sqlContext); + } + + @Test + public void testBuildWithEmptyClassNameAndAccessibleIndex() throws IOException { + FlintOptions options = mock(FlintOptions.class); + when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); + + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, true, sparkSession, options); + assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext, never()).setConf(anyString(), anyString()); + } @Test - public void testBuildWithEmptyClassName() { + public void testBuildWithEmptyClassNameAndInaccessibleIndex() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, false, sparkSession, options); assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext).setConf("spark.flint.job.externalScheduler.enabled", "false"); } @Test - public void testBuildWithCustomClassName() { + public void testBuildWithCustomClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); - when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); + when(options.getCustomAsyncQuerySchedulerClass()) + .thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(sparkSession, options); assertTrue(scheduler instanceof AsyncQuerySchedulerForLocalTest); } @Test(expected = RuntimeException.class) - public void testBuildWithInvalidClassName() { + public void testBuildWithInvalidClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("invalid.ClassName"); - AsyncQuerySchedulerBuilder.build(options); + AsyncQuerySchedulerBuilder.build(sparkSession, options); } public static class AsyncQuerySchedulerForLocalTest implements AsyncQueryScheduler { @@ -65,4 +102,35 @@ public void removeJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) { // Custom implementation } } + + public static class OpenSearchAsyncQuerySchedulerForLocalTest extends OpenSearchAsyncQueryScheduler { + @Override + public boolean hasAccessToSchedulerIndex() { + return true; + } + } + + public static class AsyncQuerySchedulerBuilderForLocalTest extends AsyncQuerySchedulerBuilder { + private OpenSearchAsyncQueryScheduler mockScheduler; + private Boolean mockHasAccess; + + public AsyncQuerySchedulerBuilderForLocalTest(OpenSearchAsyncQueryScheduler mockScheduler, Boolean mockHasAccess) { + this.mockScheduler = mockScheduler; + this.mockHasAccess = mockHasAccess; + } + + @Override + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return mockScheduler != null ? mockScheduler : super.createOpenSearchAsyncQueryScheduler(options); + } + + @Override + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return mockHasAccess != null ? mockHasAccess : super.hasAccessToSchedulerIndex(scheduler); + } + + public static AsyncQueryScheduler build(OpenSearchAsyncQueryScheduler asyncQueryScheduler, Boolean hasAccess, SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilderForLocalTest(asyncQueryScheduler, hasAccess).doBuild(sparkSession, options); + } + } } \ No newline at end of file diff --git a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java index 2ad1fea9c..731e1ae5c 100644 --- a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java +++ b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/util/IntervalSchedulerParserTest.java @@ -23,53 +23,92 @@ public void testParseNull() { IntervalSchedulerParser.parse(null); } + @Test(expected = IllegalArgumentException.class) + public void testParseMillisNull() { + IntervalSchedulerParser.parseAndConvertToMillis(null); + } + @Test(expected = IllegalArgumentException.class) public void testParseEmptyString() { IntervalSchedulerParser.parse(""); } + @Test(expected = IllegalArgumentException.class) + public void testParseMillisEmptyString() { + IntervalSchedulerParser.parseAndConvertToMillis(""); + } + @Test public void testParseString() { - Schedule result = IntervalSchedulerParser.parse("10 minutes"); - assertTrue(result instanceof IntervalSchedule); - IntervalSchedule intervalSchedule = (IntervalSchedule) result; + Schedule schedule = IntervalSchedulerParser.parse("10 minutes"); + assertTrue(schedule instanceof IntervalSchedule); + IntervalSchedule intervalSchedule = (IntervalSchedule) schedule; assertEquals(10, intervalSchedule.getInterval()); assertEquals(ChronoUnit.MINUTES, intervalSchedule.getUnit()); } + @Test + public void testParseMillisString() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("10 minutes"); + assertEquals(600000, millis.longValue()); + } + @Test(expected = IllegalArgumentException.class) public void testParseInvalidFormat() { IntervalSchedulerParser.parse("invalid format"); } + @Test(expected = IllegalArgumentException.class) + public void testParseMillisInvalidFormat() { + IntervalSchedulerParser.parseAndConvertToMillis("invalid format"); + } + @Test public void testParseStringScheduleMinutes() { - IntervalSchedule result = IntervalSchedulerParser.parse("5 minutes"); - assertEquals(5, result.getInterval()); - assertEquals(ChronoUnit.MINUTES, result.getUnit()); + IntervalSchedule schedule = IntervalSchedulerParser.parse("5 minutes"); + assertEquals(5, schedule.getInterval()); + assertEquals(ChronoUnit.MINUTES, schedule.getUnit()); + } + + @Test + public void testParseMillisStringScheduleMinutes() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("5 minutes"); + assertEquals(300000, millis.longValue()); } @Test public void testParseStringScheduleHours() { - IntervalSchedule result = IntervalSchedulerParser.parse("2 hours"); - assertEquals(120, result.getInterval()); - assertEquals(ChronoUnit.MINUTES, result.getUnit()); + IntervalSchedule schedule = IntervalSchedulerParser.parse("2 hours"); + assertEquals(120, schedule.getInterval()); + assertEquals(ChronoUnit.MINUTES, schedule.getUnit()); + } + + @Test + public void testParseMillisStringScheduleHours() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("2 hours"); + assertEquals(7200000, millis.longValue()); } @Test public void testParseStringScheduleDays() { - IntervalSchedule result = IntervalSchedulerParser.parse("1 day"); - assertEquals(1440, result.getInterval()); - assertEquals(ChronoUnit.MINUTES, result.getUnit()); + IntervalSchedule schedule = IntervalSchedulerParser.parse("1 day"); + assertEquals(1440, schedule.getInterval()); + assertEquals(ChronoUnit.MINUTES, schedule.getUnit()); + } + + @Test + public void testParseMillisStringScheduleDays() { + Long millis = IntervalSchedulerParser.parseAndConvertToMillis("1 day"); + assertEquals(86400000, millis.longValue()); } @Test public void testParseStringScheduleStartTime() { Instant before = Instant.now(); - IntervalSchedule result = IntervalSchedulerParser.parse("30 minutes"); + IntervalSchedule schedule = IntervalSchedulerParser.parse("30 minutes"); Instant after = Instant.now(); - assertTrue(result.getStartTime().isAfter(before) || result.getStartTime().equals(before)); - assertTrue(result.getStartTime().isBefore(after) || result.getStartTime().equals(after)); + assertTrue(schedule.getStartTime().isAfter(before) || schedule.getStartTime().equals(before)); + assertTrue(schedule.getStartTime().isBefore(after) || schedule.getStartTime().equals(after)); } } \ No newline at end of file diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index ee8a52d96..b675265b7 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -9,8 +9,8 @@ import org.opensearch.flint.spark.FlintSparkExtensions import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.flint.config.FlintConfigEntry -import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED +import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} +import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -26,6 +26,10 @@ trait FlintSuite extends SharedSparkSession { // ConstantPropagation etc. .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) .set("spark.sql.extensions", classOf[FlintSparkExtensions].getName) + // Override scheduler class for unit testing + .set( + FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key, + "org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest") conf } @@ -44,4 +48,22 @@ trait FlintSuite extends SharedSparkSession { setFlintSparkConf(HYBRID_SCAN_ENABLED, "false") } } + + protected def withExternalSchedulerEnabled(block: => Unit): Unit = { + setFlintSparkConf(EXTERNAL_SCHEDULER_ENABLED, "true") + try { + block + } finally { + setFlintSparkConf(EXTERNAL_SCHEDULER_ENABLED, "false") + } + } + + protected def withMetadataCacheWriteEnabled(block: => Unit): Unit = { + setFlintSparkConf(METADATA_CACHE_WRITE, "true") + try { + block + } finally { + setFlintSparkConf(METADATA_CACHE_WRITE, "false") + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala index 063c32074..80b788253 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala @@ -208,14 +208,14 @@ class FlintSparkIndexBuilderSuite None, None, Some( - "spark.flint.job.externalScheduler.enabled is false but refresh interval is set to external scheduler mode")), + "spark.flint.job.externalScheduler.enabled is false but scheduler_mode is set to external")), ( - "set external mode when interval above threshold and no mode specified", + "set external mode when interval below threshold and no mode specified", true, "5 minutes", - Map("auto_refresh" -> "true", "refresh_interval" -> "10 minutes"), - Some(SchedulerMode.EXTERNAL.toString), - Some("10 minutes"), + Map("auto_refresh" -> "true", "refresh_interval" -> "1 minutes"), + Some(SchedulerMode.INTERNAL.toString), + Some("1 minutes"), None), ( "throw exception when interval below threshold but mode is external", diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala new file mode 100644 index 000000000..07720ff24 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.scalatest.matchers.should.Matchers._ + +import org.apache.spark.FlintSuite + +class FlintSparkIndexFactorySuite extends FlintSuite { + + test("create mv should generate source tables if missing in metadata") { + val testTable = "spark_catalog.default.mv_build_test" + val testMvName = "spark_catalog.default.mv" + val testQuery = s"SELECT * FROM $testTable" + + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "indexedColumns": [ + | { + | "columnType": "int", + | "columnName": "age" + | } + | ], + | "name": "$testMvName", + | "source": "$testQuery" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val metadata = FlintOpenSearchIndexMetadataService.deserialize(content) + val index = FlintSparkIndexFactory.create(spark, metadata) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index 2c5518778..96c71d94b 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -246,6 +246,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { new FlintMetadataLogEntry( "id", 0, + 0, + 0, state, Map("seqNo" -> 0, "primaryTerm" -> 0), "", diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala new file mode 100644 index 000000000..6ec6cf696 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { + val flintMetadataLogEntry = FlintMetadataLogEntry( + "id", + 0L, + 0L, + 1234567890123L, + FlintMetadataLogEntry.IndexState.ACTIVE, + Map.empty[String, Any], + "", + Map.empty[String, Any]) + + it should "construct from skipping index FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from covering index FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$COVERING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from materialized view FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "source": "spark_catalog.default.wrong_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | }, + | "properties": { + | "sourceTables": [ + | "spark_catalog.default.test_table", + | "spark_catalog.default.another_table" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array( + "spark_catalog.default.test_table", + "spark_catalog.default.another_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from FlintMetadata excluding invalid fields" in { + // Set auto_refresh = false and lastRefreshCompleteTime = 0 + val content = + s""" { + | "_meta": { + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry.copy(lastRefreshCompleteTime = 0L))) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval shouldBe empty + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime shouldBe empty + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index c1df42883..1c9a9e83c 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConv import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} -import org.scalatest.matchers.should.Matchers.{contain, convertToAnyShouldWrapper, the} +import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite @@ -37,31 +37,34 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = "SELECT 1" test("get mv name") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" } test("get mv name with dots") { val testMvNameDots = "spark_catalog.default.mv.2023.10" - val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv.2023.10" } test("should fail if get name with unqualified MV name") { the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("mv", testQuery, Array.empty, Map.empty).name() the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("default.mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("default.mv", testQuery, Array.empty, Map.empty).name() } test("get metadata") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map("test_col" -> "integer")) + val mv = + FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map("test_col" -> "integer")) val metadata = mv.metadata() metadata.name shouldBe mv.mvName metadata.kind shouldBe MV_INDEX_TYPE metadata.source shouldBe "SELECT 1" + metadata.properties should contain key "sourceTables" + metadata.properties.get("sourceTables").asInstanceOf[Array[String]] should have size 0 metadata.indexedColumns shouldBe Array( Map("columnName" -> "test_col", "columnType" -> "integer").asJava) metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava @@ -74,6 +77,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, testQuery, + Array.empty, Map("test_col" -> "integer"), indexOptions) @@ -83,12 +87,12 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build batch data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.build(spark, None).collect() shouldBe Array(Row(1)) } test("should fail if build given other source data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) the[IllegalArgumentException] thrownBy mv.build(spark, Some(mock[DataFrame])) } @@ -103,7 +107,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -128,7 +132,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -144,7 +148,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { test("build stream with non-aggregate query") { val testQuery = s"SELECT name, age FROM $testTable WHERE age > 30" - withAggregateMaterializedView(testQuery, Map.empty) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), Map.empty) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -158,7 +162,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = s"SELECT name, age FROM $testTable" val options = Map("extra_options" -> s"""{"$testTable": {"maxFilesPerTrigger": "1"}}""") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable, Map("maxFilesPerTrigger" -> "1")) @@ -175,6 +179,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), Map.empty) the[IllegalStateException] thrownBy @@ -182,14 +187,20 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } } - private def withAggregateMaterializedView(query: String, options: Map[String, String])( - codeBlock: LogicalPlan => Unit): Unit = { + private def withAggregateMaterializedView( + query: String, + sourceTables: Array[String], + options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { withTable(testTable) { sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView(testMvName, query, Map.empty, FlintSparkIndexOptions(options)) + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) val actualPlan = mv.buildStream(spark).queryExecution.logical codeBlock(actualPlan) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala index f03116de9..d56c4e66f 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala @@ -132,6 +132,8 @@ class ApplyFlintSparkSkippingIndexSuite extends FlintSuite with Matchers { new FlintMetadataLogEntry( "id", 0L, + 0L, + 0L, indexState, Map.empty[String, Any], "", diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala index 9aeba7512..33702f23f 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala @@ -25,10 +25,12 @@ class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { val testFlintIndex = "flint_test_index" val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) - val testCreateTime = 1234567890123L + val testTimestamp = 1234567890123L val flintMetadataLogEntry = FlintMetadataLogEntry( testLatestId, - testCreateTime, + testTimestamp, + testTimestamp, + testTimestamp, ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), "", @@ -85,8 +87,10 @@ class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { val latest = metadataLog.get.getLatest latest.isPresent shouldBe true latest.get.id shouldBe testLatestId - latest.get.createTime shouldBe testCreateTime + latest.get.createTime shouldBe testTimestamp latest.get.error shouldBe "" + latest.get.lastRefreshStartTime shouldBe testTimestamp + latest.get.lastRefreshCompleteTime shouldBe testTimestamp latest.get.properties.get("dataSourceName").get shouldBe testDataSourceName } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala index 605e8e7fd..df5f4eec2 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala @@ -24,6 +24,7 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { val testFlintIndex = "flint_test_index" val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) + val testTimestamp = 1234567890123L var flintMetadataLogService: FlintMetadataLogService = _ override def beforeAll(): Unit = { @@ -40,6 +41,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { latest.id shouldBe testLatestId latest.state shouldBe EMPTY latest.createTime shouldBe 0L + latest.lastRefreshStartTime shouldBe 0L + latest.lastRefreshCompleteTime shouldBe 0L latest.error shouldBe "" latest.properties.get("dataSourceName").get shouldBe testDataSourceName true @@ -49,11 +52,12 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { } test("should preserve original values when transition") { - val testCreateTime = 1234567890123L createLatestLogEntry( FlintMetadataLogEntry( id = testLatestId, - createTime = testCreateTime, + createTime = testTimestamp, + lastRefreshStartTime = testTimestamp, + lastRefreshCompleteTime = testTimestamp, state = ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), error = "", @@ -63,8 +67,10 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { .startTransaction(testFlintIndex) .initialLog(latest => { latest.id shouldBe testLatestId - latest.createTime shouldBe testCreateTime + latest.createTime shouldBe testTimestamp latest.error shouldBe "" + latest.lastRefreshStartTime shouldBe testTimestamp + latest.lastRefreshCompleteTime shouldBe testTimestamp latest.properties.get("dataSourceName").get shouldBe testDataSourceName true }) @@ -72,8 +78,10 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { .finalLog(latest => latest.copy(state = DELETED)) .commit(latest => { latest.id shouldBe testLatestId - latest.createTime shouldBe testCreateTime + latest.createTime shouldBe testTimestamp latest.error shouldBe "" + latest.lastRefreshStartTime shouldBe testTimestamp + latest.lastRefreshCompleteTime shouldBe testTimestamp latest.properties.get("dataSourceName").get shouldBe testDataSourceName }) @@ -112,7 +120,9 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { createLatestLogEntry( FlintMetadataLogEntry( id = testLatestId, - createTime = 1234567890123L, + createTime = testTimestamp, + lastRefreshStartTime = testTimestamp, + lastRefreshCompleteTime = testTimestamp, state = ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), error = "", @@ -198,7 +208,9 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { createLatestLogEntry( FlintMetadataLogEntry( id = testLatestId, - createTime = 1234567890123L, + createTime = testTimestamp, + lastRefreshStartTime = testTimestamp, + lastRefreshCompleteTime = testTimestamp, state = ACTIVE, Map("seqNo" -> UNASSIGNED_SEQ_NO, "primaryTerm" -> UNASSIGNED_PRIMARY_TERM), error = "", @@ -240,6 +252,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { latest.id shouldBe testLatestId latest.state shouldBe EMPTY latest.createTime shouldBe 0L + latest.lastRefreshStartTime shouldBe 0L + latest.lastRefreshCompleteTime shouldBe 0L latest.error shouldBe "" latest.properties.get("dataSourceName").get shouldBe testDataSourceName true diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index c2f0f9101..fc77faaea 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -17,10 +17,10 @@ import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName -import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.CHECKPOINT_LOCATION -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTableNames, getFlintIndexName} import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler -import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.{DataFrame, Row} @@ -52,6 +52,29 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { deleteTestIndex(testFlintIndex) } + test("extract source table names from materialized view source query successfully") { + val testComplexQuery = s""" + | SELECT * + | FROM ( + | SELECT 1 + | FROM table1 + | LEFT JOIN `table2` + | ) + | UNION ALL + | SELECT 1 + | FROM spark_catalog.default.`table/3` + | INNER JOIN spark_catalog.default.`table.4` + |""".stripMargin + extractSourceTableNames(flint.spark, testComplexQuery) should contain theSameElementsAs + Array( + "spark_catalog.default.table1", + "spark_catalog.default.table2", + "spark_catalog.default.`table/3`", + "spark_catalog.default.`table.4`") + + extractSourceTableNames(flint.spark, "SELECT 1") should have size 0 + } + test("create materialized view with metadata successfully") { withTempDir { checkpointDir => val indexOptions = @@ -92,7 +115,9 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | "scheduler_mode":"internal" | }, | "latestId": "$testLatestId", - | "properties": {} + | "properties": { + | "sourceTables": ["$testTable"] + | } | }, | "properties": { | "startTime": { @@ -108,6 +133,22 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { } } + test("create materialized view should parse source tables successfully") { + val indexOptions = FlintSparkIndexOptions(Map.empty) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(indexOptions, testFlintIndex) + .create() + + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } + test("create materialized view with default checkpoint location successfully") { withTempDir { checkpointDir => setFlintSparkConf( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 1ecf48d28..c53eee548 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} import java.util.Comparator import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} @@ -23,6 +23,7 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.{FlintSuite, SparkConf} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY} import org.apache.spark.sql.streaming.StreamTest @@ -49,6 +50,8 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit override def beforeAll(): Unit = { super.beforeAll() + // Revoke override in FlintSuite on IT + conf.unsetConf(FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key) // Replace executor to avoid impact on IT. // TODO: Currently no IT test scheduler so no need to restore it back. @@ -534,6 +537,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiValueStructTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_value Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)) ), + | ( 2, array(STRUCT("2_Monday", 2), null) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 4, null ) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( @@ -642,4 +667,153 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (6, 403, '/home', null) | """.stripMargin) } + + protected def createNullableJsonContentTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | id INT, + | jString STRING, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, '{"account_number":1,"balance":39225,"age":32,"gender":"M"}', true), + | (2, '{"f1":"abc","f2":{"f3":"a","f4":"b"}}', true), + | (3, '[1,2,3,{"f1":1,"f2":[5,6]},4]', true), + | (4, '[]', true), + | (5, '{"teacher":"Alice","student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}', true), + | (6, '[1,2,3]', true), + | (7, '[1,2', false), + | (8, '[invalid json]', false), + | (9, '{"invalid": "json"', false), + | (10, 'invalid json', false), + | (11, null, false) + | """.stripMargin) + } + + protected def createIpAddressTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | id INT, + | ipAddress STRING, + | isV6 BOOLEAN, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, '127.0.0.1', false, true), + | (2, '192.168.1.0', false, true), + | (3, '192.168.1.1', false, true), + | (4, '192.168.2.1', false, true), + | (5, '192.168.2.', false, false), + | (6, '2001:db8::ff00:12:3455', true, true), + | (7, '2001:db8::ff00:12:3456', true, true), + | (8, '2001:db8::ff00:13:3457', true, true), + | (9, '2001:db8::ff00:12:', true, false) + | """.stripMargin) + } + + protected def createNestedJsonContentTable(tempFile: Path, testTable: String): Unit = { + val json = + """ + |[ + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Tower Bridge", "length": 801}, + | {"name": "London Bridge", "length": 928} + | ], + | "city": "London", + | "country": "England", + | "coor": { + | "lat": 51.5074, + | "long": -0.1278, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Pont Neuf", "length": 232}, + | {"name": "Pont Alexandre III", "length": 160} + | ], + | "city": "Paris", + | "country": "France", + | "coor": { + | "lat": 48.8566, + | "long": 2.3522, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Rialto Bridge", "length": 48}, + | {"name": "Bridge of Sighs", "length": 11} + | ], + | "city": "Venice", + | "country": "Italy", + | "coor": { + | "lat": 45.4408, + | "long": 12.3155, + | "alt": 2 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Charles Bridge", "length": 516}, + | {"name": "Legion Bridge", "length": 343} + | ], + | "city": "Prague", + | "country": "Czech Republic", + | "coor": { + | "lat": 50.0755, + | "long": 14.4378, + | "alt": 200 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Chain Bridge", "length": 375}, + | {"name": "Liberty Bridge", "length": 333} + | ], + | "city": "Budapest", + | "country": "Hungary", + | "coor": { + | "lat": 47.4979, + | "long": 19.0402, + | "alt": 96 + | } + | }, + | { + | "_time": "1990-09-13T12:00:00", + | "bridges": null, + | "city": "Warsaw", + | "country": "Poland", + | "coor": null + | } + |] + |""".stripMargin + val tempFile = Files.createTempFile("jsonTestData", ".json") + val absolutPath = tempFile.toAbsolutePath.toString; + Files.write(tempFile, json.getBytes) + sql(s""" + | CREATE TEMPORARY VIEW $testTable + | USING org.apache.spark.sql.json + | OPTIONS ( + | path "$absolutPath", + | multiLine true + | ); + |""".stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala index debb95370..e16d40f2a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala @@ -55,6 +55,8 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match latestLogEntry(testLatestId) should (contain("latestId" -> testLatestId) and contain("state" -> "active") and contain("jobStartTime" -> 0) + and contain("lastRefreshStartTime" -> 0) + and contain("lastRefreshCompleteTime" -> 0) and contain("dataSourceName" -> testDataSourceName)) implicit val formats: Formats = Serialization.formats(NoTypeHints) @@ -77,9 +79,25 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match .create() flint.refreshIndex(testFlintIndex) - val latest = latestLogEntry(testLatestId) + var latest = latestLogEntry(testLatestId) + val prevJobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() latest should contain("state" -> "active") - latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L + prevJobStartTime should be > 0L + prevLastRefreshStartTime should be > 0L + prevLastRefreshCompleteTime should be > prevLastRefreshStartTime + + flint.refreshIndex(testFlintIndex) + latest = latestLogEntry(testLatestId) + val jobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val lastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val lastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() + jobStartTime should be > prevLastRefreshCompleteTime + lastRefreshStartTime should be > prevLastRefreshCompleteTime + lastRefreshCompleteTime should be > lastRefreshStartTime } test("incremental refresh index") { @@ -97,9 +115,26 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match .create() flint.refreshIndex(testFlintIndex) - val latest = latestLogEntry(testLatestId) + var latest = latestLogEntry(testLatestId) + val prevJobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshStartTime = + latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() latest should contain("state" -> "active") - latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L + prevJobStartTime should be > 0L + prevLastRefreshStartTime should be > 0L + prevLastRefreshCompleteTime should be > prevLastRefreshStartTime + + flint.refreshIndex(testFlintIndex) + latest = latestLogEntry(testLatestId) + val jobStartTime = latest("jobStartTime").asInstanceOf[Number].longValue() + val lastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val lastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() + jobStartTime should be > prevLastRefreshCompleteTime + lastRefreshStartTime should be > prevLastRefreshCompleteTime + lastRefreshCompleteTime should be > lastRefreshStartTime } } @@ -142,6 +177,8 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match val latest = latestLogEntry(testLatestId) latest should contain("state" -> "refreshing") latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L + latest("lastRefreshStartTime").asInstanceOf[Number].longValue() shouldBe 0L + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() shouldBe 0L } test("update auto refresh index to full refresh index") { @@ -153,13 +190,24 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match .create() flint.refreshIndex(testFlintIndex) + var latest = latestLogEntry(testLatestId) + val prevLastRefreshStartTime = latest("lastRefreshStartTime").asInstanceOf[Number].longValue() + val prevLastRefreshCompleteTime = + latest("lastRefreshCompleteTime").asInstanceOf[Number].longValue() + val index = flint.describeIndex(testFlintIndex).get val updatedIndex = flint .skippingIndex() .copyWithUpdate(index, FlintSparkIndexOptions(Map("auto_refresh" -> "false"))) flint.updateIndex(updatedIndex) - val latest = latestLogEntry(testLatestId) + latest = latestLogEntry(testLatestId) latest should contain("state" -> "active") + latest("lastRefreshStartTime") + .asInstanceOf[Number] + .longValue() shouldBe prevLastRefreshStartTime + latest("lastRefreshCompleteTime") + .asInstanceOf[Number] + .longValue() shouldBe prevLastRefreshCompleteTime } test("delete and vacuum index") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index 53889045f..c9f6c47f7 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -5,10 +5,15 @@ package org.opensearch.flint.spark +import scala.jdk.CollectionConverters.mapAsJavaMapConverter + import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods._ +import org.opensearch.action.get.GetRequest import org.opensearch.client.RequestOptions -import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} +import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.opensearch.index.query.QueryBuilders import org.opensearch.index.reindex.DeleteByQueryRequest @@ -180,6 +185,91 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { } } + test("update auto refresh index to switch scheduler mode") { + setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, "true") + + withTempDir { checkpointDir => + // Create auto refresh Flint index + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year", "month") + .options( + FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "refresh_interval" -> "4 Minute", + "checkpoint_location" -> checkpointDir.getAbsolutePath)), + testIndex) + .create() + flint.refreshIndex(testIndex) + + val indexInitial = flint.describeIndex(testIndex).get + indexInitial.options.refreshInterval() shouldBe Some("4 Minute") + indexInitial.options.isExternalSchedulerEnabled() shouldBe false + + // Update Flint index to change refresh interval + val updatedIndex = flint + .skippingIndex() + .copyWithUpdate( + indexInitial, + FlintSparkIndexOptions( + Map("scheduler_mode" -> "external", "refresh_interval" -> "5 Minutes"))) + flint.updateIndex(updatedIndex) + + // Verify index after update + val indexFinal = flint.describeIndex(testIndex).get + indexFinal.options.autoRefresh() shouldBe true + indexFinal.options.refreshInterval() shouldBe Some("5 Minutes") + indexFinal.options.isExternalSchedulerEnabled() shouldBe true + indexFinal.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) + + // Verify scheduler index is updated + verifySchedulerIndex(testIndex, 5, "MINUTES") + } + } + + test("update auto refresh index to change refresh interval") { + setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, "true") + + withTempDir { checkpointDir => + // Create auto refresh Flint index + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year", "month") + .options( + FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "refresh_interval" -> "10 Minute", + "checkpoint_location" -> checkpointDir.getAbsolutePath)), + testIndex) + .create() + + val indexInitial = flint.describeIndex(testIndex).get + indexInitial.options.refreshInterval() shouldBe Some("10 Minute") + verifySchedulerIndex(testIndex, 10, "MINUTES") + + // Update Flint index to change refresh interval + val updatedIndex = flint + .skippingIndex() + .copyWithUpdate( + indexInitial, + FlintSparkIndexOptions(Map("refresh_interval" -> "5 Minutes"))) + flint.updateIndex(updatedIndex) + + // Verify index after update + val indexFinal = flint.describeIndex(testIndex).get + indexFinal.options.autoRefresh() shouldBe true + indexFinal.options.refreshInterval() shouldBe Some("5 Minutes") + indexFinal.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) + + // Verify scheduler index is updated + verifySchedulerIndex(testIndex, 5, "MINUTES") + } + } + // Test update options validation failure with external scheduler Seq( ( @@ -207,12 +297,32 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { (Map.empty[String, String], Map("checkpoint_location" -> "s3a://test/"))), "No options can be updated when auto_refresh remains false"), ( - "update other index option besides scheduler_mode when auto_refresh is true", + "update index option when refresh_interval value belows threshold", + Seq( + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("refresh_interval" -> "4 minutes"))), + "Input refresh_interval is 4 minutes, required above the interval threshold of external scheduler: 5 minutes"), + ( + "update index option when no change on auto_refresh", + Seq( + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("scheduler_mode" -> "internal", "refresh_interval" -> "4 minutes")), + ( + Map( + "auto_refresh" -> "true", + "scheduler_mode" -> "internal", + "checkpoint_location" -> "s3a://test/"), + Map("refresh_interval" -> "4 minutes"))), + "Altering index when auto_refresh remains true and scheduler_mode is internal only allows changing: Set(scheduler_mode). Invalid options"), + ( + "update other index option besides scheduler_mode and refresh_interval when auto_refresh is true", Seq( ( Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), Map("watermark_delay" -> "1 Minute"))), - "Altering index when auto_refresh remains true only allows changing: Set(scheduler_mode). Invalid options"), + "Altering index when auto_refresh remains true and scheduler_mode is external only allows changing: Set(scheduler_mode, refresh_interval). Invalid options"), ( "convert to full refresh with disallowed options", Seq( @@ -655,4 +765,28 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { flint.queryIndex(testIndex).collect().toSet should have size 1 } } + + private def verifySchedulerIndex( + indexName: String, + expectedPeriod: Int, + expectedUnit: String): Unit = { + val client = OpenSearchClientUtils.createClient(new FlintOptions(openSearchOptions.asJava)) + val response = client.get( + new GetRequest(OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME, indexName), + RequestOptions.DEFAULT) + + response.isExists shouldBe true + val sourceMap = response.getSourceAsMap + + sourceMap.get("jobId") shouldBe indexName + sourceMap.get( + "scheduledQuery") shouldBe s"REFRESH SKIPPING INDEX ON spark_catalog.default.`test`" + sourceMap.get("enabled") shouldBe true + sourceMap.get("queryLang") shouldBe "sql" + + val schedule = sourceMap.get("schedule").asInstanceOf[java.util.Map[String, Any]] + val interval = schedule.get("interval").asInstanceOf[java.util.Map[String, Any]] + interval.get("period") shouldBe expectedPeriod + interval.get("unit") shouldBe expectedUnit + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala new file mode 100644 index 000000000..c0d253fd3 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala @@ -0,0 +1,457 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.metadatacache + +import java.util.{Base64, List} + +import scala.collection.JavaConverters._ + +import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.json4s.native.JsonMethods._ +import org.opensearch.flint.common.FlintVersion.current +import org.opensearch.flint.common.metadata.FlintMetadata +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} +import org.opensearch.flint.spark.{FlintSparkIndexOptions, FlintSparkSuite} +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} +import org.scalatest.Entry +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.flint.config.FlintSparkConf + +class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Matchers { + + /** Lazy initialize after container started. */ + lazy val options = new FlintOptions(openSearchOptions.asJava) + lazy val flintClient = new FlintOpenSearchClient(options) + lazy val flintMetadataCacheWriter = new FlintOpenSearchMetadataCacheWriter(options) + lazy val flintIndexMetadataService = new FlintOpenSearchIndexMetadataService(options) + + private val testTable = "spark_catalog.default.metadatacache_test" + private val testFlintIndex = getSkippingIndexName(testTable) + private val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) + private val testLastRefreshCompleteTime = 1234567890123L + private val flintMetadataLogEntry = FlintMetadataLogEntry( + testLatestId, + 0L, + 0L, + testLastRefreshCompleteTime, + FlintMetadataLogEntry.IndexState.ACTIVE, + Map.empty[String, Any], + "", + Map.empty[String, Any]) + + override def beforeAll(): Unit = { + super.beforeAll() + createPartitionedMultiRowAddressTable(testTable) + } + + override def afterAll(): Unit = { + sql(s"DROP TABLE $testTable") + super.afterAll() + } + + override def afterEach(): Unit = { + deleteTestIndex(testFlintIndex) + super.afterEach() + } + + test("build disabled metadata cache writer") { + FlintMetadataCacheWriterBuilder + .build(FlintSparkConf()) shouldBe a[FlintDisabledMetadataCacheWriter] + } + + test("build opensearch metadata cache writer") { + setFlintSparkConf(FlintSparkConf.METADATA_CACHE_WRITE, "true") + withMetadataCacheWriteEnabled { + FlintMetadataCacheWriterBuilder + .build(FlintSparkConf()) shouldBe a[FlintOpenSearchMetadataCacheWriter] + } + } + + test("serialize metadata cache to JSON") { + val expectedMetadataJson: String = s""" + | { + | "_meta": { + | "version": "${current()}", + | "name": "$testFlintIndex", + | "kind": "test_kind", + | "source": "$testTable", + | "indexedColumns": [ + | { + | "test_field": "spark_type" + | }], + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | }, + | "properties": { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "refreshInterval": 600, + | "sourceTables": ["$testTable"], + | "lastRefreshTime": $testLastRefreshCompleteTime + | }, + | "latestId": "$testLatestId" + | }, + | "properties": { + | "test_field": { + | "type": "os_type" + | } + | } + | } + |""".stripMargin + val builder = new FlintMetadata.Builder + builder.name(testFlintIndex) + builder.kind("test_kind") + builder.source(testTable) + builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava) + builder.options( + Map("auto_refresh" -> "true", "refresh_interval" -> "10 Minutes") + .mapValues(_.asInstanceOf[AnyRef]) + .asJava) + builder.schema(Map[String, AnyRef]("test_field" -> Map("type" -> "os_type").asJava).asJava) + builder.latestLogEntry(flintMetadataLogEntry) + + val metadata = builder.build() + flintMetadataCacheWriter.serialize(metadata) should matchJson(expectedMetadataJson) + } + + test("write metadata cache to index mappings") { + val metadata = FlintOpenSearchIndexMetadataService + .deserialize("{}") + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + Seq(SKIPPING_INDEX_TYPE, COVERING_INDEX_TYPE).foreach { case kind => + test(s"write metadata cache to $kind index mappings with source tables") { + val content = + s""" { + | "_meta": { + | "kind": "$kind", + | "source": "$testTable" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[List[String]] + .toArray should contain theSameElementsAs Array(testTable) + } + } + + test(s"write metadata cache to materialized view index mappings with source tables") { + val testTable2 = "spark_catalog.default.metadatacache_test2" + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": [ + | "$testTable", "$testTable2" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[List[String]] + .toArray should contain theSameElementsAs Array(testTable, testTable2) + } + + test("write metadata cache to index mappings with refresh interval") { + val content = + """ { + | "_meta": { + | "kind": "test_kind", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 4 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("refreshInterval", 600), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + test("exclude refresh interval in metadata cache when auto refresh is false") { + val content = + """ { + | "_meta": { + | "kind": "test_kind", + | "options": { + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + test("exclude last refresh time in metadata cache when index has not been refreshed") { + val metadata = FlintOpenSearchIndexMetadataService + .deserialize("{}") + .copy(latestLogEntry = Some(flintMetadataLogEntry.copy(lastRefreshCompleteTime = 0L))) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 2 + properties should contain( + Entry("metadataCacheVersion", FlintMetadataCache.metadataCacheVersion)) + } + + test("write metadata cache to index mappings and preserve other index metadata") { + val content = + """ { + | "_meta": { + | "kind": "test_kind" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + + flintIndexMetadataService.updateIndexMetadata(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + flintIndexMetadataService.getIndexMetadata(testFlintIndex).kind shouldBe "test_kind" + flintIndexMetadataService.getIndexMetadata(testFlintIndex).name shouldBe empty + flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 + var properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + + val newContent = + """ { + | "_meta": { + | "kind": "test_kind", + | "name": "test_name" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val newMetadata = FlintOpenSearchIndexMetadataService + .deserialize(newContent) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintIndexMetadataService.updateIndexMetadata(testFlintIndex, newMetadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, newMetadata) + + flintIndexMetadataService.getIndexMetadata(testFlintIndex).kind shouldBe "test_kind" + flintIndexMetadataService.getIndexMetadata(testFlintIndex).name shouldBe "test_name" + flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 + properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties should have size 3 + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), + Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + Seq( + ( + "auto refresh index with external scheduler", + Map( + "auto_refresh" -> "true", + "scheduler_mode" -> "external", + "refresh_interval" -> "10 Minute", + "checkpoint_location" -> "s3a://test/"), + s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "refreshInterval": 600, + | "sourceTables": ["$testTable"] + | } + |""".stripMargin), + ( + "full refresh index", + Map.empty[String, String], + s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] + | } + |""".stripMargin), + ( + "incremental refresh index", + Map("incremental_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] + | } + |""".stripMargin)).foreach { case (refreshMode, optionsMap, expectedJson) => + test(s"write metadata cache for $refreshMode") { + withExternalSchedulerEnabled { + withMetadataCacheWriteEnabled { + withTempDir { checkpointDir => + // update checkpoint_location if available in optionsMap + val indexOptions = FlintSparkIndexOptions( + optionsMap + .get("checkpoint_location") + .map(_ => + optionsMap.updated("checkpoint_location", checkpointDir.getAbsolutePath)) + .getOrElse(optionsMap)) + + flint + .skippingIndex() + .onTable(testTable) + .addMinMax("age") + .options(indexOptions, testFlintIndex) + .create() + + var index = flint.describeIndex(testFlintIndex) + index shouldBe defined + val propertiesJson = + compact( + render( + parse( + flintMetadataCacheWriter.serialize( + index.get.metadata())) \ "_meta" \ "properties")) + propertiesJson should matchJson(expectedJson) + + flint.refreshIndex(testFlintIndex) + index = flint.describeIndex(testFlintIndex) + index shouldBe defined + val lastRefreshTime = + compact( + render( + parse( + flintMetadataCacheWriter.serialize( + index.get.metadata())) \ "_meta" \ "properties" \ "lastRefreshTime")).toLong + lastRefreshTime should be > 0L + } + } + } + } + } + + test("write metadata cache for auto refresh index with internal scheduler") { + withMetadataCacheWriteEnabled { + withTempDir { checkpointDir => + flint + .skippingIndex() + .onTable(testTable) + .addMinMax("age") + .options( + FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "scheduler_mode" -> "internal", + "refresh_interval" -> "10 Minute", + "checkpoint_location" -> checkpointDir.getAbsolutePath)), + testFlintIndex) + .create() + + var index = flint.describeIndex(testFlintIndex) + index shouldBe defined + val propertiesJson = + compact( + render(parse( + flintMetadataCacheWriter.serialize(index.get.metadata())) \ "_meta" \ "properties")) + propertiesJson should matchJson(s""" + | { + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "refreshInterval": 600, + | "sourceTables": ["$testTable"] + | } + |""".stripMargin) + + flint.refreshIndex(testFlintIndex) + index = flint.describeIndex(testFlintIndex) + index shouldBe defined + compact(render(parse( + flintMetadataCacheWriter.serialize( + index.get.metadata())) \ "_meta" \ "properties")) should not include "lastRefreshTime" + } + } + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala index 1ece33ce1..465ce7d12 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala @@ -8,11 +8,8 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.{FlintPPLSparkExtensions, FlintSparkExtensions, FlintSparkSuite} import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.flint.config.FlintSparkConf.OPTIMIZER_RULE_ENABLED -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession trait FlintPPLSuite extends FlintSparkSuite { override protected def sparkConf: SparkConf = { @@ -24,4 +21,15 @@ trait FlintPPLSuite extends FlintSparkSuite { .set(OPTIMIZER_RULE_ENABLED.key, "false") conf } + + def assertSameRows(expected: Seq[Row], df: DataFrame): Unit = { + QueryTest.sameRows(expected, df.collect().toSeq).foreach { results => + fail(s""" + |Results do not match for query: + |${df.queryExecution} + |== Results == + |$results + """.stripMargin) + } + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index 55d3d0709..bcfe22764 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -1122,4 +1122,166 @@ class FlintSparkPPLAggregationsITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test count() at the first of stats clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats count() as cnt, sum(a) as sum, avg(a) as avg + | """.stripMargin) + assertSameRows(Seq(Row(4, 4, 1.0)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(count, sum, avg), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() in the middle of stats clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, count() as cnt, avg(a) as avg + | """.stripMargin) + assertSameRows(Seq(Row(4, 4, 1.0)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(sum, count, avg), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() at the end of stats clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt + | """.stripMargin) + assertSameRows(Seq(Row(4, 1.0, 4)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(sum, avg, count), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() at the first of stats by clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats count() as cnt, sum(a) as sum, avg(a) as avg by country + | """.stripMargin) + assertSameRows(Seq(Row(2, 2, 1.0, "Canada"), Row(2, 2, 1.0, "USA")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(count, sum, avg, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() in the middle of stats by clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, count() as cnt, avg(a) as avg by country + | """.stripMargin) + assertSameRows(Seq(Row(2, 2, 1.0, "Canada"), Row(2, 2, 1.0, "USA")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(sum, count, avg, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test count() at the end of stats by clause") { + val frame = sql(s""" + | source = $testTable | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt by country + | """.stripMargin) + assertSameRows(Seq(Row(2, 1.0, 2, "Canada"), Row(2, 1.0, 2, "USA")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), table) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(sum, avg, count, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBetweenITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBetweenITSuite.scala new file mode 100644 index 000000000..ce0be1409 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBetweenITSuite.scala @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import java.sql.Timestamp + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLBetweenITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + private val timeSeriesTestTable = "spark_catalog.default.flint_ppl_timeseries_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test tables + createPartitionedStateCountryTable(testTable) + createTimeSeriesTable(timeSeriesTestTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test between should return records between two integer values") { + val frame = sql(s""" + | source = $testTable | where age between 20 and 30 + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 3) + assert(frame.columns.length == 6) + + results.foreach(row => { + val age = row.getAs[Int]("age") + assert(age >= 20 && age <= 30, s"Age $age is not between 20 and 30") + }) + } + + test("test between should return records between two integer computed values") { + val frame = sql(s""" + | source = $testTable | where age between 20 + 1 and 30 - 1 + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 1) + assert(frame.columns.length == 6) + + results.foreach(row => { + val age = row.getAs[Int]("age") + assert(age >= 21 && age <= 29, s"Age $age is not between 21 and 29") + }) + } + + test("test between should return records NOT between two integer values") { + val frame = sql(s""" + | source = $testTable | where age NOT between 20 and 30 + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 1) + assert(frame.columns.length == 6) + + results.foreach(row => { + val age = row.getAs[Int]("age") + assert(age < 20 || age > 30, s"Age $age is not between 20 and 30") + }) + } + + test("test between should return records where NOT between two integer values") { + val frame = sql(s""" + | source = $testTable | where NOT age between 20 and 30 + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 1) + assert(frame.columns.length == 6) + + results.foreach(row => { + val age = row.getAs[Int]("age") + assert(age < 20 || age > 30, s"Age $age is not between 20 and 30") + }) + } + + test("test between should return records between two date values") { + val frame = sql(s""" + | source = $timeSeriesTestTable | where time between '2023-10-01 00:01:00' and '2023-10-01 00:10:00' + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 2) + assert(frame.columns.length == 4) + + results.foreach(row => { + val ts = row.getAs[Timestamp]("time") + assert( + !ts.before(Timestamp.valueOf("2023-10-01 00:01:00")) || !ts.after( + Timestamp.valueOf("2023-10-01 00:01:00")), + s"Timestamp $ts is not between '2023-10-01 00:01:00' and '2023-10-01 00:10:00'") + }) + } + + test("test between should return records NOT between two date values") { + val frame = sql(s""" + | source = $timeSeriesTestTable | where time NOT between '2023-10-01 00:01:00' and '2023-10-01 00:10:00' + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 3) + assert(frame.columns.length == 4) + + results.foreach(row => { + val ts = row.getAs[Timestamp]("time") + assert( + ts.before(Timestamp.valueOf("2023-10-01 00:01:00")) || ts.after( + Timestamp.valueOf("2023-10-01 00:01:00")), + s"Timestamp $ts is not between '2023-10-01 00:01:00' and '2023-10-01 00:10:00'") + }) + + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltInDateTimeFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltInDateTimeFunctionITSuite.scala new file mode 100644 index 000000000..8001a690d --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltInDateTimeFunctionITSuite.scala @@ -0,0 +1,585 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import java.sql.{Date, Timestamp} + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLBuiltInDateTimeFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test adddate(date, numDays)") { + val frame = sql(s""" + | source = $testTable + | | eval `'2020-08-26' + 1` = ADDDATE(DATE('2020-08-26'), 1), `'2020-08-26' + (-1)` = ADDDATE(DATE('2020-08-26'), -1) + | | fields `'2020-08-26' + 1`, `'2020-08-26' + (-1)` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-27"), Date.valueOf("2020-08-25"))), frame) + } + + test("test subdate(date, numDays)") { + val frame = sql(s""" + | source = $testTable + | | eval `'2020-08-26' - 1` = SUBDATE(DATE('2020-08-26'), 1), `'2020-08-26' - (-1)` = SUBDATE(DATE('2020-08-26'), -1) + | | fields `'2020-08-26' - 1`, `'2020-08-26' - (-1)` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-25"), Date.valueOf("2020-08-27"))), frame) + } + + test("test CURRENT_DATE, CURDATE are synonyms") { + val frame = sql(s""" + | source = $testTable + | | eval `CURRENT_DATE` = CURRENT_DATE(), `CURDATE` = CURDATE() + | | where CURRENT_DATE = CURDATE + | | fields CURRENT_DATE, CURDATE | head 1 + | """.stripMargin) + val results: Array[Row] = frame.collect() + assert(results.length == 1) + } + + test("test LOCALTIME, LOCALTIMESTAMP, NOW are synonyms") { + val frame = sql(s""" + | source = $testTable + | | eval `LOCALTIME` = LOCALTIME(), `LOCALTIMESTAMP` = LOCALTIMESTAMP(), `NOW` = NOW() + | | where LOCALTIME = LOCALTIMESTAMP and LOCALTIME = NOW + | | fields LOCALTIME, LOCALTIMESTAMP, NOW | head 1 + | """.stripMargin) + val results: Array[Row] = frame.collect() + assert(results.length == 1) + } + + test("test DATE, TIMESTAMP") { + val frame = sql(s""" + | source = $testTable + | | eval `DATE('2020-08-26')` = DATE('2020-08-26') + | | eval `DATE(TIMESTAMP('2020-08-26 13:49:00'))` = DATE(TIMESTAMP('2020-08-26 13:49:00')) + | | eval `DATE('2020-08-26 13:49')` = DATE('2020-08-26 13:49') + | | fields `DATE('2020-08-26')`, `DATE(TIMESTAMP('2020-08-26 13:49:00'))`, `DATE('2020-08-26 13:49')` + | | head 1 + | """.stripMargin) + assertSameRows( + Seq( + Row(Date.valueOf("2020-08-26"), Date.valueOf("2020-08-26"), Date.valueOf("2020-08-26"))), + frame) + } + + test("test DATE_FORMAT") { + val frame = sql(s""" + | source = $testTable + | | eval format1 = DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a') + | | eval format2 = DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS') + | | fields format1, format2 + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row("1998-Jan-31 01:14:15 PM", "13:14:15.012345")), frame) + } + + test("test DATEDIFF") { + val frame = sql(s""" + | source = $testTable + | | eval diff1 = DATEDIFF(DATE('2020-08-27'), DATE('2020-08-26')) + | | eval diff2 = DATEDIFF(DATE('2020-08-26'), DATE('2020-08-27')) + | | eval diff3 = DATEDIFF(DATE('2020-08-27'), DATE('2020-08-27')) + | | eval diff4 = DATEDIFF(DATE('2020-08-26'), '2020-08-27') + | | eval diff5 = DATEDIFF(TIMESTAMP('2000-01-02 00:00:00'), TIMESTAMP('2000-01-01 23:59:59')) + | | eval diff6 = DATEDIFF(DATE('2001-02-01'), TIMESTAMP('2004-01-01 00:00:00')) + | | fields diff1, diff2, diff3, diff4, diff5, diff6 + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(1, -1, 0, -1, 1, -1064)), frame) + } + + test("test DAY, DAYOFMONTH, DAY_OF_MONTH are synonyms") { + val frame = sql(s""" + | source = $testTable + | | eval `DAY(DATE('2020-08-26'))` = DAY(DATE('2020-08-26')) + | | eval `DAYOFMONTH(DATE('2020-08-26'))` = DAYOFMONTH(DATE('2020-08-26')) + | | eval `DAY_OF_MONTH(DATE('2020-08-26'))` = DAY_OF_MONTH(DATE('2020-08-26')) + | | fields `DAY(DATE('2020-08-26'))`, `DAYOFMONTH(DATE('2020-08-26'))`, `DAY_OF_MONTH(DATE('2020-08-26'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(26, 26, 26)), frame) + } + + test("test DAYOFWEEK, DAY_OF_WEEK are synonyms") { + val frame = sql(s""" + | source = $testTable + | | eval `DAYOFWEEK(DATE('2020-08-26'))` = DAYOFWEEK(DATE('2020-08-26')) + | | eval `DAY_OF_WEEK(DATE('2020-08-26'))` = DAY_OF_WEEK(DATE('2020-08-26')) + | | fields `DAYOFWEEK(DATE('2020-08-26'))`, `DAY_OF_WEEK(DATE('2020-08-26'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(4, 4)), frame) + } + + test("test DAYOFYEAR, DAY_OF_YEAR are synonyms") { + val frame = sql(s""" + | source = $testTable + | | eval `DAY_OF_YEAR(DATE('2020-08-26'))` = DAY_OF_YEAR(DATE('2020-08-26')) + | | eval `DAYOFYEAR(DATE('2020-08-26'))` = DAYOFYEAR(DATE('2020-08-26')) + | | fields `DAY_OF_YEAR(DATE('2020-08-26'))`, `DAYOFYEAR(DATE('2020-08-26'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(239, 239)), frame) + } + + test("test WEEK, WEEK_OF_YEAR are synonyms") { + val frame = sql(s""" + | source = $testTable + | | eval `WEEK(DATE('2008-02-20'))` = WEEK(DATE('2008-02-20')) + | | eval `WEEK_OF_YEAR(DATE('2008-02-20'))` = WEEK_OF_YEAR(DATE('2008-02-20')) + | | fields `WEEK(DATE('2008-02-20'))`, `WEEK_OF_YEAR(DATE('2008-02-20'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(8, 8)), frame) + } + + test("test MONTH, MONTH_OF_YEAR are synonyms") { + val frame = sql(s""" + | source = $testTable + | | eval `MONTH(DATE('2020-08-26'))` = MONTH(DATE('2020-08-26')) + | | eval `MONTH_OF_YEAR(DATE('2020-08-26'))` = MONTH_OF_YEAR(DATE('2020-08-26')) + | | fields `MONTH(DATE('2020-08-26'))`, `MONTH_OF_YEAR(DATE('2020-08-26'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(8, 8)), frame) + } + test("test WEEKDAY") { + val frame = sql(s""" + | source = $testTable + | | eval `weekday(DATE('2020-08-26'))` = weekday(DATE('2020-08-26')) + | | eval `weekday(DATE('2020-08-27'))` = weekday(DATE('2020-08-27')) + | | fields `weekday(DATE('2020-08-26'))`, `weekday(DATE('2020-08-27'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(2, 3)), frame) + } + + test("test YEAR") { + val frame = sql(s""" + | source = $testTable + | | eval `YEAR(DATE('2020-08-26'))` = YEAR(DATE('2020-08-26')) | fields `YEAR(DATE('2020-08-26'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(2020)), frame) + } + + test("test from_unixtime and unix_timestamp") { + val frame = sql(s""" + | source = $testTable |where unix_timestamp(from_unixtime(1700000001)) > 1700000000 | fields name, age + | """.stripMargin) + assertSameRows( + Seq(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = GreaterThan( + UnresolvedFunction( + "unix_timestamp", + seq(UnresolvedFunction("from_unixtime", seq(Literal(1700000001)), isDistinct = false)), + isDistinct = false), + Literal(1700000000)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test DATE_ADD") { + val frame1 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2d` = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 DAY) + | | fields `'2020-08-26' + 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-28"))), frame1) + + val frame2 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2d` = DATE_ADD(DATE('2020-08-26'), INTERVAL -2 DAY) + | | fields `'2020-08-26' - 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-24"))), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2m` = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 MONTH) + | | fields `'2020-08-26' + 2m` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-10-26"))), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2y` = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 YEAR) + | | fields `'2020-08-26' + 2y` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2022-08-26"))), frame4) + + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval `'2020-08-26 01:01:01' + 2h` = DATE_ADD(TIMESTAMP('2020-08-26 01:01:01'), INTERVAL 2 HOUR) + | | fields `'2020-08-26 01:01:01' + 2h` | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("""Parameter 1 requires the "DATE" type""")) + } + + test("test DATE_SUB") { + val frame1 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2d` = DATE_SUB(DATE('2020-08-26'), INTERVAL 2 DAY) + | | fields `'2020-08-26' - 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-24"))), frame1) + + val frame2 = sql(s""" + | source = $testTable | eval `'2020-08-26' + 2d` = DATE_SUB(DATE('2020-08-26'), INTERVAL -2 DAY) + | | fields `'2020-08-26' + 2d` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2020-08-28"))), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2m` = DATE_SUB(DATE('2020-08-26'), INTERVAL 12 MONTH) + | | fields `'2020-08-26' - 2m` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2019-08-26"))), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval `'2020-08-26' - 2y` = DATE_SUB(DATE('2020-08-26'), INTERVAL 2 YEAR) + | | fields `'2020-08-26' - 2y` | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2018-08-26"))), frame4) + + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval `'2020-08-26 01:01:01' - 2h` = DATE_SUB(TIMESTAMP('2020-08-26 01:01:01'), INTERVAL 2 HOUR) + | | fields `'2020-08-26 01:01:01' - 2h` | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("""Parameter 1 requires the "DATE" type""")) + } + + test("test TIMESTAMPADD") { + val frame = sql(s""" + | source = $testTable + | | eval `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') + | | eval `TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00'))` = TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00')) + | | eval `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` = TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') + | | fields `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')`, `TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00'))`, `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` + | | head 1 + | """.stripMargin) + assertSameRows( + Seq( + Row( + Timestamp.valueOf("2000-01-18 00:00:00"), + Timestamp.valueOf("2000-01-18 00:00:00"), + Timestamp.valueOf("1999-10-01 00:00:00"))), + frame) + } + + test("test TIMESTAMPDIFF") { + val frame = sql(s""" + | source = $testTable + | | eval `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') + | | eval `TIMESTAMPDIFF(SECOND, TIMESTAMP('2000-01-01 00:00:23'), TIMESTAMP('2000-01-01 00:00:00'))` = TIMESTAMPDIFF(SECOND, TIMESTAMP('2000-01-01 00:00:23'), TIMESTAMP('2000-01-01 00:00:00')) + | | fields `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')`, `TIMESTAMPDIFF(SECOND, TIMESTAMP('2000-01-01 00:00:23'), TIMESTAMP('2000-01-01 00:00:00'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(4, -23)), frame) + } + + test("test CURRENT_TIMEZONE") { + val frame = sql(s""" + | source = $testTable + | | eval `CURRENT_TIMEZONE` = CURRENT_TIMEZONE() + | | fields `CURRENT_TIMEZONE` + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("test UTC_TIMESTAMP") { + val frame = sql(s""" + | source = $testTable + | | eval `UTC_TIMESTAMP` = UTC_TIMESTAMP() + | | fields `UTC_TIMESTAMP` + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("test hour, minute, second, HOUR_OF_DAY, MINUTE_OF_HOUR") { + val frame = sql(s""" + | source = $testTable + | | eval h = hour(timestamp('01:02:03')), m = minute(timestamp('01:02:03')), s = second(timestamp('01:02:03')) + | | eval hs = hour('2024-07-30 01:02:03'), ms = minute('2024-07-30 01:02:03'), ss = second('01:02:03') + | | eval h_d = HOUR_OF_DAY(timestamp('01:02:03')), m_h = MINUTE_OF_HOUR(timestamp('01:02:03')), s_m = SECOND_OF_MINUTE(timestamp('01:02:03')) + | | fields h, m, s, hs, ms, ss, h_d, m_h, s_m | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(1, 2, 3, 1, 2, 3, 1, 2, 3)), frame) + } + + test("test LAST_DAY") { + val frame = sql(s""" + | source = $testTable + | | eval `last_day('2023-02-06')` = last_day('2023-02-06') + | | fields `last_day('2023-02-06')` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("2023-02-28"))), frame) + } + + test("test MAKE_DATE") { + val frame = sql(s""" + | source = $testTable + | | eval `MAKE_DATE(1945, 5, 9)` = MAKE_DATE(1945, 5, 9) | fields `MAKE_DATE(1945, 5, 9)` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(Date.valueOf("1945-05-09"))), frame) + } + + test("test QUARTER") { + val frame = sql(s""" + | source = $testTable + | | eval `QUARTER(DATE('2020-08-26'))` = QUARTER(DATE('2020-08-26')) | fields `QUARTER(DATE('2020-08-26'))` + | | head 1 + | """.stripMargin) + assertSameRows(Seq(Row(3)), frame) + } + + test("test CURRENT_TIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `CURRENT_TIME` = CURRENT_TIME() + | | fields CURRENT_TIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("CURRENT_TIME is not a builtin function of PPL")) + } + + test("test CONVERT_TZ is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `CONVERT_TZ` = CONVERT_TZ() + | | fields CONVERT_TZ | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("CONVERT_TZ is not a builtin function of PPL")) + } + + test("test ADDTIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `ADDTIME` = ADDTIME() + | | fields ADDTIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("ADDTIME is not a builtin function of PPL")) + } + + test("test DATETIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `DATETIME` = DATETIME() + | | fields DATETIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("DATETIME is not a builtin function of PPL")) + } + + test("test DAYNAME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `DAYNAME` = DAYNAME() + | | fields DAYNAME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("DAYNAME is not a builtin function of PPL")) + } + + test("test FROM_DAYS is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `FROM_DAYS` = FROM_DAYS() + | | fields FROM_DAYS | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("FROM_DAYS is not a builtin function of PPL")) + } + + test("test GET_FORMAT is not supported") { + intercept[Exception](sql(s""" + | source = $testTable + | | eval `GET_FORMAT` = GET_FORMAT(DATE, 'USA') + | | fields GET_FORMAT | head 1 + | """.stripMargin)) + } + + test("test MAKETIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `MAKETIME` = MAKETIME() + | | fields MAKETIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("MAKETIME is not a builtin function of PPL")) + } + + test("test MICROSECOND is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `MICROSECOND` = MICROSECOND() + | | fields MICROSECOND | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("MICROSECOND is not a builtin function of PPL")) + } + + test("test MINUTE_OF_DAY is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `MINUTE_OF_DAY` = MINUTE_OF_DAY() + | | fields MINUTE_OF_DAY | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("MINUTE_OF_DAY is not a builtin function of PPL")) + } + + test("test PERIOD_ADD is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `PERIOD_ADD` = PERIOD_ADD() + | | fields PERIOD_ADD | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("PERIOD_ADD is not a builtin function of PPL")) + } + + test("test PERIOD_DIFF is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `PERIOD_DIFF` = PERIOD_DIFF() + | | fields PERIOD_DIFF | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("PERIOD_DIFF is not a builtin function of PPL")) + } + + test("test SEC_TO_TIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `SEC_TO_TIME` = SEC_TO_TIME() + | | fields SEC_TO_TIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("SEC_TO_TIME is not a builtin function of PPL")) + } + + test("test STR_TO_DATE is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `STR_TO_DATE` = STR_TO_DATE() + | | fields STR_TO_DATE | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("STR_TO_DATE is not a builtin function of PPL")) + } + + test("test SUBTIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `SUBTIME` = SUBTIME() + | | fields SUBTIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("SUBTIME is not a builtin function of PPL")) + } + + test("test TIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `TIME` = TIME() + | | fields TIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("TIME is not a builtin function of PPL")) + } + + test("test TIME_FORMAT is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `TIME_FORMAT` = TIME_FORMAT() + | | fields TIME_FORMAT | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("TIME_FORMAT is not a builtin function of PPL")) + } + + test("test TIME_TO_SEC is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `TIME_TO_SEC` = TIME_TO_SEC() + | | fields TIME_TO_SEC | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("TIME_TO_SEC is not a builtin function of PPL")) + } + + test("test TIMEDIFF is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `TIMEDIFF` = TIMEDIFF() + | | fields TIMEDIFF | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("TIMEDIFF is not a builtin function of PPL")) + } + + test("test TO_DAYS is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `TO_DAYS` = TO_DAYS() + | | fields TO_DAYS | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("TO_DAYS is not a builtin function of PPL")) + } + + test("test TO_SECONDS is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `TO_SECONDS` = TO_SECONDS() + | | fields TO_SECONDS | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("TO_SECONDS is not a builtin function of PPL")) + } + + test("test UTC_DATE is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `UTC_DATE` = UTC_DATE() + | | fields UTC_DATE | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("UTC_DATE is not a builtin function of PPL")) + } + + test("test UTC_TIME is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `UTC_TIME` = UTC_TIME() + | | fields UTC_TIME | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("UTC_TIME is not a builtin function of PPL")) + } + + test("test YEARWEEK is not supported") { + val ex = intercept[UnsupportedOperationException](sql(s""" + | source = $testTable + | | eval `YEARWEEK` = YEARWEEK() + | | fields YEARWEEK | head 1 + | """.stripMargin)) + assert(ex.getMessage.contains("YEARWEEK is not a builtin function of PPL")) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala index 4c35549df..763c2411b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala @@ -605,31 +605,6 @@ class FlintSparkPPLBuiltinFunctionITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - test("test time functions - from_unixtime and unix_timestamp") { - val frame = sql(s""" - | source = $testTable |where unix_timestamp(from_unixtime(1700000001)) > 1700000000 | fields name, age - | """.stripMargin) - - val results: Array[Row] = frame.collect() - val expectedResults: Array[Row] = - Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - val logicalPlan: LogicalPlan = frame.queryExecution.logical - val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) - val filterExpr = GreaterThan( - UnresolvedFunction( - "unix_timestamp", - seq(UnresolvedFunction("from_unixtime", seq(Literal(1700000001)), isDistinct = false)), - isDistinct = false), - Literal(1700000000)) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) - } - test("test arithmetic operators (+ - * / %)") { val frame = sql(s""" | source = $testTable | where (sqrt(pow(age, 2)) + sqrt(pow(age, 2)) / 1 - sqrt(pow(age, 2)) * 1) % 25.0 = 0 | fields name, age @@ -810,6 +785,42 @@ class FlintSparkPPLBuiltinFunctionITSuite assert(results.sameElements(expectedResults)) } + test("test cryptographic hash functions - md5") { + val frame = sql(s""" + | source = $testTable | eval a = md5('Spark') = '8cde774d6f7333752ed72cacddb05126' | fields age, a + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(70, true), Row(30, true), Row(25, true), Row(20, true)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test cryptographic hash functions - sha1") { + val frame = sql(s""" + | source = $testTable | eval a = sha1('Spark') = '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c' | fields age, a + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(70, true), Row(30, true), Row(25, true), Row(20, true)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test cryptographic hash functions - sha2") { + val frame = sql(s""" + | source = $testTable | eval a = sha2('Spark',256) = '529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b' | fields age, a + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(70, true), Row(30, true), Row(25, true), Row(20, true)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + // Todo // +---------------------------------------+ // | Below tests are not supported (cast) | diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala new file mode 100644 index 000000000..d9cf8968b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.SparkException +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLCidrmatchITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createIpAddressTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test cidrmatch for ipv4 for 192.168.1.0/24") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 2) + } + + test("test cidrmatch for ipv4 for 192.169.1.0/24") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 0) + } + + test("test cidrmatch for ipv6 for 2001:db8::/32") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2001:db8::/32') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 3) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2003:db8::/32') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 0) + } + + test("test cidrmatch for ipv6 with ipv4 cidr") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } + + test("test cidrmatch for invalid ipv4 addresses") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = false and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } + + test("test cidrmatch for invalid ipv6 addresses") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCommentITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCommentITSuite.scala new file mode 100644 index 000000000..71d9f1693 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCommentITSuite.scala @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLCommentITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test line comment") { + val frame = sql(s""" + | /* + | * This is a + | * multiple + | * line block + | * comment + | */ + | source = /* block comment */ $testTable /* block comment */ + | | eval /* + | This is a + | multiple + | line + | block + | comment + | */ col = 1 + | | /* block comment */ fields name, /* block comment */ age + | /* block comment */ + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test block comment") { + val frame = sql(s""" + | source = $testTable //line comment + | | eval col = 1 // line comment + | | fields name, age // line comment + | /////////line comment + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index e10b2e2a6..c3dd1d533 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, In, LessThan, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest @@ -688,4 +688,20 @@ class FlintSparkPPLEvalITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) } + + test("test IN expr in eval") { + val frame = sql(s""" + | source = $testTable | eval in = state in ('California', 'New York') | fields in + | """.stripMargin) + assertSameRows(Seq(Row(true), Row(true), Row(false), Row(false)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val in = Alias( + In(UnresolvedAttribute("state"), Seq(Literal("California"), Literal("New York"))), + "in")() + val eval = Project(Seq(UnresolvedStar(None), in), table) + val expectedPlan = Project(Seq(UnresolvedAttribute("in")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala new file mode 100644 index 000000000..f1d287429 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala @@ -0,0 +1,379 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, Floor, Literal, Multiply, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Window} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLEventstatsITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test eventstats avg") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 36.25), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 36.25), + Row("Jake", 70, "California", "USA", 2023, 4, 36.25), + Row("Hello", 30, "New York", "USA", 2023, 4, 36.25)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), Nil, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 36.25, 70, 20, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 36.25, 70, 20, 4), + Row("Jake", 70, "California", "USA", 2023, 4, 36.25, 70, 20, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, 36.25, 70, 20, 4)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + Nil, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg by country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5), + Row("Jake", 70, "California", "USA", 2023, 4, 50), + Row("Hello", 30, "New York", "USA", 2023, 4, 50)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), partitionSpec, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 50, 70, 30, 2), + Row("Hello", 30, "New York", "USA", 2023, 4, 50, 70, 30, 2)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val partitionSpec = Seq(span) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span and country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + } + + test("test eventstats avg, max, min, count by span and state") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, state + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 25, 1), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 20, 1), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + } + + test("test eventstats stddev by span with filter") { + val frame = sql(s""" + | source = $testTable | where country != 'USA' | eventstats stddev_samp(age) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 3.5355339059327378), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 3.5355339059327378)) + assertSameRows(expected, frame) + } + + test("test eventstats stddev_pop by span with filter") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | eventstats stddev_pop(age) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 2.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 2.5), + Row("Hello", 30, "New York", "USA", 2023, 4, 0.0)) + assertSameRows(expected, frame) + } + + test("test eventstats percentile by span with filter") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | eventstats percentile_approx(age, 60) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 25), + Row("Hello", 30, "New York", "USA", 2023, 4, 30)) + assertSameRows(expected, frame) + } + + test("test multiple eventstats") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age by state, country | eventstats avg(avg_age) as avg_state_age by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 22.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20.0, 22.5), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 50.0), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 50.0)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val partitionSpec = Seq( + Alias(UnresolvedAttribute("state"), "state")(), + Alias(UnresolvedAttribute("country"), "country")()) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val avgAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgAgeWindowExprAlias = Alias(avgAgeWindowExpression, "avg_age")() + val windowPlan1 = Window(Seq(avgAgeWindowExprAlias), partitionSpec, Nil, table) + + val countryPartitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgStateAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("avg_age")), isDistinct = false), + WindowSpecDefinition( + countryPartitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgStateAgeWindowExprAlias = Alias(avgStateAgeWindowExpression, "avg_state_age")() + val windowPlan2 = + Window(Seq(avgStateAgeWindowExprAlias), countryPartitionSpec, Nil, windowPlan1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan2) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eventstats with eval") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age by state, country | eval new_avg_age = avg_age - 10 | eventstats avg(new_avg_age) as avg_state_age by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 15.0, 12.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20.0, 10.0, 12.5), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 60.0, 40.0), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 20.0, 40.0)) + assertSameRows(expected, frame) + } + + test("test multiple eventstats with eval and filter") { + val frame = sql(s""" + | source = $testTable| eventstats avg(age) as avg_age by country, state, name | eval avg_age_divide_20 = avg_age - 20 | eventstats avg(avg_age_divide_20) + | as avg_state_age by country, state | where avg_state_age > 0 | eventstats count(avg_state_age) as count_country_age_greater_20 by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 5.0, 5.0, 1), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 50.0, 50.0, 2), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 10.0, 10.0, 2)) + assertSameRows(expected, frame) + } + + test("test eventstats distinct_count by span with filter") { + val exception = intercept[AnalysisException](sql(s""" + | source = $testTable | where state != 'California' | eventstats distinct_count(age) by span(age, 10) as age_span + | """.stripMargin)) + assert(exception.message.contains("Distinct window functions are not supported")) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala index 81bdd99df..8009015b1 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala @@ -84,6 +84,44 @@ class FlintSparkPPLExistsSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test simple exists subquery in search filter") { + val frame = sql(s""" + | source = $outerTable exists [ source = $innerTable | where id = uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test not exists subquery") { val frame = sql(s""" | source = $outerTable @@ -122,6 +160,41 @@ class FlintSparkPPLExistsSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test not exists subquery in search filter") { + val frame = sql(s""" + | source = $outerTable not exists [ source = $innerTable | where id = uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = + Filter( + Not( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test empty exists subquery") { var frame = sql(s""" | source = $outerTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala new file mode 100644 index 000000000..5a5990001 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -0,0 +1,751 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, Expression, Literal, NamedExpression, Not, SortOrder, Subtract} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFieldSummaryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNullableTableHttpLog(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test fieldsummary with single field includefields(status_code) & nulls=true ") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 184.0, 184.0, 161.16699413961905, 2, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=false ") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 276.0, 276.0, 97.1356439899038, 2, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=true + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 184.0, 184.0, 161.16699413961905, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", 0.0, 0.0, 0.0, 2, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=false + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 276.0, 276.0, 97.1356439899038, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", null, null, null, 2, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 14ef7ccc4..f2d7ee844 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, In, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -453,4 +453,18 @@ class FlintSparkPPLFiltersITSuite // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test NOT IN expr in filter") { + val frame = sql(s""" + | source = $testTable | where state not in ('California', 'New York') | fields state + | """.stripMargin) + assertSameRows(Seq(Row("Ontario"), Row("Quebec")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val in = In(UnresolvedAttribute("state"), Seq(Literal("California"), Literal("New York"))) + val filter = Filter(Not(in), table) + val expectedPlan = Project(Seq(UnresolvedAttribute("state")), filter) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala new file mode 100644 index 000000000..e714a5f7e --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala @@ -0,0 +1,350 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFlattenITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("flatten for structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | fields coor + | | flatten coor + | """.stripMargin) + + assert(frame.columns.sameElements(Array("alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(35, 51.5074, -0.1278), Row(null, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val projectCoor = Project(Seq(UnresolvedAttribute("coor")), filter) + val flattenCoor = flattenPlanFor("coor", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + private def flattenPlanFor(flattenedColumn: String, parentPlan: LogicalPlan): LogicalPlan = { + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute(flattenedColumn)) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), outer = true, None, seq(), parentPlan) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute(flattenedColumn)), generate) + dropSourceColumn + } + + test("flatten for arrays") { + val frame = sql(s""" + | source = $testTable + | | fields bridges + | | flatten bridges + | """.stripMargin) + + assert(frame.columns.sameElements(Array("length", "name"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, null), + Row(11L, "Bridge of Sighs"), + Row(48L, "Rialto Bridge"), + Row(160L, "Pont Alexandre III"), + Row(232L, "Pont Neuf"), + Row(801L, "Tower Bridge"), + Row(928L, "London Bridge"), + Row(343L, "Legion Bridge"), + Row(516L, "Charles Bridge"), + Row(333L, "Liberty Bridge"), + Row(375L, "Chain Bridge")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCoor = Project(Seq(UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenBridges) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten for structs and arrays") { + val frame = sql(s""" + | source = $testTable | flatten bridges | flatten coor + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("_time", "city", "country", "length", "name", "alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("1990-09-13T12:00:00", "Warsaw", "Poland", null, null, null, null, null), + Row( + "2024-09-13T12:00:00", + "Venice", + "Italy", + 11L, + "Bridge of Sighs", + 2, + 45.4408, + 12.3155), + Row("2024-09-13T12:00:00", "Venice", "Italy", 48L, "Rialto Bridge", 2, 45.4408, 12.3155), + Row( + "2024-09-13T12:00:00", + "Paris", + "France", + 160L, + "Pont Alexandre III", + 35, + 48.8566, + 2.3522), + Row("2024-09-13T12:00:00", "Paris", "France", 232L, "Pont Neuf", 35, 48.8566, 2.3522), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 801L, + "Tower Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 928L, + "London Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 343L, + "Legion Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 516L, + "Charles Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 333L, + "Liberty Bridge", + 96, + 47.4979, + 19.0402), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 375L, + "Chain Bridge", + 96, + 47.4979, + 19.0402)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](3)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val flattenBridges = flattenPlanFor("bridges", table) + val flattenCoor = flattenPlanFor("coor", flattenBridges) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val frame = sql(s""" + | source = $testTable + | | fields country, bridges + | | flatten bridges + | | fields country, length + | | stats avg(length) as avg by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, "Poland"), + Row(196d, "France"), + Row(429.5, "Czech Republic"), + Row(864.5, "England"), + Row(29.5, "Italy"), + Row(354.0, "Hungary")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCountryBridges = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCountryBridges) + val projectCountryLength = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("length")), flattenBridges) + val average = Alias( + UnresolvedFunction( + seq("AVG"), + seq(UnresolvedAttribute("length")), + isDistinct = false, + None, + ignoreNulls = false), + "avg")() + val country = Alias(UnresolvedAttribute("country"), "country")() + val grouping = Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(average, country), projectCountryLength) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct table") { + val frame = sql(s""" + | source = $structTable + | | flatten struct_col + | | flatten field1 + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(30, 123, "value1"), Row(40, 456, "value2"), Row(50, 789, "value3")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct nested table") { + val frame = sql(s""" + | source = $structNestedTable + | | flatten struct_col + | | flatten field1 + | | flatten struct_col2 + | | flatten field1 + | """.stripMargin) + + assert( + frame.columns.sameElements(Array("int_col", "field2", "subfield", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(30, 123, "value1", 23, "valueA"), + Row(40, 123, "value5", 33, "valueB"), + Row(30, 823, "value4", 83, "valueC"), + Row(40, 456, "value2", 46, "valueD"), + Row(50, 789, "value3", 89, "valueE")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_nested_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val flattenStructCol2 = flattenPlanFor("struct_col2", flattenField1) + val flattenField1Again = flattenPlanFor("field1", flattenStructCol2) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1Again) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten multi value nullable") { + val frame = sql(s""" + | source = $multiValueTable + | | flatten multi_value + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "name", "value"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "1_one", 1), + Row(1, null, 11), + Row(1, "1_three", null), + Row(2, "2_Monday", 2), + Row(2, null, null), + Row(3, "3_third", 3), + Row(3, "3_4th", 4), + Row(4, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val flattenMultiValue = flattenPlanFor("multi_value", table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenMultiValue) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala index 9d8c2c12d..107390dff 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala @@ -87,6 +87,45 @@ class FlintSparkPPLInSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test filter id in (select uid from inner)") { + val frame = sql(s""" + source = $outerTable id in [ source = $innerTable | fields uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test where (id) in (select uid from inner)") { // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) // InSubquery: (0, 2, 3, 5, 6) @@ -214,6 +253,41 @@ class FlintSparkPPLInSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test filter id not in (select uid from inner)") { + val frame = sql(s""" + source = $outerTable id not in [ source = $innerTable | fields uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test where (id, name) not in (select uid, name from inner)") { // Not InSubquery: (1, 4, 6) val frame = sql(s""" diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index b276149a0..00e55d50a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -7,9 +7,9 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, LessThan, Literal, Multiply, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, LogicalPlan, Project, Sort, SubqueryAlias} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLJoinITSuite @@ -738,4 +738,190 @@ class FlintSparkPPLJoinITSuite case j @ Join(_, _, Inner, _, JoinHint.NONE) => j }.size == 1) } + + test("test inner join with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'USA' OR country = 'England' + | | inner join left=a, right=b + | ON a.name = b.name + | [ + | source = $testTable2 + | | where salary > 0 + | | fields name, country, salary + | | sort salary + | | head 3 + | ] + | | stats avg(salary) by span(age, 10) as age_span, b.country + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(70000.0, "USA", 30), Row(100000.0, "England", 70)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val plan1 = SubqueryAlias("a", Filter(filterExpr, table1)) + val rightSubquery = + GlobalLimit( + Literal(3), + LocalLimit( + Literal(3), + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Ascending)), + global = true, + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("country"), + UnresolvedAttribute("salary")), + Filter(GreaterThan(UnresolvedAttribute("salary"), Literal(0)), table2))))) + val plan2 = SubqueryAlias("b", rightSubquery) + + val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name")) + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute("b.country") + val countryAlias = Alias(countryField, "b.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + val expectedPlan = Project(star, aggregatePlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test left outer join with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'USA' OR country = 'England' + | | left join left=a, right=b + | ON a.name = b.name + | [ + | source = $testTable2 + | | where salary > 0 + | | fields name, country, salary + | | sort salary + | | head 3 + | ] + | | stats avg(salary) by span(age, 10) as age_span, b.country + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(70000.0, "USA", 30), Row(100000.0, "England", 70), Row(null, null, 40)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val plan1 = SubqueryAlias("a", Filter(filterExpr, table1)) + val rightSubquery = + GlobalLimit( + Literal(3), + LocalLimit( + Literal(3), + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Ascending)), + global = true, + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("country"), + UnresolvedAttribute("salary")), + Filter(GreaterThan(UnresolvedAttribute("salary"), Literal(0)), table2))))) + val plan2 = SubqueryAlias("b", rightSubquery) + + val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name")) + val joinPlan = Join(plan1, plan2, LeftOuter, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute("b.country") + val countryAlias = Alias(countryField, "b.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + val expectedPlan = Project(star, aggregatePlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'Canada' OR country = 'England' + | | inner join left=a, right=b + | ON a.name = b.name AND a.year = 2023 AND a.month = 4 AND b.year = 2023 AND b.month = 4 + | [ + | source = $testTable2 + | ] + | | eval a_name = a.name + | | eval a_country = a.country + | | eval b_country = b.country + | | fields a_name, age, state, a_country, occupation, b_country, salary + | | left join left=a, right=b + | ON a.a_name = b.name + | [ + | source = $testTable3 + | ] + | | eval aa_country = a.a_country + | | eval ab_country = a.b_country + | | eval bb_country = b.country + | | fields a_name, age, state, aa_country, occupation, ab_country, salary, bb_country, hobby, language + | | cross join left=a, right=b + | [ + | source = $testTable2 + | ] + | | eval new_country = a.aa_country + | | eval new_salary = b.salary + | | stats avg(new_salary) as avg_salary by span(age, 5) as age_span, state + | | left semi join left=a, right=b + | ON a.state = b.state + | [ + | source = $testTable1 + | ] + | | eval new_avg_salary = floor(avg_salary) + | | fields state, age_span, new_avg_salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("Quebec", 20, 83333), Row("Ontario", 25, 83333)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Cross, None, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, LeftOuter, _, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.analyzed.collect { case s: SubqueryAlias => + s + }.size == 13) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala new file mode 100644 index 000000000..7cc0a221d --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -0,0 +1,386 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, Not} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLJsonFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + private val validJson1 = "{\"account_number\":1,\"balance\":39225,\"age\":32,\"gender\":\"M\"}" + private val validJson2 = "{\"f1\":\"abc\",\"f2\":{\"f3\":\"a\",\"f4\":\"b\"}}" + private val validJson3 = "[1,2,3,{\"f1\":1,\"f2\":[5,6]},4]" + private val validJson4 = "[]" + private val validJson5 = + "{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}" + private val validJson6 = "[1,2,3]" + private val invalidJson1 = "[1,2" + private val invalidJson2 = "[invalid json]" + private val invalidJson3 = "{\"invalid\": \"json\"" + private val invalidJson4 = "invalid json" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test json() function: valid JSON") { + Seq(validJson1, validJson2, validJson3, validJson4, validJson5).foreach { jsonStr => + val frame = sql(s""" + | source = $testTable + | | eval result = json('$jsonStr') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(jsonStr)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(Literal(jsonStr), Literal("$")), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("test json() function: invalid JSON") { + Seq(invalidJson1, invalidJson2, invalidJson3, invalidJson4).foreach { jsonStr => + val frame = sql(s""" + | source = $testTable + | | eval result = json('$jsonStr') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(Literal(jsonStr), Literal("$")), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("test json() function on field") { + val frame = sql(s""" + | source = $testTable + | | where isValid = true | eval result = json(jString) | fields result + | """.stripMargin) + assertSameRows( + Seq(validJson1, validJson2, validJson3, validJson4, validJson5, validJson6).map( + Row.apply(_)), + frame) + + val frame2 = sql(s""" + | source = $testTable + | | where isValid = false | eval result = json(jString) | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null), Row(null), Row(null), Row(null), Row(null)), frame2) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false), + "result")() + val eval = Project( + Seq(UnresolvedStar(None), jsonFunc), + Filter(EqualTo(UnresolvedAttribute("isValid"), Literal(true)), table)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_array()") { + // test string array + var frame = sql(s""" + | source = $testTable | eval result = json_array('this', 'is', 'a', 'string', 'array') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq("this", "is", "a", "string", "array").toArray)), frame) + + // test empty array + frame = sql(s""" + | source = $testTable | eval result = json_array() | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Array.empty)), frame) + + // test number array + frame = sql(s""" + | source = $testTable | eval result = json_array(1, 2, 0, -1, 1.1, -0.11) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1.0, 2.0, 0.0, -1.0, 1.1, -0.11).toArray)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(0), Literal(-1), Literal(1.1), Literal(-0.11)), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + // item in json_array should all be the same type + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval result = json_array('this', 'is', 1.1, -0.11, true, false) | head 1 | fields result + | """.stripMargin)) + assert(ex.getMessage().contains("should all be the same type")) + } + + test("test json_array() with json()") { + val frame = sql(s""" + | source = $testTable | eval result = json(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""[1.0,2.0,0.0,-1.0,1.1,-0.11]""")), frame) + } + + test("test json_array_length()") { + var frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(5)), frame) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(6)), frame) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array()) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(0)), frame) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[]') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(0)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2,3,4]') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(4)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2,3,{"f1":1,"f2":[5,6]},4]') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(5)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('{\"key\": 1}') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null)), frame) + } + + test("test json_object()") { + // test value is a string + var frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', 'string_value')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":"string_value"}""")), frame) + + // test value is a number + frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', 123.45)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":123.45}""")), frame) + + // test value is a boolean + frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', true)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":true}""")), frame) + + frame = sql(s""" + | source = $testTable| eval result = json(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"a":1,"b":2,"c":3}""")), frame) + } + + test("test json_object() and json_array()") { + // test value is an empty array + var frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', array())) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":[]}""")), frame) + + // test value is an array + frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', array(1, 2, 3))) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":[1,2,3]}""")), frame) + + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "named_struct", + Seq( + Literal("key"), + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false), + "result")() + var expectedPlan = Project( + Seq(UnresolvedAttribute("result")), + GlobalLimit( + Literal(1), + LocalLimit(Literal(1), Project(Seq(UnresolvedStar(None), jsonFunc), table)))) + comparePlans(frame.queryExecution.logical, expectedPlan, checkAnalysis = false) + } + + test("test json_object() nested") { + val frame = sql(s""" + | source = $testTable | eval result = json(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"outer":{"inner":123.45}}""")), frame) + } + + test("test json_object(), json_array() and json()") { + val frame = sql(s""" + | source = $testTable | eval result = json(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]}""")), frame) + } + + test("test json_valid()") { + val frame = sql(s""" + | source = $testTable + | | where json_valid(jString) | fields jString + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Seq(validJson1, validJson2, validJson3, validJson4, validJson5, validJson6) + .map(Row.apply(_)) + .toArray + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val frame2 = sql(s""" + | source = $testTable + | | where not json_valid(jString) | fields jString + | """.stripMargin) + val results2: Array[Row] = frame2.collect() + val expectedResults2: Array[Row] = + Seq(invalidJson1, invalidJson2, invalidJson3, invalidJson4, null).map(Row.apply(_)).toArray + assert(results2.sameElements(expectedResults2)) + + val logicalPlan: LogicalPlan = frame2.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = + UnresolvedFunction( + "isnotnull", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false)), + isDistinct = false) + val where = Filter(Not(jsonFunc), table) + val expectedPlan = Project(Seq(UnresolvedAttribute("jString")), where) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_keys()") { + val frame = sql(s""" + | source = $testTable + | | where isValid = true + | | eval result = json_keys(json(jString)) | fields result + | """.stripMargin) + val expectedRows = Seq( + Row(Array("account_number", "balance", "age", "gender")), + Row(Array("f1", "f2")), + Row(null), + Row(null), + Row(Array("teacher", "student")), + Row(null)) + assertSameRows(expectedRows, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "json_object_keys", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false)), + isDistinct = false), + "result")() + val eval = Project( + Seq(UnresolvedStar(None), jsonFunc), + Filter(EqualTo(UnresolvedAttribute("isValid"), Literal(true)), table)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_extract()") { + val frame = sql(""" + | source = spark_catalog.default.flint_ppl_test | where id = 5 + | | eval root = json_extract(jString, '$') + | | eval teacher = json_extract(jString, '$.teacher') + | | eval students = json_extract(jString, '$.student') + | | eval students_* = json_extract(jString, '$.student[*]') + | | eval student_0 = json_extract(jString, '$.student[0]') + | | eval student_names = json_extract(jString, '$.student[*].name') + | | eval student_1_name = json_extract(jString, '$.student[1].name') + | | eval student_non_exist_key = json_extract(jString, '$.student[0].non_exist_key') + | | eval student_non_exist = json_extract(jString, '$.student[10]') + | | fields root, teacher, students, students_*, student_0, student_names, student_1_name, student_non_exist_key, student_non_exist + | """.stripMargin) + val expectedSeq = Seq( + Row( + """{"teacher":"Alice","student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}""", + "Alice", + """[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]""", + """[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]""", + """{"name":"Bob","rank":1}""", + """["Bob","Charlie"]""", + "Charlie", + null, + null)) + assertSameRows(expectedSeq, frame) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala new file mode 100644 index 000000000..f86502521 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLambdaFunctionITSuite.scala @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions.{col, to_json} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLLambdaFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test forall()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame4) + } + + test("test exists()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(true)), frame3) + + val frame4 = sql(s""" + | source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(false)), frame4) + + } + + test("test filter()") { + val frame = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 2, 1.1))), frame) + + val frame2 = sql(s""" + | source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq())), frame2) + + val frame3 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.a > 0) | head 1 | fields result + | """.stripMargin) + + assertSameRows(Seq(Row("""[{"a":1,"b":-1}]""")), frame3.select(to_json(col("result")))) + + val frame4 = sql(s""" + | source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.b > 0) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""[]""")), frame4.select(to_json(col("result")))) + } + + test("test transform()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(2, 3, 4))), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1, 3, 5))), frame2) + } + + test("test reduce()") { + val frame = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(6)), frame) + + val frame2 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 1, (acc, x) -> acc + x) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(7)), frame2) + + val frame3 = sql(s""" + | source = $testTable | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(60)), frame3) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala index 654add8d8..24b4d77e6 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala @@ -132,12 +132,12 @@ class FlintSparkPPLScalarSubqueryITSuite test("test uncorrelated scalar subquery in select and where") { val frame = sql(s""" | source = $outerTable - | | eval count_dept = [ - | source = $innerTable | stats count(department) - | ] | | where id > [ | source = $innerTable | stats count(department) | ] + 999 + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] | | fields name, count_dept | """.stripMargin) val results: Array[Row] = frame.collect() @@ -160,13 +160,50 @@ class FlintSparkPPLScalarSubqueryITSuite val countScalarSubqueryExpr = ScalarSubquery(countAggPlan) val plusScalarSubquery = UnresolvedFunction(Seq("+"), Seq(countScalarSubqueryExpr, Literal(999)), isDistinct = false) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), outer) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(countScalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, filter) + val expectedPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in select and from with filter") { + val frame = sql(s""" + | source = $outerTable id > [ source = $innerTable | stats count(department) ] + 999 + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] + | | fields name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("Jane", 5), Row("Tommy", 5)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val countAgg = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val countAggPlan = Aggregate(Seq(), countAgg, inner) + val countScalarSubqueryExpr = ScalarSubquery(countAggPlan) + val plusScalarSubquery = + UnresolvedFunction(Seq("+"), Seq(countScalarSubqueryExpr, Literal(999)), isDistinct = false) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), outer) val evalProjectList = Seq(UnresolvedStar(None), Alias(countScalarSubqueryExpr, "count_dept")()) - val evalProject = Project(evalProjectList, outer) - val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), evalProject) + val evalProject = Project(evalProjectList, filter) val expectedPlan = - Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), filter) + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), evalProject) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -302,6 +339,39 @@ class FlintSparkPPLScalarSubqueryITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("test correlated scalar subquery in from with filter") { + val frame = sql(s""" + | source = $outerTable id = [ source = $innerTable | where id = uid | stats max(uid) ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake"), + Row(1002, "John"), + Row(1003, "David"), + Row(1005, "Jane"), + Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("uid")), isDistinct = false), + "max(uid)")()) + val innerFilter = + Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(EqualTo(UnresolvedAttribute("id"), scalarSubqueryExpr), outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), outerFilter) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } test("test disjunctive correlated scalar subquery") { val frame = sql(s""" diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala new file mode 100644 index 000000000..bc4463537 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTrendlineITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test trendline sma command without fields command and without alias") { + val frame = sql(s""" + | source = $testTable | sort - age | trendline sma(2, age) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", "USA", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 50.0), + Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command with fields command") { + val frame = sql(s""" + | source = $testTable | trendline sort - age sma(3, age) as age_sma | fields name, age, age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, null), + Row("Hello", 30, null), + Row("John", 25, 41.666666666666664), + Row("Jane", 20, 25)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageSmaField = UnresolvedAttribute("age_sma") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline sma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, null, null), + Row("John", 25, 22.5, null), + Row("Hello", 30, 27.5, 25.0), + Row("Jake", 70, 50.0, 41.666666666666664)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "three_points_sma")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 45.0), + Row("Hello", 60, 55.0), + Row("Jake", 140, 100.0)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val doubledAgeField = UnresolvedAttribute("doubled_age") + val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma") + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false), + "doubled_age")()), + table) + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val doubleAgeSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = + CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")()) + val expectedPlan = Project( + Seq(nameField, doubledAgeField, doubledAgeSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command chaining") { + val frame = sql(s""" + | source = $testTable | eval age_1 = age, age_2 = age | trendline sort - age_1 sma(3, age_1) | trendline sort + age_2 sma(3, age_2) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "age_1", + "age_2", + "age_1_trendline", + "age_2_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, null, 25.0), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, null, 41.666666666666664), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 25.0, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 41.666666666666664, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 5d980f167..d052101c7 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -18,6 +18,7 @@ WHERE: 'WHERE'; FIELDS: 'FIELDS'; RENAME: 'RENAME'; STATS: 'STATS'; +EVENTSTATS: 'EVENTSTATS'; DEDUP: 'DEDUP'; SORT: 'SORT'; EVAL: 'EVAL'; @@ -36,6 +37,8 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +FLATTEN: 'FLATTEN'; +TRENDLINE: 'TRENDLINE'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -73,19 +76,24 @@ INDEX: 'INDEX'; D: 'D'; DESC: 'DESC'; DATASOURCES: 'DATASOURCES'; -VALUE: 'VALUE'; USING: 'USING'; WITH: 'WITH'; -// CLAUSE KEYWORDS -SORTBY: 'SORTBY'; - // FIELD KEYWORDS AUTO: 'AUTO'; STR: 'STR'; IP: 'IP'; NUM: 'NUM'; + +// FIELDSUMMARY keywords +FIELDSUMMARY: 'FIELDSUMMARY'; +INCLUDEFIELDS: 'INCLUDEFIELDS'; +NULLS: 'NULLS'; + +//TRENDLINE KEYWORDS +SMA: 'SMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; @@ -195,6 +203,7 @@ RT_SQR_PRTHS: ']'; SINGLE_QUOTE: '\''; DOUBLE_QUOTE: '"'; BACKTICK: '`'; +ARROW: '->'; // Operators. Bit @@ -280,6 +289,11 @@ RADIANS: 'RADIANS'; SIN: 'SIN'; TAN: 'TAN'; +// CRYPTOGRAPHIC FUNCTIONS +MD5: 'MD5'; +SHA1: 'SHA1'; +SHA2: 'SHA2'; + // DATE AND TIME FUNCTIONS ADDDATE: 'ADDDATE'; ADDTIME: 'ADDTIME'; @@ -287,6 +301,7 @@ CURDATE: 'CURDATE'; CURRENT_DATE: 'CURRENT_DATE'; CURRENT_TIME: 'CURRENT_TIME'; CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_TIMEZONE: 'CURRENT_TIMEZONE'; CURTIME: 'CURTIME'; DATE: 'DATE'; DATEDIFF: 'DATEDIFF'; @@ -308,6 +323,7 @@ LAST_DAY: 'LAST_DAY'; LOCALTIME: 'LOCALTIME'; LOCALTIMESTAMP: 'LOCALTIMESTAMP'; MAKEDATE: 'MAKEDATE'; +MAKE_DATE: 'MAKE_DATE'; MAKETIME: 'MAKETIME'; MONTHNAME: 'MONTHNAME'; NOW: 'NOW'; @@ -357,11 +373,41 @@ CAST: 'CAST'; ISEMPTY: 'ISEMPTY'; ISBLANK: 'ISBLANK'; +// JSON TEXT FUNCTIONS +JSON: 'JSON'; +JSON_OBJECT: 'JSON_OBJECT'; +JSON_ARRAY: 'JSON_ARRAY'; +JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; +JSON_EXTRACT: 'JSON_EXTRACT'; +JSON_KEYS: 'JSON_KEYS'; +JSON_VALID: 'JSON_VALID'; +//JSON_APPEND: 'JSON_APPEND'; +//JSON_DELETE: 'JSON_DELETE'; +//JSON_EXTEND: 'JSON_EXTEND'; +//JSON_SET: 'JSON_SET'; +//JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH'; +//JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH'; +//JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER'; +//JSON_ARRAY_MAP: 'JSON_ARRAY_MAP'; +//JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE'; + +// COLLECTION FUNCTIONS +ARRAY: 'ARRAY'; + +// LAMBDA FUNCTIONS +//EXISTS: 'EXISTS'; +FORALL: 'FORALL'; +FILTER: 'FILTER'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; + // BOOL FUNCTIONS LIKE: 'LIKE'; ISNULL: 'ISNULL'; ISNOTNULL: 'ISNOTNULL'; ISPRESENT: 'ISPRESENT'; +BETWEEN: 'BETWEEN'; +CIDRMATCH: 'CIDRMATCH'; // FLOWCONTROL FUNCTIONS IFNULL: 'IFNULL'; @@ -443,5 +489,7 @@ SQUOTA_STRING: '\'' ('\\'. | '\'\'' | ~('\'' | '\\'))* '\'' BQUOTA_STRING: '`' ( '\\'. | '``' | ~('`'|'\\'))* '`'; fragment DEC_DIGIT: [0-9]; +LINE_COMMENT: '//' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN); +BLOCK_COMMENT: '/*' .*? '*/' -> channel(HIDDEN); ERROR_RECOGNITION: . -> channel(ERRORCHANNEL); diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 8b84f0348..9c82b27a7 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -52,6 +52,40 @@ commands | lookupCommand | renameCommand | fillnullCommand + | fieldsummaryCommand + | flattenCommand + | trendlineCommand + ; + +commandName + : SEARCH + | DESCRIBE + | SHOW + | AD + | ML + | KMEANS + | WHERE + | CORRELATE + | JOIN + | FIELDS + | STATS + | EVENTSTATS + | DEDUP + | EXPLAIN + | SORT + | HEAD + | TOP + | RARE + | EVAL + | GROK + | PARSE + | PATTERNS + | LOOKUP + | RENAME + | FILLNULL + | FIELDSUMMARY + | FLATTEN + | TRENDLINE ; searchCommand @@ -60,6 +94,15 @@ searchCommand | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; +fieldsummaryCommand + : FIELDSUMMARY (fieldsummaryParameter)* + ; + +fieldsummaryParameter + : INCLUDEFIELDS EQUAL fieldList # fieldsummaryIncludeFields + | NULLS EQUAL booleanLiteral # fieldsummaryNulls + ; + describeCommand : DESCRIBE tableSourceClause ; @@ -115,7 +158,7 @@ renameCommand ; statsCommand - : STATS (PARTITIONS EQUAL partitions = integerLiteral)? (ALLNUM EQUAL allnum = booleanLiteral)? (DELIM EQUAL delim = stringLiteral)? statsAggTerm (COMMA statsAggTerm)* (statsByClause)? (DEDUP_SPLITVALUES EQUAL dedupsplit = booleanLiteral)? + : (STATS | EVENTSTATS) (PARTITIONS EQUAL partitions = integerLiteral)? (ALLNUM EQUAL allnum = booleanLiteral)? (DELIM EQUAL delim = stringLiteral)? statsAggTerm (COMMA statsAggTerm)* (statsByClause)? (DEDUP_SPLITVALUES EQUAL dedupsplit = booleanLiteral)? ; dedupCommand @@ -207,6 +250,21 @@ fillnullCommand : expression ; +flattenCommand + : FLATTEN fieldExpression + ; + +trendlineCommand + : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + ; + +trendlineType + : SMA + ; kmeansCommand : KMEANS (kmeansParameter)* @@ -247,17 +305,27 @@ mlArg // clauses fromClause - : SOURCE EQUAL tableSourceClause - | INDEX EQUAL tableSourceClause + : SOURCE EQUAL tableOrSubqueryClause + | INDEX EQUAL tableOrSubqueryClause + ; + +tableOrSubqueryClause + : LT_SQR_PRTHS subSearch RT_SQR_PRTHS (AS alias = qualifiedName)? + | tableSourceClause ; +// One tableSourceClause will generate one Relation node with/without one alias +// even if the relation contains more than one table sources. +// These table sources in one relation will be readed one by one in OpenSearch. +// But it may have different behaivours in different execution backends. +// For example, a Spark UnresovledRelation node only accepts one data source. tableSourceClause - : tableSource (COMMA tableSource)* + : tableSource (COMMA tableSource)* (AS alias = qualifiedName)? ; // join joinCommand - : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableSource + : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableOrSubqueryClause ; joinType @@ -279,13 +347,13 @@ joinCriteria ; joinHintList - : hintPair (COMMA? hintPair)* - ; + : hintPair (COMMA? hintPair)* + ; hintPair - : leftHintKey = LEFT_HINT DOT ID EQUAL leftHintValue = ident #leftHint - | rightHintKey = RIGHT_HINT DOT ID EQUAL rightHintValue = ident #rightHint - ; + : leftHintKey = LEFT_HINT DOT ID EQUAL leftHintValue = ident #leftHint + | rightHintKey = RIGHT_HINT DOT ID EQUAL rightHintValue = ident #rightHint + ; renameClasue : orignalField = wcFieldExpression AS renamedField = wcFieldExpression @@ -340,14 +408,6 @@ statsFunctionName | STDDEV_POP ; -takeAggFunction - : TAKE LT_PRTHS fieldExpression (COMMA size = integerLiteral)? RT_PRTHS - ; - -percentileAggFunction - : PERCENTILE LESS value = integerLiteral GREATER LT_PRTHS aggField = fieldExpression RT_PRTHS - ; - // expressions expression : logicalExpression @@ -365,7 +425,8 @@ logicalExpression comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr - | valueExpression IN valueList # inExpr + | valueExpression NOT? IN valueList # inExpr + | expr1 = functionArg NOT? BETWEEN expr2 = functionArg AND expr3 = functionArg # between ; valueExpressionList @@ -379,8 +440,11 @@ valueExpression | primaryExpression # valueExpressionDefault | positionFunction # positionFunctionCall | caseFunction # caseExpr + | timestampFunction # timestampFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | ident ARROW expression # lambda + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda ; primaryExpression @@ -399,6 +463,7 @@ booleanExpression | isEmptyExpression # isEmptyExpr | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr | EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr + | cidrMatchFunctionCall # cidrFunctionCallExpr ; isEmptyExpression @@ -478,6 +543,10 @@ booleanFunctionCall : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS ; +cidrMatchFunctionCall + : CIDRMATCH LT_PRTHS ipAddress = functionArg COMMA cidrBlock = functionArg RT_PRTHS + ; + convertedDataType : typeName = DATE | typeName = TIME @@ -499,6 +568,10 @@ evalFunctionName | systemFunctionName | positionFunctionName | coalesceFunctionName + | cryptographicFunctionName + | jsonFunctionName + | collectionFunctionName + | lambdaFunctionName ; functionArgs @@ -614,6 +687,12 @@ trigonometricFunctionName | TAN ; +cryptographicFunctionName + : MD5 + | SHA1 + | SHA2 + ; + dateTimeFunctionName : ADDDATE | ADDTIME @@ -622,6 +701,7 @@ dateTimeFunctionName | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP + | CURRENT_TIMEZONE | CURTIME | DATE | DATEDIFF @@ -645,6 +725,7 @@ dateTimeFunctionName | LOCALTIME | LOCALTIMESTAMP | MAKEDATE + | MAKE_DATE | MAKETIME | MICROSECOND | MINUTE @@ -746,6 +827,7 @@ conditionFunctionBase | IFNULL | NULLIF | ISPRESENT + | JSON_VALID ; systemFunctionName @@ -774,6 +856,37 @@ textFunctionName | ISBLANK ; +jsonFunctionName + : JSON + | JSON_OBJECT + | JSON_ARRAY + | JSON_ARRAY_LENGTH + | JSON_EXTRACT + | JSON_KEYS + | JSON_VALID +// | JSON_APPEND +// | JSON_DELETE +// | JSON_EXTEND +// | JSON_SET +// | JSON_ARRAY_ALL_MATCH +// | JSON_ARRAY_ANY_MATCH +// | JSON_ARRAY_FILTER +// | JSON_ARRAY_MAP +// | JSON_ARRAY_REDUCE + ; + +collectionFunctionName + : ARRAY + ; + +lambdaFunctionName + : FORALL + | EXISTS + | FILTER + | TRANSFORM + | REDUCE + ; + positionFunctionName : POSITION ; @@ -817,6 +930,7 @@ literalValue | decimalLiteral | booleanLiteral | datetimeLiteral //#datetime + | intervalLiteral ; intervalLiteral @@ -946,47 +1060,41 @@ keywordsCanBeId | intervalUnit | dateTimeFunctionName | textFunctionName + | jsonFunctionName | mathematicalFunctionName | positionFunctionName - // commands - | SEARCH - | DESCRIBE - | SHOW - | FROM - | WHERE - | CORRELATE - | FIELDS - | RENAME - | STATS - | DEDUP - | SORT - | EVAL - | HEAD - | TOP - | RARE - | PARSE - | METHOD - | REGEX - | PUNCT - | GROK - | PATTERN - | PATTERNS - | NEW_FIELD - | KMEANS - | AD - | ML - | EXPLAIN + | cryptographicFunctionName + | singleFieldRelevanceFunctionName + | multiFieldRelevanceFunctionName + | commandName + | comparisonOperator + | explainMode + | correlationType // commands assist keywords + | IN | SOURCE | INDEX | DESC | DATASOURCES - // CLAUSEKEYWORDS - | SORTBY - // FIELDKEYWORDSAUTO + | AUTO | STR | IP | NUM + | FROM + | PATTERN + | NEW_FIELD + | SCOPE + | MAPPING + | WITH + | USING + | CAST + | GET_FORMAT + | EXTRACT + | INTERVAL + | PLUS + | MINUS + | INCLUDEFIELDS + | NULLS // ARGUMENT KEYWORDS | KEEPEMPTY | CONSECUTIVE @@ -1009,27 +1117,21 @@ keywordsCanBeId | TRAINING_DATA_SIZE | ANOMALY_SCORE_THRESHOLD // AGGREGATIONS - | AVG - | COUNT + | statsFunctionName | DISTINCT_COUNT + | PERCENTILE + | PERCENTILE_APPROX | ESTDC | ESTDC_ERROR - | MAX | MEAN | MEDIAN - | MIN | MODE | RANGE | STDEV | STDEVP - | SUM | SUMSQ | VAR_SAMP | VAR_POP - | STDDEV_SAMP - | STDDEV_POP - | PERCENTILE - | PERCENTILE_APPROX | TAKE | FIRST | LAST @@ -1057,4 +1159,7 @@ keywordsCanBeId | SEMI | ANTI | GEOIP + | BETWEEN + | CIDRMATCH + | trendlineType ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 66303b2a6..e6722eae2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -6,6 +6,7 @@ package org.opensearch.sql.ast; import org.opensearch.sql.ast.expression.*; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; @@ -81,6 +82,10 @@ public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } @@ -149,6 +154,10 @@ public T visitFunction(Function node, C context) { return visitChildren(node, context); } + public T visitLambdaFunction(LambdaFunction node, C context) { + return visitChildren(node, context); + } + public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } @@ -179,6 +188,10 @@ public T visitField(Field node, C context) { return visitChildren(node, context); } + public T visitFieldList(FieldList node, C context) { + return visitChildren(node, context); + } + public T visitQualifiedName(QualifiedName node, C context) { return visitChildren(node, context); } @@ -269,9 +282,14 @@ public T visitExplain(Explain node, C context) { public T visitInSubquery(InSubquery node, C context) { return visitChildren(node, context); } + public T visitFillNull(FillNull fillNull, C context) { return visitChildren(fillNull, context); } + + public T visitFieldSummary(FieldSummary fieldSummary, C context) { + return visitChildren(fieldSummary, context); + } public T visitScalarSubquery(ScalarSubquery node, C context) { return visitChildren(node, context); @@ -282,4 +300,15 @@ public T visitExistsSubquery(ExistsSubquery node, C context) { } public T visitGeoIp(GeoIp node, C context) { return visitGeoip(node, context); } + + public T visitWindow(Window node, C context) { + return visitChildren(node, context); + } + public T visitCidr(Cidr node, C context) { + return visitChildren(node, context); + } + + public T visitFlatten(Flatten flatten, C context) { + return visitChildren(flatten, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java new file mode 100644 index 000000000..fdbb3ef65 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** AST node that represents CIDR function. */ +@AllArgsConstructor +@Getter +@EqualsAndHashCode(callSuper = false) +@ToString +public class Cidr extends UnresolvedExpression { + private UnresolvedExpression ipAddress; + private UnresolvedExpression cidrBlock; + + @Override + public List getChild() { + return Arrays.asList(ipAddress, cidrBlock); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCidr(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java new file mode 100644 index 000000000..4f6ac5e14 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** Expression node that includes a list of fields nodes. */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@AllArgsConstructor +public class FieldList extends UnresolvedExpression { + private final List fieldList; + + @Override + public List getChild() { + return ImmutableList.copyOf(fieldList); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldList(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java index a53b1e130..6226a7f6b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ast.expression; import lombok.Getter; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java index fc09ec2f5..bf00b2106 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java @@ -23,11 +23,6 @@ public class Interval extends UnresolvedExpression { private final UnresolvedExpression value; private final IntervalUnit unit; - public Interval(UnresolvedExpression value, String unit) { - this.value = value; - this.unit = IntervalUnit.of(unit); - } - @Override public List getChild() { return Collections.singletonList(value); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java index a7e983473..6e1e0712c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java @@ -17,6 +17,7 @@ public enum IntervalUnit { UNKNOWN, MICROSECOND, + MILLISECOND, SECOND, MINUTE, HOUR, diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IsEmpty.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IsEmpty.java index 0374d1c90..1691992a6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IsEmpty.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IsEmpty.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ast.expression; import com.google.common.collect.ImmutableList; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java new file mode 100644 index 000000000..e1ee755b8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of lambda function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class LambdaFunction extends UnresolvedExpression { + private final UnresolvedExpression function; + private final List funcArgs; + + @Override + public List getChild() { + List children = new ArrayList<>(); + children.add(function); + children.addAll(funcArgs); + return children; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLambdaFunction(this, context); + } + + @Override + public String toString() { + return String.format( + "(%s) -> %s", + funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")), + function.toString() + ); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java index fb18b8c1e..1ecea8779 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ast.expression; import lombok.EqualsAndHashCode; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java index f7bd8ad9a..cccc163f6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java new file mode 100644 index 000000000..a8072e76b --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +public class FieldSummary extends UnresolvedPlan { + private List includeFields; + private boolean includeNull; + private List collect; + private UnresolvedPlan child; + + public FieldSummary(List collect) { + this.collect = collect; + collect.forEach(exp -> { + if (exp instanceof Argument) { + this.includeNull = (boolean) ((Argument)exp).getValue().getValue(); + } + if (exp instanceof AttributeList) { + this.includeFields = ((AttributeList)exp).getAttrList(); + } + }); + } + + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldSummary(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java index 19bfea668..a1d591b9f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ast.tree; import lombok.Getter; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java new file mode 100644 index 000000000..e31fbb6e3 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -0,0 +1,34 @@ +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; + +import java.util.List; + +@RequiredArgsConstructor +public class Flatten extends UnresolvedPlan { + + private UnresolvedPlan child; + + @Getter + private final Field fieldToBeFlattened; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFlatten(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index e1732f75f..1b30a7998 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -8,7 +8,9 @@ import com.google.common.collect.ImmutableList; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; +import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; @@ -38,7 +40,7 @@ public Relation(UnresolvedExpression tableName, String alias) { } /** Optional alias name for the relation. */ - private String alias; + @Setter @Getter private String alias; /** * Return table name. @@ -53,15 +55,6 @@ public List getQualifiedNames() { return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } - /** - * Return alias. - * - * @return alias. - */ - public String getAlias() { - return alias; - } - /** * Get Qualified name preservs parts of the user given identifiers. This can later be utilized to * determine DataSource,Schema and Table Name during Analyzer stage. So Passing QualifiedName diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 000000000..9fa1ae81d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final Optional sortByField; + private final List computations; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation { + + private final Integer numberOfDataPoints; + private final UnresolvedExpression dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, Trendline.TrendlineType computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = computationType; + } + + } + + public enum TrendlineType { + SMA + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java new file mode 100644 index 000000000..26cd08831 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class Window extends UnresolvedPlan { + private final List windowFunctionList; + private final List partExprList; + private final List sortExprList; + @Setter private UnresolvedExpression span; + private UnresolvedPlan child; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitWindow(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/Parser.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/Parser.java index 7962f53ef..02faf4f3d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/Parser.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/Parser.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.common.antlr; import org.antlr.v4.runtime.tree.ParseTree; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 6b549663a..13b5c20ef 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -52,18 +52,23 @@ public enum BuiltinFunctionName { SIN(FunctionName.of("sin")), TAN(FunctionName.of("tan")), + /** Cryptographic Functions. */ + MD5(FunctionName.of("md5")), + SHA1(FunctionName.of("sha1")), + SHA2(FunctionName.of("sha2")), + /** Date and Time Functions. */ ADDDATE(FunctionName.of("adddate")), - ADDTIME(FunctionName.of("addtime")), - CONVERT_TZ(FunctionName.of("convert_tz")), +// ADDTIME(FunctionName.of("addtime")), +// CONVERT_TZ(FunctionName.of("convert_tz")), DATE(FunctionName.of("date")), DATEDIFF(FunctionName.of("datediff")), - DATETIME(FunctionName.of("datetime")), +// DATETIME(FunctionName.of("datetime")), DATE_ADD(FunctionName.of("date_add")), DATE_FORMAT(FunctionName.of("date_format")), DATE_SUB(FunctionName.of("date_sub")), DAY(FunctionName.of("day")), - DAYNAME(FunctionName.of("dayname")), +// DAYNAME(FunctionName.of("dayname")), DAYOFMONTH(FunctionName.of("dayofmonth")), DAY_OF_MONTH(FunctionName.of("day_of_month")), DAYOFWEEK(FunctionName.of("dayofweek")), @@ -71,56 +76,58 @@ public enum BuiltinFunctionName { DAY_OF_WEEK(FunctionName.of("day_of_week")), DAY_OF_YEAR(FunctionName.of("day_of_year")), EXTRACT(FunctionName.of("extract")), - FROM_DAYS(FunctionName.of("from_days")), +// FROM_DAYS(FunctionName.of("from_days")), FROM_UNIXTIME(FunctionName.of("from_unixtime")), - GET_FORMAT(FunctionName.of("get_format")), +// GET_FORMAT(FunctionName.of("get_format")), HOUR(FunctionName.of("hour")), HOUR_OF_DAY(FunctionName.of("hour_of_day")), LAST_DAY(FunctionName.of("last_day")), MAKEDATE(FunctionName.of("makedate")), - MAKETIME(FunctionName.of("maketime")), - MICROSECOND(FunctionName.of("microsecond")), + MAKE_DATE(FunctionName.of("make_date")), +// MAKETIME(FunctionName.of("maketime")), +// MICROSECOND(FunctionName.of("microsecond")), MINUTE(FunctionName.of("minute")), - MINUTE_OF_DAY(FunctionName.of("minute_of_day")), +// MINUTE_OF_DAY(FunctionName.of("minute_of_day")), MINUTE_OF_HOUR(FunctionName.of("minute_of_hour")), MONTH(FunctionName.of("month")), MONTH_OF_YEAR(FunctionName.of("month_of_year")), MONTHNAME(FunctionName.of("monthname")), - PERIOD_ADD(FunctionName.of("period_add")), - PERIOD_DIFF(FunctionName.of("period_diff")), +// PERIOD_ADD(FunctionName.of("period_add")), +// PERIOD_DIFF(FunctionName.of("period_diff")), QUARTER(FunctionName.of("quarter")), - SEC_TO_TIME(FunctionName.of("sec_to_time")), +// SEC_TO_TIME(FunctionName.of("sec_to_time")), SECOND(FunctionName.of("second")), SECOND_OF_MINUTE(FunctionName.of("second_of_minute")), - STR_TO_DATE(FunctionName.of("str_to_date")), +// STR_TO_DATE(FunctionName.of("str_to_date")), SUBDATE(FunctionName.of("subdate")), - SUBTIME(FunctionName.of("subtime")), - TIME(FunctionName.of("time")), - TIMEDIFF(FunctionName.of("timediff")), - TIME_TO_SEC(FunctionName.of("time_to_sec")), +// SUBTIME(FunctionName.of("subtime")), +// TIME(FunctionName.of("time")), +// TIMEDIFF(FunctionName.of("timediff")), +// TIME_TO_SEC(FunctionName.of("time_to_sec")), TIMESTAMP(FunctionName.of("timestamp")), TIMESTAMPADD(FunctionName.of("timestampadd")), TIMESTAMPDIFF(FunctionName.of("timestampdiff")), - TIME_FORMAT(FunctionName.of("time_format")), - TO_DAYS(FunctionName.of("to_days")), - TO_SECONDS(FunctionName.of("to_seconds")), - UTC_DATE(FunctionName.of("utc_date")), - UTC_TIME(FunctionName.of("utc_time")), +// TIME_FORMAT(FunctionName.of("time_format")), +// TO_DAYS(FunctionName.of("to_days")), +// TO_SECONDS(FunctionName.of("to_seconds")), +// UTC_DATE(FunctionName.of("utc_date")), +// UTC_TIME(FunctionName.of("utc_time")), UTC_TIMESTAMP(FunctionName.of("utc_timestamp")), + CURRENT_TIMEZONE(FunctionName.of("current_timezone")), UNIX_TIMESTAMP(FunctionName.of("unix_timestamp")), WEEK(FunctionName.of("week")), WEEKDAY(FunctionName.of("weekday")), WEEKOFYEAR(FunctionName.of("weekofyear")), WEEK_OF_YEAR(FunctionName.of("week_of_year")), YEAR(FunctionName.of("year")), - YEARWEEK(FunctionName.of("yearweek")), +// YEARWEEK(FunctionName.of("yearweek")), // `now`-like functions NOW(FunctionName.of("now")), CURDATE(FunctionName.of("curdate")), CURRENT_DATE(FunctionName.of("current_date")), - CURTIME(FunctionName.of("curtime")), - CURRENT_TIME(FunctionName.of("current_time")), +// CURTIME(FunctionName.of("curtime")), +// CURRENT_TIME(FunctionName.of("current_time")), LOCALTIME(FunctionName.of("localtime")), CURRENT_TIMESTAMP(FunctionName.of("current_timestamp")), LOCALTIMESTAMP(FunctionName.of("localtimestamp")), @@ -158,6 +165,8 @@ public enum BuiltinFunctionName { /** Aggregation Function. */ AVG(FunctionName.of("avg")), + MEAN(FunctionName.of("mean")), + STDDEV(FunctionName.of("stddev")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), MIN(FunctionName.of("min")), @@ -198,6 +207,35 @@ public enum BuiltinFunctionName { TRIM(FunctionName.of("trim")), UPPER(FunctionName.of("upper")), + /** JSON Functions. */ + // If the function argument is a valid JSON, return itself, or return NULL + JSON(FunctionName.of("json")), + JSON_OBJECT(FunctionName.of("json_object")), + JSON_ARRAY(FunctionName.of("json_array")), + JSON_ARRAY_LENGTH(FunctionName.of("json_array_length")), + JSON_EXTRACT(FunctionName.of("json_extract")), + JSON_KEYS(FunctionName.of("json_keys")), + JSON_VALID(FunctionName.of("json_valid")), +// JSON_DELETE(FunctionName.of("json_delete")), +// JSON_APPEND(FunctionName.of("json_append")), +// JSON_EXTEND(FunctionName.of("json_extend")), +// JSON_SET(FunctionName.of("json_set")), +// JSON_ARRAY_ALL_MATCH(FunctionName.of("json_array_all_match")), +// JSON_ARRAY_ANY_MATCH(FunctionName.of("json_array_any_match")), +// JSON_ARRAY_FILTER(FunctionName.of("json_array_filter")), +// JSON_ARRAY_MAP(FunctionName.of("json_array_map")), +// JSON_ARRAY_REDUCE(FunctionName.of("json_array_reduce")), + + /** COLLECTION Functions **/ + ARRAY(FunctionName.of("array")), + + /** LAMBDA Functions **/ + ARRAY_FORALL(FunctionName.of("forall")), + ARRAY_EXISTS(FunctionName.of("exists")), + ARRAY_FILTER(FunctionName.of("filter")), + ARRAY_TRANSFORM(FunctionName.of("transform")), + ARRAY_AGGREGATE(FunctionName.of("reduce")), + /** NULL Test. */ IS_NULL(FunctionName.of("is null")), IS_NOT_NULL(FunctionName.of("is not null")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java new file mode 100644 index 000000000..2541b3743 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import inet.ipaddr.AddressStringException; +import inet.ipaddr.IPAddressString; +import inet.ipaddr.IPAddressStringParameters; +import scala.Function2; +import scala.Serializable; +import scala.runtime.AbstractFunction2; + + +public interface SerializableUdf { + + Function2 cidrFunction = new SerializableAbstractFunction2<>() { + + IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() + .allowEmpty(false) + .setEmptyAsLoopback(false) + .allow_inet_aton(false) + .allowSingleSegment(false) + .toParams(); + + @Override + public Boolean apply(String ipAddress, String cidrBlock) { + + IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions); + + try { + parsedIpAddress.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage()); + } + + IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions); + + try { + parsedCidrBlock.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage()); + } + + if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6."); + } + + return parsedCidrBlock.contains(parsedIpAddress); + } + }; + + abstract class SerializableAbstractFunction2 extends AbstractFunction2 + implements Serializable { + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java new file mode 100644 index 000000000..397419819 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -0,0 +1,450 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; +import org.apache.spark.sql.catalyst.expressions.Exists$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.In$; +import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LambdaFunction$; +import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.ListQuery$; +import org.apache.spark.sql.catalyst.expressions.MakeInterval$; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable; +import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.*; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FillNull; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.utils.AggregatorTransformer; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; +import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.JavaToScalaTransformer; +import scala.Option; +import scala.PartialFunction; +import scala.Tuple2; +import scala.collection.Seq; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Stack; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; +import static org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer.createIntervalArgs; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; +import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; + +/** + * Class of building catalyst AST Expression nodes. + */ +public class CatalystExpressionVisitor extends AbstractNodeVisitor { + + private final AbstractNodeVisitor planVisitor; + + public CatalystExpressionVisitor(AbstractNodeVisitor planVisitor) { + this.planVisitor = planVisitor; + } + + public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + return unresolved.accept(this, context); + } + + @Override + public Expression visitLiteral(Literal node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( + translate(node.getValue(), node.getType()), translate(node.getType()))); + } + + /** + * generic binary (And, Or, Xor , ...) arithmetic expression resolver + * + * @param node + * @param transformer + * @param context + * @return + */ + public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Optional left = context.popNamedParseExpressions(); + node.getRight().accept(this, context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + return transformer.apply(left.get(), right.get()); + } else if (left.isPresent()) { + return context.getNamedParseExpressions().push(left.get()); + } else if (right.isPresent()) { + return context.getNamedParseExpressions().push(right.get()); + } + return null; + + } + + @Override + public Expression visitAnd(And node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); + } + + @Override + public Expression visitOr(Or node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); + } + + @Override + public Expression visitXor(Xor node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); + } + + @Override + public Expression visitNot(Not node, CatalystPlanContext context) { + node.getExpression().accept(this, context); + Optional arg = context.popNamedParseExpressions(); + return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); + } + + @Override + public Expression visitSpan(Span node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression field = (Expression) context.popNamedParseExpressions().get(); + node.getValue().accept(this, context); + Expression value = (Expression) context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); + } + + @Override + public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression arg = (Expression) context.popNamedParseExpressions().get(); + Expression aggregator = AggregatorTransformer.aggregator(node, arg); + return context.getNamedParseExpressions().push(aggregator); + } + + @Override + public Expression visitCompare(Compare node, CatalystPlanContext context) { + analyze(node.getLeft(), context); + Optional left = context.popNamedParseExpressions(); + analyze(node.getRight(), context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + } + return null; + } + + @Override + public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + List relation = findRelation(context.traversalContext()); + if (!relation.isEmpty()) { + Optional resolveField = resolveField(relation, node, context.getRelations()); + return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) + .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); + } + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + + /** + * Resolve the qualified name with subquery alias:
+ * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
+ * - tableName1.joinKey = subqueryAlias2.joinKey
+ * - subqueryAlias1.joinKey = tableName2.joinKey
+ */ + private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { + if (node.getPrefix().isPresent() && + context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { + if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) + .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) + .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + } + return null; + } + + @Override + public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { + return node.getChild().stream().map(expression -> + visitCompare((Compare) expression, context) + ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + + @Override + public Expression visitAllFields(AllFields node, CatalystPlanContext context) { + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + return context.getNamedParseExpressions().peek(); + } + + @Override + public Expression visitAlias(Alias node, CatalystPlanContext context) { + node.getDelegated().accept(this, context); + Expression arg = context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, + node.getAlias() != null ? node.getAlias() : node.getName(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + + @Override + public Expression visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public Expression visitFunction(Function node, CatalystPlanContext context) { + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTransformer.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); + } + + @Override + public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { + Stack namedParseExpressions = new Stack<>(); + namedParseExpressions.addAll(context.getNamedParseExpressions()); + Expression expression = visitCase(node.getCaseValue(), context); + namedParseExpressions.add(expression); + context.setNamedParseExpressions(namedParseExpressions); + return expression; + } + + @Override + public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : FillNull"); + } + + @Override + public Expression visitInterval(Interval node, CatalystPlanContext context) { + node.getValue().accept(this, context); + Expression value = context.getNamedParseExpressions().pop(); + Expression[] intervalArgs = createIntervalArgs(node.getUnit(), value); + Expression interval = MakeInterval$.MODULE$.apply( + intervalArgs[0], intervalArgs[1], intervalArgs[2], intervalArgs[3], + intervalArgs[4], intervalArgs[5], intervalArgs[6], true); + return context.getNamedParseExpressions().push(interval); + } + + @Override + public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Dedupe"); + } + + @Override + public Expression visitIn(In node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression value = context.popNamedParseExpressions().get(); + List list = node.getValueList().stream().map( expression -> { + expression.accept(this, context); + return context.popNamedParseExpressions().get(); + }).collect(Collectors.toList()); + return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); + } + + @Override + public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public Expression visitCase(Case node, CatalystPlanContext context) { + Stack initialNameExpressions = new Stack<>(); + initialNameExpressions.addAll(context.getNamedParseExpressions()); + analyze(node.getElseClause(), context); + Expression elseValue = context.getNamedParseExpressions().pop(); + List> whens = new ArrayList<>(); + for (When when : node.getWhenClauses()) { + if (node.getCaseValue() == null) { + whens.add( + new Tuple2<>( + analyze(when.getCondition(), context), + analyze(when.getResult(), context) + ) + ); + } else { + // Merge case value and condition (compare value) into a single equal condition + Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); + whens.add( + new Tuple2<>( + analyze(compare, context), analyze(when.getResult(), context) + ) + ); + } + context.retainAllNamedParseExpressions(e -> e); + } + context.setNamedParseExpressions(initialNameExpressions); + return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); + } + + @Override + public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + + @Override + public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + visitExpressionList(node.getChild(), innerContext); + Seq values = innerContext.retainAllNamedParseExpressions(p -> p); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression inSubQuery = InSubquery$.MODULE$.apply( + values, + ListQuery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + -1, + seq(new java.util.ArrayList()), + Option.empty())); + return outerContext.getNamedParseExpressions().push(inSubQuery); + } + + @Override + public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + Option.empty()); + return context.getNamedParseExpressions().push(scalarSubQuery); + } + + @Override + public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext); + Expression existsSubQuery = Exists$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty()); + return context.getNamedParseExpressions().push(existsSubQuery); + } + + @Override + public Expression visitBetween(Between node, CatalystPlanContext context) { + Expression value = analyze(node.getValue(), context); + Expression lower = analyze(node.getLowerBound(), context); + Expression upper = analyze(node.getUpperBound(), context); + context.retainAllNamedParseExpressions(p -> p); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); + } + + @Override + public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) { + analyze(node.getIpAddress(), context); + Expression ipAddressExpression = context.getNamedParseExpressions().pop(); + analyze(node.getCidrBlock(), context); + Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); + + ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddressExpression,cidrBlockExpression), + seq(), + Option.empty(), + Option.apply("cidr"), + false, + true); + + return context.getNamedParseExpressions().push(udf); + } + + @Override + public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext context) { + PartialFunction transformer = JavaToScalaTransformer.toPartialFunction( + expr -> expr instanceof UnresolvedAttribute, + expr -> { + UnresolvedAttribute attr = (UnresolvedAttribute) expr; + return new UnresolvedNamedLambdaVariable(attr.nameParts()); + } + ); + Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer); + context.popNamedParseExpressions(); + List argsResult = node.getFuncArgs().stream() + .map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts()))) + .collect(Collectors.toList()); + return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false)); + } + + @Override + public Expression visitGeoIp(GeoIp node, CatalystPlanContext context) { + + ScalaUDF udf = new ScalaUDF(); + + return context.getNamedParseExpressions().push(udf); + } + + private List visitExpressionList(List expressionList, CatalystPlanContext context) { + return expressionList.isEmpty() + ? emptyList() + : expressionList.stream().map(field -> analyze(field, context)) + .collect(Collectors.toList()); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 46a016d1a..61762f616 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -154,7 +154,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + + public LogicalPlan applyBranches(List> plans) { + plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); + planBranches.remove(0); + return getPlan(); + } + /** * append plan with evolving plans branches * @@ -281,4 +287,5 @@ public static Optional findRelation(LogicalPlan plan) { // Return null if no UnresolvedRelation is found return Optional.empty(); } + } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 22b53d638..669459fba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -6,30 +6,47 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.TableIdentifier; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; -import org.apache.spark.sql.catalyst.expressions.*; -import org.apache.spark.sql.catalyst.plans.logical.*; +import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; +import org.apache.spark.sql.catalyst.expressions.In$; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.ListQuery$; +import org.apache.spark.sql.catalyst.expressions.MakeInterval$; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.SortDirection; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Generate; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.And; -import org.opensearch.sql.ast.expression.BinaryExpression; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.WindowFunction; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; -import org.opensearch.sql.ast.expression.subquery.InSubquery; -import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; @@ -38,8 +55,10 @@ import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Flatten; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; @@ -53,30 +72,28 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; -import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.ppl.utils.AggregatorTranslator; -import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; -import org.opensearch.sql.ppl.utils.ComparatorTransformer; -import org.opensearch.sql.ppl.utils.ParseStrategy; +import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; +import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; +import org.opensearch.sql.ppl.utils.WindowSpecTransformer; +import scala.None$; import scala.Option; -import scala.Tuple2; import scala.collection.IterableLike; import scala.collection.Seq; -import java.util.*; -import java.util.Map; -import java.util.Stack; -import java.util.function.BiFunction; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.List.of; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; -import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent; @@ -88,28 +105,23 @@ import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList; import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields; import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; -import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; -import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; +import static scala.collection.JavaConverters.seqAsJavaList; /** * Utility class to traverse PPL logical plan and translate it into catalyst logical plan */ public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { - private final ExpressionAnalyzer expressionAnalyzer; + private final CatalystExpressionVisitor expressionAnalyzer; public CatalystQueryPlanVisitor() { - this.expressionAnalyzer = new ExpressionAnalyzer(); + this.expressionAnalyzer = new CatalystExpressionVisitor(this); } public LogicalPlan visit(Statement plan, CatalystPlanContext context) { return plan.accept(this, context); } - - public LogicalPlan visitSubSearch(UnresolvedPlan plan, CatalystPlanContext context) { - return plan.accept(this, context); - } - + /** * Handle Query Statement. */ @@ -214,6 +226,30 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { }); } + @Override + public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + + node.getSortByField() + .ifPresent(sortField -> { + Expression sortFieldExpression = visitExpression(sortField, context); + Seq sortOrder = context + .retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(sortFieldExpression, SortUtils.isSortedAscending(sortField))); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortOrder, true, p)); + }); + + List trendlineProjectExpressions = new ArrayList<>(); + + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); + } + + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); + } + @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); @@ -301,6 +337,30 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { return context.apply(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); } + @Override + public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); + Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); + List partitionExpList = visitExpressionList(node.getPartExprList(), context); + UnresolvedExpression span = node.getSpan(); + if (!Objects.isNull(span)) { + visitExpression(span, context); + } + Seq partitionSpec = context.retainAllNamedParseExpressions(p -> p); + Seq orderSpec = seq(new ArrayList()); + Seq aggregatorFunctions = seq( + seqAsJavaList(windowFunctionExpressions).stream() + .map(w -> WindowSpecTransformer.buildAggregateWindowFunction(w, partitionSpec, orderSpec)) + .collect(Collectors.toList())); + return context.apply(p -> + new org.apache.spark.sql.catalyst.plans.logical.Window( + aggregatorFunctions, + partitionSpec, + orderSpec, + p)); + } + @Override public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { expressionAnalyzer.visitAlias(node, context); @@ -355,6 +415,12 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { node.getSize(), DataTypes.IntegerType), p)); } + @Override + public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { + fieldSummary.getChild().get(0).accept(this, context); + return FieldSummaryTransformer.translate(fieldSummary, context); + } + @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { fillNull.getChild().get(0).accept(this, context); @@ -386,6 +452,20 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); } + @Override + public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { + flatten.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(flatten.getFieldToBeFlattened(), context); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + FlattenGenerator flattenGenerator = new FlattenGenerator(field); + context.apply(p -> new Generate(new GeneratorOuter(flattenGenerator), seq(), true, (Option) None$.MODULE$, seq(), p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } @@ -408,7 +488,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); String pattern = (String) node.getPattern().getValue(); - return ParseStrategy.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); + return ParseTransformer.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); } @Override @@ -502,315 +582,4 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { } } } - - /** - * Expression Analyzer. - */ - public class ExpressionAnalyzer extends AbstractNodeVisitor { - - public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { - return unresolved.accept(this, context); - } - - @Override - public Expression visitLiteral(Literal node, CatalystPlanContext context) { - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( - translate(node.getValue(), node.getType()), translate(node.getType()))); - } - - /** - * generic binary (And, Or, Xor , ...) arithmetic expression resolver - * - * @param node - * @param transformer - * @param context - * @return - */ - public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { - node.getLeft().accept(this, context); - Optional left = context.popNamedParseExpressions(); - node.getRight().accept(this, context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - return transformer.apply(left.get(), right.get()); - } else if (left.isPresent()) { - return context.getNamedParseExpressions().push(left.get()); - } else if (right.isPresent()) { - return context.getNamedParseExpressions().push(right.get()); - } - return null; - - } - - @Override - public Expression visitAnd(And node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); - } - - @Override - public Expression visitOr(Or node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); - } - - @Override - public Expression visitXor(Xor node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); - } - - @Override - public Expression visitNot(Not node, CatalystPlanContext context) { - node.getExpression().accept(this, context); - Optional arg = context.popNamedParseExpressions(); - return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); - } - - @Override - public Expression visitSpan(Span node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression field = (Expression) context.popNamedParseExpressions().get(); - node.getValue().accept(this, context); - Expression value = (Expression) context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); - } - - @Override - public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression arg = (Expression) context.popNamedParseExpressions().get(); - Expression aggregator = AggregatorTranslator.aggregator(node, arg); - return context.getNamedParseExpressions().push(aggregator); - } - - @Override - public Expression visitCompare(Compare node, CatalystPlanContext context) { - analyze(node.getLeft(), context); - Optional left = context.popNamedParseExpressions(); - analyze(node.getRight(), context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); - return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); - } - return null; - } - - @Override - public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { - List relation = findRelation(context.traversalContext()); - if (!relation.isEmpty()) { - Optional resolveField = resolveField(relation, node, context.getRelations()); - return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) - .orElse(resolveQualifiedNameWithSubqueryAlias(node, context)); - } - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - - /** - * Resolve the qualified name with subquery alias:
- * - subqueryAlias1.joinKey = subqueryAlias2.joinKey
- * - tableName1.joinKey = subqueryAlias2.joinKey
- * - subqueryAlias1.joinKey = tableName2.joinKey
- */ - private Expression resolveQualifiedNameWithSubqueryAlias(QualifiedName node, CatalystPlanContext context) { - if (node.getPrefix().isPresent() && - context.traversalContext().peek() instanceof org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) { - if (context.getSubqueryAlias().stream().map(p -> (org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias) p) - .anyMatch(a -> a.alias().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } else if (context.getRelations().stream().map(p -> (UnresolvedRelation) p) - .anyMatch(a -> a.tableName().equalsIgnoreCase(node.getPrefix().get().toString()))) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - } - return null; - } - - @Override - public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { - return node.getChild().stream().map(expression -> - visitCompare((Compare) expression, context) - ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); - } - - @Override - public Expression visitAllFields(AllFields node, CatalystPlanContext context) { - // Case of aggregation step - no start projection can be added - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - } - return context.getNamedParseExpressions().peek(); - } - - @Override - public Expression visitAlias(Alias node, CatalystPlanContext context) { - node.getDelegated().accept(this, context); - Expression arg = context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, - node.getAlias() != null ? node.getAlias() : node.getName(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - } - - @Override - public Expression visitEval(Eval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Eval"); - } - - @Override - public Expression visitFunction(Function node, CatalystPlanContext context) { - List arguments = - node.getFuncArgs().stream() - .map( - unresolvedExpression -> { - var ret = analyze(unresolvedExpression, context); - if (ret == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", unresolvedExpression)); - } else { - return context.popNamedParseExpressions().get(); - } - }) - .collect(Collectors.toList()); - Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); - return context.getNamedParseExpressions().push(function); - } - - @Override - public Expression visitIsEmpty(IsEmpty node, CatalystPlanContext context) { - Stack namedParseExpressions = new Stack<>(); - namedParseExpressions.addAll(context.getNamedParseExpressions()); - Expression expression = visitCase(node.getCaseValue(), context); - namedParseExpressions.add(expression); - context.setNamedParseExpressions(namedParseExpressions); - return expression; - } - - @Override - public Expression visitFillNull(FillNull fillNull, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : FillNull"); - } - - @Override - public Expression visitInterval(Interval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Interval"); - } - - @Override - public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Dedupe"); - } - - @Override - public Expression visitIn(In node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : In"); - } - - @Override - public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Kmeans"); - } - - @Override - public Expression visitCase(Case node, CatalystPlanContext context) { - Stack initialNameExpressions = new Stack<>(); - initialNameExpressions.addAll(context.getNamedParseExpressions()); - analyze(node.getElseClause(), context); - Expression elseValue = context.getNamedParseExpressions().pop(); - List> whens = new ArrayList<>(); - for (When when : node.getWhenClauses()) { - if (node.getCaseValue() == null) { - whens.add( - new Tuple2<>( - analyze(when.getCondition(), context), - analyze(when.getResult(), context) - ) - ); - } else { - // Merge case value and condition (compare value) into a single equal condition - Compare compare = new Compare(EQUAL.getName().getFunctionName(), node.getCaseValue(), when.getCondition()); - whens.add( - new Tuple2<>( - analyze(compare, context), analyze(when.getResult(), context) - ) - ); - } - context.retainAllNamedParseExpressions(e -> e); - } - context.setNamedParseExpressions(initialNameExpressions); - return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); - } - - @Override - public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : RareTopN"); - } - - @Override - public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : WindowFunction"); - } - - @Override - public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - visitExpressionList(node.getChild(), innerContext); - Seq values = innerContext.retainAllNamedParseExpressions(p -> p); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression inSubQuery = InSubquery$.MODULE$.apply( - values, - ListQuery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - -1, - seq(new java.util.ArrayList()), - Option.empty())); - return outerContext.getNamedParseExpressions().push(inSubQuery); - } - - @Override - public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - Option.empty()); - return context.getNamedParseExpressions().push(scalarSubQuery); - } - - @Override - public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { - CatalystPlanContext innerContext = new CatalystPlanContext(); - UnresolvedPlan outerPlan = node.getQuery(); - LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); - Expression existsSubQuery = Exists$.MODULE$.apply( - subSearch, - seq(new java.util.ArrayList()), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty()); - return context.getNamedParseExpressions().push(existsSubQuery); - } - - @Override - public Expression visitGeoIp(GeoIp node, CatalystPlanContext context) { - - ScalaUDF udf = new ScalaUDF(); - - return context.getNamedParseExpressions().push(udf); - } - } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 8673b1582..4e6b1f131 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -21,6 +21,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -81,8 +82,8 @@ public class AstBuilder extends OpenSearchPPLParserBaseVisitor { */ private String query; - public AstBuilder(AstExpressionBuilder expressionBuilder, String query) { - this.expressionBuilder = expressionBuilder; + public AstBuilder(String query) { + this.expressionBuilder = new AstExpressionBuilder(this); this.query = query; } @@ -156,8 +157,12 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); String leftAlias = ctx.sideAlias().leftAlias.getText(); String rightAlias = ctx.sideAlias().rightAlias.getText(); - // TODO when sub-search is supported, this part need to change. Now relation is the only supported plan for right side - UnresolvedPlan right = new SubqueryAlias(rightAlias, new Relation(this.internalVisitExpression(ctx.tableSource()), rightAlias)); + if (ctx.tableOrSubqueryClause().alias != null) { + // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous + throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); + UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -265,14 +270,24 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext .map(this::internalVisitExpression) .orElse(null); - Aggregation aggregation = - new Aggregation( - aggListBuilder.build(), - emptyList(), - groupList, - span, - ArgumentFactory.getArgumentList(ctx)); - return aggregation; + if (ctx.STATS() != null) { + Aggregation aggregation = + new Aggregation( + aggListBuilder.build(), + emptyList(), + groupList, + span, + ArgumentFactory.getArgumentList(ctx)); + return aggregation; + } else { + Window window = + new Window( + aggListBuilder.build(), + groupList, + emptyList()); + window.setSpan(span); + return window; + } } /** Dedup command. */ @@ -371,6 +386,30 @@ private java.util.Map buildLookupPair(List (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); } + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = ctx.trendlineClause() + .stream() + .map(this::toTrendlineComputation) + .collect(Collectors.toList()); + return Optional.ofNullable(ctx.sortField()) + .map(this::internalVisitExpression) + .map(Field.class::cast) + .map(sort -> new Trendline(Optional.of(sort), trendlineComputations)) + .orElse(new Trendline(Optional.empty(), trendlineComputations)); + } + + private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParser.TrendlineClauseContext ctx) { + int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + if (numberOfDataPoints < 1) { + throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); + } + Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field); + String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText(); + String computationType = ctx.trendlineType().getText(); + return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase())); + } + /** Top command. */ @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { @@ -411,8 +450,14 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) groupListBuilder.build()); return aggregation; } - - /** Rare command. */ + + /** Fieldsummary command. */ + @Override + public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryCommandContext ctx) { + return new FieldSummary(ctx.fieldsummaryParameter().stream().map(arg -> expressionBuilder.visit(arg)).collect(Collectors.toList())); + } + + /** Rare command. */ @Override public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); @@ -451,16 +496,22 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct return aggregation; } - /** From clause. */ @Override - public UnresolvedPlan visitFromClause(OpenSearchPPLParser.FromClauseContext ctx) { - return visitTableSourceClause(ctx.tableSourceClause()); + public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) { + if (ctx.subSearch() != null) { + return ctx.alias != null + ? new SubqueryAlias(ctx.alias.getText(), visitSubSearch(ctx.subSearch())) + : visitSubSearch(ctx.subSearch()); + } else { + return visitTableSourceClause(ctx.tableSourceClause()); + } } @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return new Relation( - ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias == null + ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) + : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); } @Override @@ -535,6 +586,12 @@ public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandCo } } + @Override + public UnresolvedPlan visitFlattenCommand(OpenSearchPPLParser.FlattenCommandContext ctx) { + Field unresolvedExpression = (Field) internalVisitExpression(ctx.fieldExpression()); + return new Flatten(unresolvedExpression); + } + /** AD command. */ @Override public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 394049651..2c4410bde 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -11,10 +11,41 @@ import org.antlr.v4.runtime.RuleContext; import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor; -import org.opensearch.sql.ast.expression.*; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; +import org.opensearch.sql.ast.expression.Between; +import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cidr; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.EqualTo; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.IntervalUnit; +import org.opensearch.sql.ast.expression.IsEmpty; +import org.opensearch.sql.ast.expression.LambdaFunction; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.expression.UnresolvedArgument; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.When; +import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -39,25 +70,21 @@ */ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor { - private static final int DEFAULT_TAKE_FUNCTION_SIZE_VALUE = 10; - - private AstBuilder astBuilder; - - /** Set AstBuilder back to AstExpressionBuilder for resolving the subquery plan in subquery expression */ - public void setAstBuilder(AstBuilder astBuilder) { - this.astBuilder = astBuilder; - } - /** * The function name mapping between fronted and core engine. */ - private static Map FUNCTION_NAME_MAPPING = + private static final Map FUNCTION_NAME_MAPPING = new ImmutableMap.Builder() .put("isnull", IS_NULL.getName().getFunctionName()) .put("isnotnull", IS_NOT_NULL.getName().getFunctionName()) .put("ispresent", IS_NOT_NULL.getName().getFunctionName()) .build(); + private AstBuilder astBuilder; + public AstExpressionBuilder(AstBuilder astBuilder) { + this.astBuilder = astBuilder; + } + @Override public UnresolvedExpression visitMappingCompareExpr(OpenSearchPPLParser.MappingCompareExprContext ctx) { return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right)); @@ -127,7 +154,7 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont @Override public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) { return new Function( - ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); + ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); } @Override @@ -155,6 +182,20 @@ public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ArgumentFactory.getArgumentList(ctx)); } + @Override + public UnresolvedExpression visitFieldsummaryIncludeFields(OpenSearchPPLParser.FieldsummaryIncludeFieldsContext ctx) { + List list = ctx.fieldList().fieldExpression().stream() + .map(this::visitFieldExpression) + .collect(Collectors.toList()); + return new AttributeList(list); + } + + @Override + public UnresolvedExpression visitFieldsummaryNulls(OpenSearchPPLParser.FieldsummaryNullsContext ctx) { + return new Argument("NULLS",(Literal)visitBooleanLiteral(ctx.booleanLiteral())); + } + + /** * Aggregation function. */ @@ -173,14 +214,6 @@ public UnresolvedExpression visitDistinctCountFunctionCall(OpenSearchPPLParser.D return new AggregateFunction("count", visit(ctx.valueExpression()), true); } - @Override - public UnresolvedExpression visitPercentileAggFunction(OpenSearchPPLParser.PercentileAggFunctionContext ctx) { - return new AggregateFunction( - ctx.PERCENTILE().getText(), - visit(ctx.aggField), - Collections.singletonList(new Argument("rank", (Literal) visit(ctx.value)))); - } - @Override public UnresolvedExpression visitPercentileFunctionCall(OpenSearchPPLParser.PercentileFunctionCallContext ctx) { return new AggregateFunction( @@ -212,7 +245,7 @@ public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ct }) .collect(Collectors.toList()); UnresolvedExpression elseValue = new Literal(null, DataType.NULL); - if(ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { + if (ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { // else value is present elseValue = visit(ctx.caseFunction().valueExpression(ctx.caseFunction().valueExpression().size() - 1)); } @@ -245,15 +278,18 @@ public UnresolvedExpression visitConvertedDataType(OpenSearchPPLParser.Converted return new Literal(ctx.getText(), DataType.STRING); } + @Override + public UnresolvedExpression visitBetween(OpenSearchPPLParser.BetweenContext ctx) { + UnresolvedExpression betweenExpr = new Between(visit(ctx.expr1),visit(ctx.expr2),visit(ctx.expr3)); + return ctx.NOT() != null ? new Not(betweenExpr) : betweenExpr; + } + private Function buildFunction( String functionName, List args) { return new Function( functionName, args.stream().map(this::visitFunctionArg).collect(Collectors.toList())); } - public AstExpressionBuilder() { - } - @Override public UnresolvedExpression visitMultiFieldRelevanceFunction( OpenSearchPPLParser.MultiFieldRelevanceFunctionContext ctx) { @@ -267,7 +303,7 @@ public UnresolvedExpression visitTableSource(OpenSearchPPLParser.TableSourceCont if (ctx.getChild(0) instanceof OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) { return visitIdentsAsTableQualifiedName((OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) ctx.getChild(0)); } else { - return visitIdentifiers(Arrays.asList(ctx)); + return visitIdentifiers(List.of(ctx)); } } @@ -359,9 +395,9 @@ public UnresolvedExpression visitRightHint(OpenSearchPPLParser.RightHintContext @Override public UnresolvedExpression visitInSubqueryExpr(OpenSearchPPLParser.InSubqueryExprContext ctx) { UnresolvedExpression expr = new InSubquery( - ctx.valueExpressionList().valueExpression().stream() - .map(this::visit).collect(Collectors.toList()), - astBuilder.visitSubSearch(ctx.subSearch())); + ctx.valueExpressionList().valueExpression().stream() + .map(this::visit).collect(Collectors.toList()), + astBuilder.visitSubSearch(ctx.subSearch())); return ctx.NOT() != null ? new Not(expr) : expr; } @@ -383,10 +419,43 @@ public UnresolvedExpression visitGeoIpFunctionCall(OpenSearchPPLParser.GeoIpFunc return new GeoIp(datasource, ipAddress, properties); } - Literal pattern = (Literal) internalVisitExpression(ctx.pattern); + @Override + public UnresolvedExpression visitInExpr(OpenSearchPPLParser.InExprContext ctx) { + UnresolvedExpression expr = new In(visit(ctx.valueExpression()), + ctx.valueList().literalValue().stream().map(this::visit).collect(Collectors.toList())); + return ctx.NOT() != null ? new Not(expr) : expr; + } - return new Parse(ParseMethod.REGEX, sourceField, pattern, ImmutableMap.of()); -} + @Override + public UnresolvedExpression visitCidrMatchFunctionCall(OpenSearchPPLParser.CidrMatchFunctionCallContext ctx) { + return new Cidr(visit(ctx.ipAddress), visit(ctx.cidrBlock)); + } + + @Override + public UnresolvedExpression visitTimestampFunctionCall( + OpenSearchPPLParser.TimestampFunctionCallContext ctx) { + return new Function( + ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx)); + } + + @Override + public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { + + List arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect( + Collectors.toList()); + UnresolvedExpression function = visitExpression(ctx.expression()); + return new LambdaFunction(function, arguments); + } + + private List timestampFunctionArguments( + OpenSearchPPLParser.TimestampFunctionCallContext ctx) { + List args = + Arrays.asList( + new Literal(ctx.timestampFunction().simpleDateTimePart().getText(), DataType.STRING), + visitFunctionArg(ctx.timestampFunction().firstArg), + visitFunctionArg(ctx.timestampFunction().secondArg)); + return args; + } private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index 6a545f091..44610f3a4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -56,6 +56,7 @@ public StatementBuilderContext getContext() { } public static class StatementBuilderContext { + public static final int FETCH_SIZE = 1000; private int fetchSize; public StatementBuilderContext(int fetchSize) { @@ -63,8 +64,7 @@ public StatementBuilderContext(int fetchSize) { } public static StatementBuilderContext builder() { - //todo set the default statement builder init params configurable - return new StatementBuilderContext(1000); + return new StatementBuilderContext(FETCH_SIZE); } public int getFetchSize() { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java similarity index 80% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java index 3c367a948..9788ac1bc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java @@ -25,32 +25,38 @@ * * @return */ -public interface AggregatorTranslator { - +public interface AggregatorTransformer { + static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName()); + boolean distinct = aggregateFunction.getDistinct(); // Additional aggregation function operators will be added here - switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { + BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get(); + switch (functionName) { case MAX: - return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MAX"), seq(arg), distinct, empty(),false); case MIN: - return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MIN"), seq(arg), distinct, empty(),false); + case MEAN: + return new UnresolvedFunction(seq("MEAN"), seq(arg), distinct, empty(),false); case AVG: - return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("AVG"), seq(arg), distinct, empty(),false); case COUNT: - return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("COUNT"), seq(arg), distinct, empty(),false); case SUM: - return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("SUM"), seq(arg), distinct, empty(),false); + case STDDEV: + return new UnresolvedFunction(seq("STDDEV"), seq(arg), distinct, empty(),false); case STDDEV_POP: - return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), distinct, empty(),false); case STDDEV_SAMP: - return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), distinct, empty(),false); case PERCENTILE: - return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: - return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java new file mode 100644 index 000000000..e39c9ab38 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -0,0 +1,221 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction$; +import org.apache.spark.sql.catalyst.expressions.CurrentTimeZone$; +import org.apache.spark.sql.catalyst.expressions.CurrentTimestamp$; +import org.apache.spark.sql.catalyst.expressions.DateAddInterval$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.TimestampAdd$; +import org.apache.spark.sql.catalyst.expressions.TimestampDiff$; +import org.apache.spark.sql.catalyst.expressions.ToUTCTimestamp$; +import org.apache.spark.sql.catalyst.expressions.UnaryMinus$; +import org.opensearch.sql.ast.expression.IntervalUnit; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import scala.Option; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_SUB; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_MONTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COALESCE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_LENGTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_EXTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_KEYS; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_OBJECT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_VALID; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MODULUS; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.HOUR_OF_DAY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.LENGTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOCALTIME; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE_OF_HOUR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MONTH_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SECOND_OF_MINUTE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_TIMESTAMP; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK_OF_YEAR; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface BuiltinFunctionTransformer { + + /** + * The name mapping between PPL builtin functions to Spark builtin functions. + * This is only used for the built-in functions between PPL and Spark with different names. + * If the built-in function names are the same in PPL and Spark, add it to {@link BuiltinFunctionName} only. + */ + static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING + = ImmutableMap.builder() + // arithmetic operators + .put(ADD, "+") + .put(SUBTRACT, "-") + .put(MULTIPLY, "*") + .put(DIVIDE, "/") + .put(MODULUS, "%") + // time functions + .put(DAY_OF_WEEK, "dayofweek") + .put(DAY_OF_MONTH, "dayofmonth") + .put(DAY_OF_YEAR, "dayofyear") + .put(WEEK_OF_YEAR, "weekofyear") + .put(WEEK, "weekofyear") + .put(MONTH_OF_YEAR, "month") + .put(HOUR_OF_DAY, "hour") + .put(MINUTE_OF_HOUR, "minute") + .put(SECOND_OF_MINUTE, "second") + .put(SUBDATE, "date_sub") // only maps subdate(date, days) + .put(ADDDATE, "date_add") // only maps adddate(date, days) + .put(DATEDIFF, "datediff") + .put(LOCALTIME, "localtimestamp") + .put(SYSDATE, "now") + // condition functions + .put(IS_NULL, "isnull") + .put(IS_NOT_NULL, "isnotnull") + .put(BuiltinFunctionName.ISPRESENT, "isnotnull") + .put(COALESCE, "coalesce") + .put(LENGTH, "length") + .put(TRIM, "trim") + // json functions + .put(JSON_KEYS, "json_object_keys") + .put(JSON_EXTRACT, "get_json_object") + .build(); + + /** + * The name mapping between PPL builtin functions to Spark builtin functions. + */ + static final Map, Expression>> PPL_TO_SPARK_FUNC_MAPPING + = ImmutableMap., Expression>>builder() + // json functions + .put( + JSON_ARRAY, + args -> { + return UnresolvedFunction$.MODULE$.apply("array", seq(args), false); + }) + .put( + JSON_OBJECT, + args -> { + return UnresolvedFunction$.MODULE$.apply("named_struct", seq(args), false); + }) + .put( + JSON_ARRAY_LENGTH, + args -> { + // Check if the input is an array (from json_array()) or a JSON string + if (args.get(0) instanceof UnresolvedFunction) { + // Input is a JSON array + return UnresolvedFunction$.MODULE$.apply("json_array_length", + seq(UnresolvedFunction$.MODULE$.apply("to_json", seq(args), false)), false); + } else { + // Input is a JSON string + return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); + } + }) + .put( + JSON, + args -> { + // Check if the input is a named_struct (from json_object()) or a JSON string + if (args.get(0) instanceof UnresolvedFunction) { + return UnresolvedFunction$.MODULE$.apply("to_json", seq(args.get(0)), false); + } else { + return UnresolvedFunction$.MODULE$.apply("get_json_object", + seq(args.get(0), Literal$.MODULE$.apply("$")), false); + } + }) + .put( + JSON_VALID, + args -> { + return UnresolvedFunction$.MODULE$.apply("isnotnull", + seq(UnresolvedFunction$.MODULE$.apply("get_json_object", + seq(args.get(0), Literal$.MODULE$.apply("$")), false)), false); + }) + .put( + DATE_ADD, + args -> { + return DateAddInterval$.MODULE$.apply(args.get(0), args.get(1), Option.empty(), false); + }) + .put( + DATE_SUB, + args -> { + return DateAddInterval$.MODULE$.apply(args.get(0), UnaryMinus$.MODULE$.apply(args.get(1), true), Option.empty(), true); + }) + .put( + TIMESTAMPADD, + args -> { + return TimestampAdd$.MODULE$.apply(args.get(0).toString(), args.get(1), args.get(2), Option.empty()); + }) + .put( + TIMESTAMPDIFF, + args -> { + return TimestampDiff$.MODULE$.apply(args.get(0).toString(), args.get(1), args.get(2), Option.empty()); + }) + .put( + UTC_TIMESTAMP, + args -> { + return ToUTCTimestamp$.MODULE$.apply(CurrentTimestamp$.MODULE$.apply(), CurrentTimeZone$.MODULE$.apply()); + }) + .build(); + + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { + if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { + // TODO change it when UDF is supported + // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions + throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); + } else { + BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); + String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.get(builtin); + if (name != null) { + // there is a Spark builtin function mapping with the PPL builtin function + return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); + } + Function, Expression> alternative = PPL_TO_SPARK_FUNC_MAPPING.get(builtin); + if (alternative != null) { + return alternative.apply(args); + } + name = builtin.getName().getFunctionName(); + return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); + } + } + + static Expression[] createIntervalArgs(IntervalUnit unit, Expression value) { + Expression[] args = new Expression[7]; + Arrays.fill(args, Literal$.MODULE$.apply(0)); + switch (unit) { + case YEAR: args[0] = value; break; + case MONTH: args[1] = value; break; + case WEEK: args[2] = value; break; + case DAY: args[3] = value; break; + case HOUR: args[4] = value; break; + case MINUTE: args[5] = value; break; + case SECOND: args[6] = value; break; + default: + throw new IllegalArgumentException("Unsupported Interval unit: " + unit); + } + return args; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java deleted file mode 100644 index d817305a9..000000000 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.ppl.utils; - -import com.google.common.collect.ImmutableMap; -import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; -import org.apache.spark.sql.catalyst.expressions.Expression; -import org.opensearch.sql.expression.function.BuiltinFunctionName; - -import java.util.List; -import java.util.Map; - -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_MONTH; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.COALESCE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.MODULUS; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_WEEK; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_YEAR; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.HOUR_OF_DAY; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LENGTH; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOCALTIME; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE_OF_HOUR; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.MONTH_OF_YEAR; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SECOND_OF_MINUTE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBDATE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK_OF_YEAR; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -import static scala.Option.empty; - -public interface BuiltinFunctionTranslator { - - /** - * The name mapping between PPL builtin functions to Spark builtin functions. - */ - static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING - = new ImmutableMap.Builder() - // arithmetic operators - .put(ADD, "+") - .put(SUBTRACT, "-") - .put(MULTIPLY, "*") - .put(DIVIDE, "/") - .put(MODULUS, "%") - // time functions - .put(DAY_OF_WEEK, "dayofweek") - .put(DAY_OF_MONTH, "dayofmonth") - .put(DAY_OF_YEAR, "dayofyear") - .put(WEEK_OF_YEAR, "weekofyear") - .put(WEEK, "weekofyear") - .put(MONTH_OF_YEAR, "month") - .put(HOUR_OF_DAY, "hour") - .put(MINUTE_OF_HOUR, "minute") - .put(SECOND_OF_MINUTE, "second") - .put(SUBDATE, "date_sub") // only maps subdate(date, days) - .put(ADDDATE, "date_add") // only maps adddate(date, days) - .put(DATEDIFF, "datediff") - .put(LOCALTIME, "localtimestamp") - //condition functions - .put(IS_NULL, "isnull") - .put(IS_NOT_NULL, "isnotnull") - .put(BuiltinFunctionName.ISPRESENT, "isnotnull") - .put(COALESCE, "coalesce") - .put(LENGTH, "length") - .put(TRIM, "trim") - .build(); - - static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { - if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { - // TODO change it when UDF is supported - // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions - throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); - } else { - BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); - String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING - .getOrDefault(builtin, builtin.getName().getFunctionName()); - return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); - } - } -} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 4345b0897..e4defad52 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -14,6 +14,7 @@ import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.NullType$; import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; @@ -41,6 +42,7 @@ public interface DataTypeTransformer { static Seq seq(T... elements) { return seq(List.of(elements)); } + static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); } @@ -63,6 +65,8 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; + case UNDEFINED: + return NullType$.MODULE$; default: return StringType$.MODULE$; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java index 0866ca7e9..a34fc0184 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java @@ -16,8 +16,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; -import org.opensearch.sql.ppl.CatalystQueryPlanVisitor.ExpressionAnalyzer; import scala.collection.Seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; @@ -38,7 +38,7 @@ public interface DedupeTransformer { static LogicalPlan retainOneDuplicateEventAndKeepEmpty( Dedupe node, Seq dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); @@ -63,7 +63,7 @@ static LogicalPlan retainOneDuplicateEventAndKeepEmpty( static LogicalPlan retainOneDuplicateEvent( Dedupe node, Seq dedupeFields, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); @@ -87,7 +87,7 @@ static LogicalPlan retainOneDuplicateEvent( static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { context.apply(p -> { // Build isnull Filter for right @@ -137,7 +137,7 @@ static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( static LogicalPlan retainMultipleDuplicateEvents( Dedupe node, Integer allowedDuplication, - ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // Build isnotnull Filter Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); @@ -163,7 +163,7 @@ static LogicalPlan retainMultipleDuplicateEvents( return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); } - private static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNotNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNotNullExpressions = context.retainAllNamedParseExpressions( @@ -180,7 +180,7 @@ private static Expression buildIsNotNullFilterExpression(Dedupe node, Expression return isNotNullExpr; } - private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + private static Expression buildIsNullFilterExpression(Dedupe node, CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNullExpressions = context.retainAllNamedParseExpressions( diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java new file mode 100644 index 000000000..dd8f01874 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -0,0 +1,253 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.AliasIdentifier; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Alias; +import org.apache.spark.sql.catalyst.expressions.Alias$; +import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.expressions.Subtract; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit; +import org.apache.spark.sql.catalyst.plans.logical.LocalLimit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.Sort; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.FieldSummary; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystPlanContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.AVG; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MEAN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.STDDEV; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface FieldSummaryTransformer { + + String TOP_VALUES = "TopValues"; + String NULLS = "Nulls"; + String FIELD = "Field"; + + /** + * translate the command into the aggregate statement group by the column name + */ + static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { + List> aggBranches = fieldSummary.getIncludeFields().stream() + .filter(field -> field instanceof org.opensearch.sql.ast.expression.Field ) + .map(field -> { + Literal fieldNameLiteral = Literal.create(((org.opensearch.sql.ast.expression.Field)field).getField().toString(), StringType); + UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(((org.opensearch.sql.ast.expression.Field)field).getField().getParts())); + context.withProjectedFields(Collections.singletonList(field)); + + // Alias for the field name as Field + Alias fieldNameAlias = Alias$.MODULE$.apply(fieldNameLiteral, + FIELD, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(field) as Count + UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(count, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(DISTINCT field) as CountDistinct + UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); + Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, + "DISTINCT", + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MAX(field) as MAX + UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); + Alias maxAlias = Alias$.MODULE$.apply(max, + MAX.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MIN(field) as Min + UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); + Alias minAlias = Alias$.MODULE$.apply(min, + MIN.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the AVG(field) as Avg + Alias avgAlias = getAggMethodAlias(AVG, fieldSummary, fieldLiteral); + + //Alias for the MEAN(field) as Mean + Alias meanAlias = getAggMethodAlias(MEAN, fieldSummary, fieldLiteral); + + //Alias for the STDDEV(field) as Stddev + Alias stddevAlias = getAggMethodAlias(STDDEV, fieldSummary, fieldLiteral); + + // Alias COUNT(*) - COUNT(column2) AS Nulls + UnresolvedFunction countStar = new UnresolvedFunction(seq(COUNT.name()), seq(Literal.create(1, IntegerType)), false, empty(), false); + Alias nonNullAlias = Alias$.MODULE$.apply( + new Subtract(countStar, count), + NULLS, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + + //Alias for the typeOf(field) as Type + UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); + Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, + TYPEOF.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Aggregation + return (Function) p -> + new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, meanAlias, stddevAlias, nonNullAlias, typeOfAlias), p); + }).collect(Collectors.toList()); + + return context.applyBranches(aggBranches); + } + + /** + * Alias for aggregate function (if isIncludeNull use COALESCE to replace nulls with zeros) + */ + private static Alias getAggMethodAlias(BuiltinFunctionName method, FieldSummary fieldSummary, UnresolvedAttribute fieldLiteral) { + UnresolvedFunction avg = new UnresolvedFunction(seq(method.name()), seq(fieldLiteral), false, empty(), false); + Alias avgAlias = Alias$.MODULE$.apply(avg, + method.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + if (fieldSummary.isIncludeNull()) { + UnresolvedFunction coalesceExpr = new UnresolvedFunction( + seq("COALESCE"), + seq(fieldLiteral, Literal.create(0, DataTypes.IntegerType)), + false, + empty(), + false + ); + avg = new UnresolvedFunction(seq(method.name()), seq(coalesceExpr), false, empty(), false); + avgAlias = Alias$.MODULE$.apply(avg, + method.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + } + return avgAlias; + } + + /** + * top values sub-query + */ + private static Alias topValuesSubQueryAlias(FieldSummary fieldSummary, CatalystPlanContext context, UnresolvedAttribute fieldLiteral, UnresolvedFunction count) { + int topValues = 5;// this value should come from the FieldSummary definition + CreateNamedStruct structExpr = new CreateNamedStruct(seq( + fieldLiteral, + count + )); + // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values + UnresolvedFunction collectList = new UnresolvedFunction( + seq("COLLECT_LIST"), + seq(structExpr), + false, + empty(), + !fieldSummary.isIncludeNull() + ); + Alias topValuesAlias = Alias$.MODULE$.apply( + collectList, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + Project subQueryProject = new Project(seq(topValuesAlias), buildTopValueSubQuery(topValues, fieldLiteral, context)); + ScalarSubquery scalarSubquery = ScalarSubquery$.MODULE$.apply( + subQueryProject, + seq(new ArrayList()), + NamedExpression.newExprId(), + seq(new ArrayList()), + empty(), + empty()); + + return Alias$.MODULE$.apply( + scalarSubquery, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + } + + /** + * inner top values query + * ----------------------------------------------------- + * : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, count_status)), None)] + * : : +- 'GlobalLimit 5 + * : : +- 'LocalLimit 5 + * : : +- 'Sort ['count_status DESC NULLS LAST], true + * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] + * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + private static LogicalPlan buildTopValueSubQuery(int topValues, UnresolvedAttribute fieldLiteral, CatalystPlanContext context) { + //Alias for the count(field) as Count + UnresolvedFunction countFunc = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(countFunc, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + Aggregate aggregate = new Aggregate(seq(fieldLiteral), seq(countAlias), context.getPlan()); + Project project = new Project(seq(fieldLiteral, countAlias), aggregate); + SortOrder sortOrder = new SortOrder(countAlias, Descending$.MODULE$, Ascending$.MODULE$.defaultNullOrdering(), seq()); + Sort sort = new Sort(seq(sortOrder), true, project); + GlobalLimit limit = new GlobalLimit(Literal.create(topValues, IntegerType), new LocalLimit(Literal.create(topValues, IntegerType), sort)); + return new SubqueryAlias(new AliasIdentifier(TOP_VALUES + "_subquery"), limit); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java new file mode 100644 index 000000000..40246d7c9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JavaToScalaTransformer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + + +import scala.PartialFunction; +import scala.runtime.AbstractPartialFunction; + +public interface JavaToScalaTransformer { + static PartialFunction toPartialFunction( + java.util.function.Predicate isDefinedAt, + java.util.function.Function apply) { + return new AbstractPartialFunction() { + @Override + public boolean isDefinedAt(T t) { + return isDefinedAt.test(t); + } + + @Override + public T apply(T t) { + if (isDefinedAt.test(t)) return apply.apply(t); + else return t; + } + }; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java index 8a6bafc53..f6f59c009 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ppl.utils; import org.apache.spark.sql.catalyst.expressions.Expression; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java index 58ef15ea9..3673d96d6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.tree.Lookup; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; import org.opensearch.sql.ppl.CatalystQueryPlanVisitor; import scala.Option; @@ -32,7 +33,7 @@ public interface LookupTransformer { /** lookup mapping fields + input fields*/ static List buildLookupRelationProjectList( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List inputFields = new ArrayList<>(node.getInputFieldList()); if (inputFields.isEmpty()) { @@ -45,7 +46,7 @@ static List buildLookupRelationProjectList( static List buildProjectListFromFields( List fields, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { return fields.stream().map(field -> expressionAnalyzer.visitField(field, context)) .map(NamedExpression.class::cast) @@ -54,7 +55,7 @@ static List buildProjectListFromFields( static Expression buildLookupMappingCondition( Lookup node, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { // only equi-join conditions are accepted in lookup command List equiConditions = new ArrayList<>(); @@ -81,7 +82,7 @@ static Expression buildLookupMappingCondition( static List buildOutputProjectList( Lookup node, Lookup.OutputStrategy strategy, - CatalystQueryPlanVisitor.ExpressionAnalyzer expressionAnalyzer, + CatalystExpressionVisitor expressionAnalyzer, CatalystPlanContext context) { List outputProjectList = new ArrayList<>(); for (Map.Entry entry : node.getOutputCandidateMap().entrySet()) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java similarity index 97% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java index 6cdb2f6b2..eed7db228 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseTransformer.java @@ -1,9 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ppl.utils; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Field; @@ -22,7 +26,7 @@ import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -public interface ParseStrategy { +public interface ParseTransformer { /** * transform the parse/grok/patterns command into a standard catalyst RegExpExtract expression * Since spark's RegExpExtract cant accept actual regExp group name we need to translate the group's name into its corresponding index diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java index 7be7f1f45..1dc7b9878 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.ppl.utils; import org.apache.spark.sql.catalyst.TableIdentifier; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java index 83603b031..803daea8b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java @@ -38,7 +38,7 @@ static SortOrder getSortDirection(Sort node, NamedExpression expression) { .findAny(); return field.map(value -> sortOrder((Expression) expression, - (Boolean) value.getFieldArgs().get(0).getValue().getValue())) + isSortedAscending(value))) .orElse(null); } @@ -51,4 +51,8 @@ static SortOrder sortOrder(Expression expression, boolean ascending) { seq(new ArrayList()) ); } + + static boolean isSortedAscending(Field field) { + return (Boolean) field.getFieldArgs().get(0).getValue().getValue(); + } } \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java new file mode 100644 index 000000000..67603ccc7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.*; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; +import scala.Tuple2; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +public interface TrendlineCatalystUtils { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + return computations.stream() + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .collect(Collectors.toList()); + } + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //window lower boundary + expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); + Expression windowLowerBoundary = context.popNamedParseExpressions().get(); + + //window definition + WindowSpecDefinition windowDefinition = new WindowSpecDefinition( + seq(), + seq(), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + + if (node.getComputationType() == Trendline.TrendlineType.SMA) { + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, + node.getAlias(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList())); + } else { + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + } + } + + private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpressionVisitor expressionVisitor, WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //required number of data points + expressionVisitor.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context); + Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get(); + + //count data points function + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context); + Expression countDataPointsFunction = context.popNamedParseExpressions().get(); + //count data points window + WindowExpression countDataPointsWindow = new WindowExpression( + countDataPointsFunction, + trendlineWindow.windowSpec()); + + expressionVisitor.visitLiteral(new Literal(null, DataType.NULL), context); + Expression nullLiteral = context.popNamedParseExpressions().get(); + Tuple2 nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>( + new LessThan(countDataPointsWindow, requiredNumberOfDataPoints), + nullLiteral + ); + return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java index 0e6ba2a1d..e6dd12032 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.expressions.Alias; import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Divide; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -16,6 +17,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; import org.apache.spark.sql.catalyst.expressions.TimeWindow; +import org.apache.spark.sql.catalyst.expressions.UnboundedFollowing$; import org.apache.spark.sql.catalyst.expressions.UnboundedPreceding$; import org.apache.spark.sql.catalyst.expressions.WindowExpression; import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; @@ -79,4 +81,21 @@ static NamedExpression buildRowNumber(Seq partitionSpec, Seq())); } + + static NamedExpression buildAggregateWindowFunction(Expression aggregator, Seq partitionSpec, Seq orderSpec) { + Alias aggregatorAlias = (Alias) aggregator; + WindowExpression aggWindowExpression = new WindowExpression( + aggregatorAlias.child(), + new WindowSpecDefinition( + partitionSpec, + orderSpec, + new SpecifiedWindowFrame(RowFrame$.MODULE$, UnboundedPreceding$.MODULE$, UnboundedFollowing$.MODULE$))); + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + aggWindowExpression, + aggregatorAlias.name(), + NamedExpression.newExprId(), + seq(new ArrayList()), + Option.empty(), + seq(new ArrayList())); + } } diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala new file mode 100644 index 000000000..23b545826 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, CreateArray, Expression, GenericInternalRow, Inline, UnaryExpression} +import org.apache.spark.sql.types.{ArrayType, StructType} + +class FlattenGenerator(override val child: Expression) + extends Inline(child) + with CollectionGenerator { + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case st: StructType => TypeCheckResult.TypeCheckSuccess + case _ => super.checkInputDataTypes() + } + + override def elementSchema: StructType = child.dataType match { + case st: StructType => st + case _ => super.elementSchema + } + + override protected def withNewChildInternal(newChild: Expression): FlattenGenerator = { + newChild.dataType match { + case ArrayType(st: StructType, _) => new FlattenGenerator(newChild) + case st: StructType => withNewChildInternal(CreateArray(Seq(newChild), false)) + case _ => + throw new IllegalArgumentException(s"Unexpected input type ${newChild.dataType}") + } + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala index c435af53d..ed498e98b 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -29,11 +29,8 @@ class PPLSyntaxParser extends Parser { object PlaneUtils { def plan(parser: PPLSyntaxParser, query: String): Statement = { - val astExpressionBuilder = new AstExpressionBuilder() - val astBuilder = new AstBuilder(astExpressionBuilder, query) - astExpressionBuilder.setAstBuilder(astBuilder) - val builder = - new AstStatementBuilder(astBuilder, AstStatementBuilder.StatementBuilderContext.builder()) - builder.visit(parser.parse(query)) + new AstStatementBuilder( + new AstBuilder(query), + AstStatementBuilder.StatementBuilderContext.builder()).visit(parser.parse(query)) } } diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java new file mode 100644 index 000000000..3d3940730 --- /dev/null +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java @@ -0,0 +1,61 @@ +package org.opensearch.sql.expression.function; + +import org.junit.Assert; +import org.junit.Test; + +public class SerializableUdfTest { + + @Test(expected = RuntimeException.class) + public void cidrNullIpTest() { + SerializableUdf.cidrFunction.apply(null, "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrEmptyIpTest() { + SerializableUdf.cidrFunction.apply("", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrNullCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", null); + } + + @Test(expected = RuntimeException.class) + public void cidrEmptyCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", ""); + } + + @Test(expected = RuntimeException.class) + public void cidrInvalidIpTest() { + SerializableUdf.cidrFunction.apply("xxx", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrInvalidCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", "xxx"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTest() { + SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "192.168.0.0/24"); + SerializableUdf.cidrFunction.apply("192.168.0.0", "2001:db8::/324"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTestV6V4() { + SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTestV4V6() { + SerializableUdf.cidrFunction.apply("192.168.0.0", "2001:db8::/324"); + } + + @Test + public void cidrBasicTest() { + Assert.assertTrue(SerializableUdf.cidrFunction.apply("192.168.0.0", "192.168.0.0/24")); + Assert.assertFalse(SerializableUdf.cidrFunction.apply("10.10.0.0", "192.168.0.0/24")); + Assert.assertTrue(SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "2001:db8::/32")); + Assert.assertFalse(SerializableUdf.cidrFunction.apply("2001:0db7:85a3:0000:0000:8a2e:0370:7334", "2001:0db8::/32")); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 03d7f0ab0..9946bff6a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -959,4 +959,58 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test("test count() as the last aggregator in stats clause") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt"), + context) + val tableRelation = UnresolvedRelation(Seq("table")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), tableRelation) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val aggregate = Aggregate(Seq.empty, Seq(sum, avg, count), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test count() as the last aggregator in stats by clause") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | eval a = 1 | stats sum(a) as sum, avg(a) as avg, count() as cnt by country"), + context) + val tableRelation = UnresolvedRelation(Seq("table")) + val eval = Project(Seq(UnresolvedStar(None), Alias(Literal(1), "a")()), tableRelation) + val sum = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "sum")() + val avg = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("a")), isDistinct = false), + "avg")() + val count = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "cnt")() + val grouping = + Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(sum, avg, count, grouping), eval) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 96176982e..2a569dbdf 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, GreaterThan, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -354,4 +354,46 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite thrown.getMessage === "[Field(field=A, fieldArgs=[]), Field(field=B, fieldArgs=[])] can't be resolved") } + + test("test line comment should pass without exceptions") { + val context = new CatalystPlanContext + planTransformer.visit(plan(pplParser, "source=t a=1 b=2 //this is a comment"), context) + planTransformer.visit(plan(pplParser, "source=t a=1 b=2 // this is a comment "), context) + planTransformer.visit( + plan( + pplParser, + """ + | // test is a new line comment + | source=t a=1 b=2 // test is a line comment at the end of ppl command + | | fields a,b // this is line comment inner ppl command + | ////this is a new line comment + |""".stripMargin), + context) + } + + test("test block comment should pass without exceptions") { + val context = new CatalystPlanContext + planTransformer.visit(plan(pplParser, "source=t a=1 b=2 /*block comment*/"), context) + planTransformer.visit(plan(pplParser, "source=t a=1 b=2 /* block comment */"), context) + planTransformer.visit( + plan( + pplParser, + """ + | /* + | * This is a + | * multiple + | * line block + | * comment + | */ + | search /* block comment */ source=t /* block comment */ a=1 b=2 + | | /* + | This is a + | multiple + | line + | block + | comment */ fields a,b /* block comment */ + | /* block comment */ + |""".stripMargin), + context) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBetweenExpressionTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBetweenExpressionTranslatorTestSuite.scala new file mode 100644 index 000000000..6defcb766 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBetweenExpressionTranslatorTestSuite.scala @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{And, GreaterThanOrEqual, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ + +class PPLLogicalPlanBetweenExpressionTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test between expression") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = { + planTransformer.visit( + plan( + pplParser, + "source = table | where datetime_field between '2024-09-10' and '2024-09-15'"), + context) + } + // SQL: SELECT * FROM table WHERE datetime_field BETWEEN '2024-09-10' AND '2024-09-15' + val star = Seq(UnresolvedStar(None)) + + val datetime_field = UnresolvedAttribute("datetime_field") + val tableRelation = UnresolvedRelation(Seq("table")) + + val lowerBound = Literal("2024-09-10") + val upperBound = Literal("2024-09-15") + val betweenCondition = And( + GreaterThanOrEqual(datetime_field, lowerBound), + LessThanOrEqual(datetime_field, upperBound)) + + val filterPlan = Filter(betweenCondition, tableRelation) + val expectedPlan = Project(star, filterPlan) + + comparePlans(expectedPlan, logPlan, false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCryptographicFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCryptographicFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..a3f163de9 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCryptographicFunctionsTranslatorTestSuite.scala @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanCryptographicFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test md5") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = md5(b)"), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("md5", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test sha1") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = sha1(b)"), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("sha1", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test sha2") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = sha2(b,256)"), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("sha2", seq(UnresolvedAttribute("b"), Literal(256)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..308b038bb --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite.scala @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentTimestamp, CurrentTimeZone, DateAddInterval, Literal, MakeInterval, NamedExpression, TimestampAdd, TimestampDiff, ToUTCTimestamp, UnaryMinus} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project + +class PPLLogicalPlanDateTimeFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test DATE_ADD") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 DAY)"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + DateAddInterval( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + MakeInterval( + Literal(0), + Literal(0), + Literal(0), + Literal(2), + Literal(0), + Literal(0), + Literal(0), + failOnError = true)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test DATE_ADD for year") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = DATE_ADD(DATE('2020-08-26'), INTERVAL 2 YEAR)"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + DateAddInterval( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + MakeInterval( + Literal(2), + Literal(0), + Literal(0), + Literal(0), + Literal(0), + Literal(0), + Literal(0), + failOnError = true)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test DATE_SUB") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = DATE_SUB(DATE('2020-08-26'), INTERVAL 2 DAY)"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + DateAddInterval( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + UnaryMinus( + MakeInterval( + Literal(0), + Literal(0), + Literal(0), + Literal(2), + Literal(0), + Literal(0), + Literal(0), + failOnError = true), + failOnError = true)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test TIMESTAMPADD") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(TimestampAdd("DAY", Literal(17), Literal("2000-01-01 00:00:00")), "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test TIMESTAMPADD with timestamp") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = TIMESTAMPADD(DAY, 17, TIMESTAMP('2000-01-01 00:00:00'))"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + TimestampAdd( + "DAY", + Literal(17), + UnresolvedFunction( + "timestamp", + Seq(Literal("2000-01-01 00:00:00")), + isDistinct = false)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test TIMESTAMPDIFF") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + TimestampDiff("YEAR", Literal("1997-01-01 00:00:00"), Literal("2001-03-06 00:00:00")), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test TIMESTAMPDIFF with timestamp") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = TIMESTAMPDIFF(YEAR, TIMESTAMP('1997-01-01 00:00:00'), TIMESTAMP('2001-03-06 00:00:00'))"), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + TimestampDiff( + "YEAR", + UnresolvedFunction( + "timestamp", + Seq(Literal("1997-01-01 00:00:00")), + isDistinct = false), + UnresolvedFunction( + "timestamp", + Seq(Literal("2001-03-06 00:00:00")), + isDistinct = false)), + "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test UTC_TIMESTAMP") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | eval a = UTC_TIMESTAMP()"), context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(ToUTCTimestamp(CurrentTimestamp(), CurrentTimeZone()), "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test CURRENT_TIMEZONE") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | eval a = CURRENT_TIMEZONE()"), context) + + val table = UnresolvedRelation(Seq("t")) + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("current_timezone", Seq.empty, isDistinct = false), "a")()) + val eval = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala index b8cc9776d..2a828339c 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala @@ -12,7 +12,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, In, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} @@ -204,4 +204,17 @@ class PPLLogicalPlanEvalTranslatorTestSuite val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + + test("test IN expr in eval") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval in = a in ('Hello', 'World') | fields in"), + context) + + val in = Alias(In(UnresolvedAttribute("a"), Seq(Literal("Hello"), Literal("World"))), "in")() + val eval = Project(Seq(UnresolvedStar(None), in), UnresolvedRelation(Seq("t"))) + val expectedPlan = Project(Seq(UnresolvedAttribute("in")), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala new file mode 100644 index 000000000..53bb65950 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala @@ -0,0 +1,256 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, Floor, Literal, Multiply, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Window} + +class PPLLogicalPlanEventstatsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test eventstats avg") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = table | eventstats avg(age)"), context) + + val table = UnresolvedRelation(Seq("table")) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), Nil, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test eventstats avg, max, min, count") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count"), + context) + + val table = UnresolvedRelation(Seq("table")) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + Nil, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg by country") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = table | eventstats avg(age) by country"), + context) + + val table = UnresolvedRelation(Seq("table")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), partitionSpec, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by country") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country"), + context) + + val table = UnresolvedRelation(Seq("table")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span"), + context) + + val table = UnresolvedRelation(Seq("table")) + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val partitionSpec = Seq(span) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple eventstats") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age by state, country | eventstats avg(avg_age) as avg_state_age by country"), + context) + + val partitionSpec = Seq( + Alias(UnresolvedAttribute("state"), "state")(), + Alias(UnresolvedAttribute("country"), "country")()) + val table = UnresolvedRelation(Seq("table")) + val avgAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgAgeWindowExprAlias = Alias(avgAgeWindowExpression, "avg_age")() + val windowPlan1 = Window(Seq(avgAgeWindowExprAlias), partitionSpec, Nil, table) + + val countryPartitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgStateAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("avg_age")), isDistinct = false), + WindowSpecDefinition( + countryPartitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgStateAgeWindowExprAlias = Alias(avgStateAgeWindowExpression, "avg_state_age")() + val windowPlan2 = + Window(Seq(avgStateAgeWindowExprAlias), countryPartitionSpec, Nil, windowPlan1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan2) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala new file mode 100644 index 000000000..c14e1f6cf --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala @@ -0,0 +1,709 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, NamedExpression, Not, Subtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project, Union} + +class PPLLogicalPlanFieldSummaryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test fieldsummary with single field includefields(status_code) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala index 20809db95..fe9304e22 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -233,4 +233,15 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) comparePlans(expectedPlan, logPlan, false) } + + test("test IN expr in filter") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | where a in ('Hello', 'World')"), context) + + val in = In(UnresolvedAttribute("a"), Seq(Literal("Hello"), Literal("World"))) + val filter = Filter(in, UnresolvedRelation(Seq("t"))) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..58a6c04b3 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, GeneratorOuter, Literal, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, GlobalLimit, LocalLimit, Project, Sort} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanFlattenCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test flatten only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | flatten field_with_array"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("field_with_array")) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), true, None, seq(), relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val context = new CatalystPlanContext + val query = + "source = relation | fields state, company, employee | flatten employee | fields state, company, salary | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val projectStateCompanyEmployee = + Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("employee")), + table) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + projectStateCompanyEmployee) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val projectStateCompanySalary = Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("salary")), + dropSourceColumn) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + projectStateCompanySalary) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and eval") { + val context = new CatalystPlanContext + val query = "source = relation | flatten employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and parse and flatten") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | flatten employee | parse description '(?.+@.+)' | flatten roles"), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumnEmployee = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generateEmployee) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + dropSourceColumnEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 58c1a8d12..3ceff7735 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -11,9 +11,9 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Descending, EqualTo, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, JoinHint, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias} class PPLLogicalPlanJoinTranslatorTestSuite extends SparkFunSuite @@ -341,4 +341,228 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test inner join with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1| JOIN left = l right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 and name = 'abc' + | | fields id, name + | | sort id + | | head 10 + | ] + | | stats count(id) as cnt by type + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", table1) + val rightSubquery = + GlobalLimit( + Literal(10), + LocalLimit( + Literal(10), + Sort( + Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), + global = true, + Project( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + Filter( + And( + GreaterThan(UnresolvedAttribute("id"), Literal(10)), + EqualTo(UnresolvedAttribute("name"), Literal("abc"))), + table2))))) + val rightPlan = SubqueryAlias("r", rightSubquery) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val groupingExpression = Alias(UnresolvedAttribute("type"), "type")() + val aggregateExpression = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedAttribute("id")), isDistinct = false), + "cnt")() + val aggPlan = + Aggregate(Seq(groupingExpression), Seq(aggregateExpression, groupingExpression), joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test left outer join with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1| LEFT JOIN left = l right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 and name = 'abc' + | | fields id, name + | | sort id + | | head 10 + | ] + | | stats count(id) as cnt by type + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", table1) + val rightSubquery = + GlobalLimit( + Literal(10), + LocalLimit( + Literal(10), + Sort( + Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), + global = true, + Project( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + Filter( + And( + GreaterThan(UnresolvedAttribute("id"), Literal(10)), + EqualTo(UnresolvedAttribute("name"), Literal("abc"))), + table2))))) + val rightPlan = SubqueryAlias("r", rightSubquery) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, LeftOuter, Some(joinCondition), JoinHint.NONE) + val groupingExpression = Alias(UnresolvedAttribute("type"), "type")() + val aggregateExpression = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedAttribute("id")), isDistinct = false), + "cnt")() + val aggPlan = + Aggregate(Seq(groupingExpression), Seq(aggregateExpression, groupingExpression), joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | head 10 + | | inner JOIN left = l,right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 + | ] + | | left JOIN left = l,right = r ON l.name = r.name + | [ + | source = $testTable3 + | | fields id + | ] + | | cross JOIN left = l,right = r + | [ + | source = $testTable4 + | | sort id + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val table4 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test4")) + var leftPlan = SubqueryAlias("l", GlobalLimit(Literal(10), LocalLimit(Literal(10), table1))) + var rightPlan = + SubqueryAlias("r", Filter(GreaterThan(UnresolvedAttribute("id"), Literal(10)), table2)) + val joinCondition1 = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan1 = Join(leftPlan, rightPlan, Inner, Some(joinCondition1), JoinHint.NONE) + leftPlan = SubqueryAlias("l", joinPlan1) + rightPlan = SubqueryAlias("r", Project(Seq(UnresolvedAttribute("id")), table3)) + val joinCondition2 = EqualTo(UnresolvedAttribute("l.name"), UnresolvedAttribute("r.name")) + val joinPlan2 = Join(leftPlan, rightPlan, LeftOuter, Some(joinCondition2), JoinHint.NONE) + leftPlan = SubqueryAlias("l", joinPlan2) + rightPlan = SubqueryAlias( + "r", + Sort(Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), global = true, table4)) + val joinPlan3 = Join(leftPlan, rightPlan, Cross, None, JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test complex join: TPC-H Q13 with relation subquery") { + // select + // c_count, + // count(*) as custdist + // from + // ( + // select + // c_custkey, + // count(o_orderkey) as c_count + // from + // customer left outer join orders on + // c_custkey = o_custkey + // and o_comment not like '%special%requests%' + // group by + // c_custkey + // ) as c_orders + // group by + // c_count + // order by + // custdist desc, + // c_count desc + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | SEARCH source = [ + | SEARCH source = customer + | | LEFT OUTER JOIN left = c right = o ON c_custkey = o_custkey + | [ + | SEARCH source = orders + | | WHERE not like(o_comment, '%special%requests%') + | ] + | | STATS COUNT(o_orderkey) AS c_count BY c_custkey + | ] AS c_orders + | | STATS COUNT(o_orderkey) AS c_count BY c_custkey + | | STATS COUNT(1) AS custdist BY c_count + | | SORT - custdist, - c_count + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val tableC = UnresolvedRelation(Seq("customer")) + val tableO = UnresolvedRelation(Seq("orders")) + val left = SubqueryAlias("c", tableC) + val filterNot = Filter( + Not( + UnresolvedFunction( + Seq("like"), + Seq(UnresolvedAttribute("o_comment"), Literal("%special%requests%")), + isDistinct = false)), + tableO) + val right = SubqueryAlias("o", filterNot) + val joinCondition = + EqualTo(UnresolvedAttribute("o_custkey"), UnresolvedAttribute("c_custkey")) + val join = Join(left, right, LeftOuter, Some(joinCondition), JoinHint.NONE) + val groupingExpression1 = Alias(UnresolvedAttribute("c_custkey"), "c_custkey")() + val aggregateExpressions1 = + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("o_orderkey")), + isDistinct = false), + "c_count")() + val agg3 = + Aggregate(Seq(groupingExpression1), Seq(aggregateExpressions1, groupingExpression1), join) + val subqueryAlias = SubqueryAlias("c_orders", agg3) + val agg2 = + Aggregate( + Seq(groupingExpression1), + Seq(aggregateExpressions1, groupingExpression1), + subqueryAlias) + val groupingExpression2 = Alias(UnresolvedAttribute("c_count"), "c_count")() + val aggregateExpressions2 = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(Literal(1)), isDistinct = false), "custdist")() + val agg1 = + Aggregate(Seq(groupingExpression2), Seq(aggregateExpressions2, groupingExpression2), agg2) + val sort = Sort( + Seq( + SortOrder(UnresolvedAttribute("custdist"), Descending), + SortOrder(UnresolvedAttribute("c_count"), Descending)), + global = true, + agg1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..216c0f232 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanJsonFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test json()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json('[1,2,3,{"f1":1,"f2":[5,6]},4]')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "get_json_object", + Seq(Literal("""[1,2,3,{"f1":1,"f2":[5,6]},4]"""), Literal("$")), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_object") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json(json_object('key', array(1, 2, 3)))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "named_struct", + Seq( + Literal("key"), + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_array()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_array(1, 2, 0, -1, 1.1, -0.11)"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(0), Literal(-1), Literal(1.1), Literal(-0.11)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_object() and json_array()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json(json_object('key', json_array(1, 2, 3)))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "named_struct", + Seq( + Literal("key"), + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_array_length(jsonString)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_array_length('[1,2,3]')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("json_array_length", Seq(Literal("""[1,2,3]""")), isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_array_length(json_array())") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_array_length(json_array(1,2,3))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "json_array_length", + Seq( + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_extract()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[1].b')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "get_json_object", + Seq(Literal("""{"a":[{"b":1},{"b":2}]}"""), Literal("""$.a[1].b""")), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_keys()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_keys('{"f1":"abc","f2":{"f3":"a","f4":"b"}}')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "json_object_keys", + Seq(Literal("""{"f1":"abc","f2":{"f3":"a","f4":"b"}}""")), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("json_valid()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_valid('[1,2,3,{"f1":1,"f2":[5,6]},4]')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "isnotnull", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(Literal("""[1,2,3,{"f1":1,"f2":[5,6]},4]"""), Literal("$")), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..9c3c1c8a0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLambdaFunctionsTranslatorTestSuite.scala @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, LambdaFunction, Literal, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project + +class PPLLogicalPlanLambdaFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test forall()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = forall(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("forall", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test exits()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = exists(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("exists", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test filter()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = filter(a, x -> x > 0)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("filter", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform()") { + val context = new CatalystPlanContext + // test single argument of lambda + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, x -> x + 1)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction("+", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(1)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) + } + + test("test transform() - test binary lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = transform(a, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val lambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias(UnresolvedFunction("transform", Seq(UnresolvedAttribute("a"), lambda), false), "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - without finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test reduce() - with finish lambda") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval a = json_array(1, 2, 3), b = reduce(a, 0, (x, y) -> x + y, x -> x * 10)""".stripMargin), + context) + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "a")() + val mergeLambda = LambdaFunction( + UnresolvedFunction( + "+", + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y"))), + false), + Seq(UnresolvedNamedLambdaVariable(seq("x")), UnresolvedNamedLambdaVariable(seq("y")))) + val finishLambda = LambdaFunction( + UnresolvedFunction("*", Seq(UnresolvedNamedLambdaVariable(seq("x")), Literal(10)), false), + Seq(UnresolvedNamedLambdaVariable(seq("x")))) + val aliasB = + Alias( + UnresolvedFunction( + "reduce", + Seq(UnresolvedAttribute("a"), Literal(0), mergeLambda, finishLambda), + false), + "b")() + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)) + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala new file mode 100644 index 000000000..213f201cc --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DataTypes + +class PPLLogicalPlanParseCidrmatchTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test cidrmatch for ipv4 for 192.168.1.0/24") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("192.168.1.0/24") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(false)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32')"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(false)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + Filter(And(And(filterIpv6, filterIsValid), cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field projected") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true and cidrmatch(ipAddress, '2003:db8::/32') | fields ip"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val expectedPlan = Project( + Seq(UnresolvedAttribute("ip")), + Filter(And(filterIpv6, cidr), UnresolvedRelation(Seq("t")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32 with ip field bool respond for each ip") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where isV6 = true | eval inRange = case(cidrmatch(ipAddress, '2003:db8::/32'), 'in' else 'out') | fields ip, inRange"), + context) + + val ipAddress = UnresolvedAttribute("ipAddress") + val cidrExpression = Literal("2003:db8::/32") + + val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) + val filterClause = Filter(filterIpv6, UnresolvedRelation(Seq("t"))) + val cidr = ScalaUDF( + SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddress, cidrExpression), + seq(), + Option.empty, + Option.apply("cidr"), + false, + true) + + val equalTo = EqualTo(Literal(true), cidr) + val caseFunction = CaseWhen(Seq((equalTo, Literal("in"))), Literal("out")) + val aliasStatusCategory = Alias(caseFunction, "inRange")() + val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory) + val evalProject = Project(evalProjectList, filterClause) + + val expectedPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("inRange")), evalProject) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..d22750ee0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanTrendlineCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test trendline") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sma(3, age)"), context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, table)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age sma(3, age)"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with multiple trendline sma commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort + age sma(2, age) as two_points_sma sma(3, age) | fields name, age, two_points_sma, age_trendline"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageTrendlineField = UnresolvedAttribute("age_trendline") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "age_trendline")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageTrendlineField), + Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/spark-sql-application/src/main/java/org/opensearch/sql/FlintDelegatingSessionCatalog.java b/spark-sql-application/src/main/java/org/opensearch/sql/FlintDelegatingSessionCatalog.java index 6ed9fa980..40f4aee46 100644 --- a/spark-sql-application/src/main/java/org/opensearch/sql/FlintDelegatingSessionCatalog.java +++ b/spark-sql-application/src/main/java/org/opensearch/sql/FlintDelegatingSessionCatalog.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql; import org.apache.spark.sql.SparkSession; diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index cdeebe663..ef0e76557 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -17,7 +17,7 @@ import com.codahale.metrics.Timer import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging -import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.{MetricConstants, ReadWriteBytesSparkListener} import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.apache.spark.SparkConf @@ -314,7 +314,13 @@ object FlintREPL extends Logging with FlintJobExecutor { val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) - var futurePrepareQueryExecution: Future[Either[String, Unit]] = null + + val statementsExecutionManager = + instantiateStatementExecutionManager(commandContext) + + var futurePrepareQueryExecution: Future[Either[String, Unit]] = Future { + statementsExecutionManager.prepareStatementExecution() + } try { logInfo(s"""Executing session with sessionId: ${sessionId}""") @@ -324,12 +330,6 @@ object FlintREPL extends Logging with FlintJobExecutor { var lastCanPickCheckTime = 0L while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - val statementsExecutionManager = - instantiateStatementExecutionManager(commandContext) - - futurePrepareQueryExecution = Future { - statementsExecutionManager.prepareStatementExecution() - } try { val commandState = CommandState( @@ -525,12 +525,16 @@ object FlintREPL extends Logging with FlintJobExecutor { val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val (dataToWrite, returnedVerificationResult) = - processStatementOnVerification( - statementExecutionManager, - queryResultWriter, - flintStatement, - state, - context) + ReadWriteBytesSparkListener.withMetrics( + spark, + () => { + processStatementOnVerification( + statementExecutionManager, + queryResultWriter, + flintStatement, + state, + context) + }) verificationResult = returnedVerificationResult finalizeCommand( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 58d868a2e..6cdbdb16d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -5,7 +5,7 @@ package org.apache.spark.sql -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.{ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future, TimeoutException} @@ -14,7 +14,7 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.common.scheduler.model.LangType -import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil, ReadWriteBytesSparkListener} import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.spark.FlintSpark @@ -70,6 +70,9 @@ case class JobOperator( val statementExecutionManager = instantiateStatementExecutionManager(commandContext, resultIndex, osClient) + val readWriteBytesSparkListener = new ReadWriteBytesSparkListener() + sparkSession.sparkContext.addSparkListener(readWriteBytesSparkListener) + val statement = new FlintStatement( "running", @@ -136,6 +139,10 @@ case class JobOperator( "", startTime)) } finally { + emitQueryExecutionTimeMetric(startTime) + readWriteBytesSparkListener.emitMetrics() + sparkSession.sparkContext.removeSparkListener(readWriteBytesSparkListener) + try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) } catch { @@ -148,11 +155,14 @@ case class JobOperator( statement.error = Some(error) statementExecutionManager.updateStatement(statement) - cleanUpResources(exceptionThrown, threadPool) + cleanUpResources(exceptionThrown, threadPool, startTime) } } - def cleanUpResources(exceptionThrown: Boolean, threadPool: ThreadPoolExecutor): Unit = { + def cleanUpResources( + exceptionThrown: Boolean, + threadPool: ThreadPoolExecutor, + startTime: Long): Unit = { val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) try { // Wait for streaming job complete if no error @@ -195,6 +205,13 @@ case class JobOperator( } } + private def emitQueryExecutionTimeMetric(startTime: Long): Unit = { + MetricsUtil + .addHistoricGauge( + MetricConstants.QUERY_EXECUTION_TIME_METRIC, + System.currentTimeMillis() - startTime) + } + def stop(): Unit = { Try { logInfo("Stopping Spark session") diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index 432d6df11..09a1b3c1e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -6,7 +6,7 @@ package org.apache.spark.sql import org.opensearch.flint.common.model.FlintStatement -import org.opensearch.flint.core.storage.OpenSearchUpdater +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.apache.spark.internal.Logging @@ -29,8 +29,8 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater] // Using one reader client within same session will cause concurrency issue. - // To resolve this move the reader creation and getNextStatement method to mirco-batch level - private val flintReader = createOpenSearchQueryReader() + // To resolve this move the reader creation to getNextStatement method at mirco-batch level + private var currentReader: Option[FlintReader] = None override def prepareStatementExecution(): Either[String, Unit] = { checkAndCreateIndex(osClient, resultIndex) @@ -39,12 +39,17 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement)) } override def terminateStatementExecution(): Unit = { - flintReader.close() + currentReader.foreach(_.close()) + currentReader = None } override def getNextStatement(): Option[FlintStatement] = { - if (flintReader.hasNext) { - val rawStatement = flintReader.next() + if (currentReader.isEmpty) { + currentReader = Some(createOpenSearchQueryReader()) + } + + if (currentReader.get.hasNext) { + val rawStatement = currentReader.get.next() val flintStatement = FlintStatement.deserialize(rawStatement) logInfo(s"Next statement to execute: $flintStatement") Some(flintStatement) @@ -100,7 +105,6 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) | ] | } |}""".stripMargin - val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) - flintReader + osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 5eeccce73..07ed94bdc 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -1387,7 +1387,8 @@ class FlintREPLTest val expectedCalls = Math.ceil(inactivityLimit.toDouble / DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY).toInt - verify(mockOSClient, Mockito.atMost(expectedCalls)).getIndexMetadata(*) + verify(mockOSClient, times(1)).getIndexMetadata(*) + verify(mockOSClient, Mockito.atMost(expectedCalls)).createQueryReader(*, *, *, *) } val testCases = Table(