diff --git a/.github/workflows/test-and-build-workflow.yml b/.github/workflows/test-and-build-workflow.yml index f8d9bd682..e3b2b20f4 100644 --- a/.github/workflows/test-and-build-workflow.yml +++ b/.github/workflows/test-and-build-workflow.yml @@ -25,5 +25,16 @@ jobs: - name: Style check run: sbt scalafmtCheckAll + - name: Set SBT_OPTS + # Needed to extend the JVM memory size to avoid OutOfMemoryError for HTML test report + run: echo "SBT_OPTS=-Xmx2G" >> $GITHUB_ENV + - name: Integ Test run: sbt integtest/integration + + - name: Upload test report + if: always() # Ensures the artifact is saved even if tests fail + uses: actions/upload-artifact@v3 + with: + name: test-reports + path: target/test-reports # Adjust this path if necessary \ No newline at end of file diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index bb8f697ec..834a2a201 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -11,6 +11,11 @@ To execute the unit tests, run the following command: ``` sbt test ``` +To run a specific unit test in SBT, use the testOnly command with the full path of the test class: +``` +sbt "; project pplSparkIntegration; test:testOnly org.opensearch.flint.spark.ppl.PPLLogicalPlanTrendlineCommandTranslatorTestSuite" +``` + ## Integration Test The integration test is defined in the `integration` directory of the project. The integration tests will automatically trigger unit tests and will only run if all unit tests pass. If you want to run the integration test for the project, you can do so by running the following command: @@ -23,6 +28,13 @@ If you get integration test failures with error message "Previous attempts to fi 3. Run `sudo ln -s $HOME/.docker/desktop/docker.sock /var/run/docker.sock` or `sudo ln -s $HOME/.docker/run/docker.sock /var/run/docker.sock` 4. If you use Docker Desktop, as an alternative of `3`, check mark the "Allow the default Docker socket to be used (requires password)" in advanced settings of Docker Desktop. +Running only a selected set of integration test suites is possible with the following command: +``` +sbt "; project integtest; it:testOnly org.opensearch.flint.spark.ppl.FlintSparkPPLTrendlineITSuite" +``` +This command runs only the specified test suite within the integtest submodule. + + ### AWS Integration Test The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` diff --git a/README.md b/README.md index 4c470e98b..db3790e64 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ Please refer to the [Flint Index Reference Manual](./docs/index.md) for more inf * For additional details on Spark PPL commands project, see [PPL Project](https://github.com/orgs/opensearch-project/projects/214/views/2) +* Experiment ppl queries on local spark cluster [PPL on local spark ](docs/ppl-lang/local-spark-ppl-test-instruction.md) + ## Prerequisites Version compatibility: @@ -31,6 +33,7 @@ Version compatibility: | 0.4.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ | | 0.5.0 | 11+ | 3.5.1 | 2.12.14 | 2.17+ | | 0.6.0 | 11+ | 3.5.1 | 2.12.14 | 2.17+ | +| 0.7.0 | 11+ | 3.5.1 | 2.12.14 | 2.17+ | ## Flint Extension Usage @@ -62,7 +65,7 @@ sbt clean standaloneCosmetic/publishM2 ``` then add org.opensearch:opensearch-spark-standalone_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark-standalone_2.12:0.6.0-SNAPSHOT" \ +bin/spark-shell --packages "org.opensearch:opensearch-spark-standalone_2.12:0.7.0-SNAPSHOT" \ --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions" \ --conf "spark.sql.catalog.dev=org.apache.spark.opensearch.catalog.OpenSearchCatalog" ``` @@ -74,14 +77,20 @@ To build and run this PPL in Spark, you can run (requires Java 11): ``` sbt clean sparkPPLCosmetic/publishM2 ``` -then add org.opensearch:opensearch-spark-ppl_2.12 when run spark application, for example, + +Then add org.opensearch:opensearch-spark-ppl_2.12 when run spark application, for example, + ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.6.0-SNAPSHOT" \ +bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.7.0-SNAPSHOT" \ --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintPPLSparkExtensions" \ --conf "spark.sql.catalog.dev=org.apache.spark.opensearch.catalog.OpenSearchCatalog" ``` +### PPL Run queries on a local spark cluster +See ppl usage sample on local spark cluster [PPL on local spark ](docs/ppl-lang/local-spark-ppl-test-instruction.md) + + ## Code of Conduct This project has adopted an [Open Source Code of Conduct](./CODE_OF_CONDUCT.md). diff --git a/build.sbt b/build.sbt index 9254d1ff1..131fb2347 100644 --- a/build.sbt +++ b/build.sbt @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ import Dependencies._ +import sbtassembly.AssemblyPlugin.autoImport.ShadeRule lazy val scala212 = "2.12.14" lazy val sparkVersion = "3.5.1" @@ -21,7 +22,7 @@ val sparkMinorVersion = sparkVersion.split("\\.").take(2).mkString(".") ThisBuild / organization := "org.opensearch" -ThisBuild / version := "0.6.0-SNAPSHOT" +ThisBuild / version := "0.7.0-SNAPSHOT" ThisBuild / scalaVersion := scala212 @@ -43,7 +44,35 @@ lazy val compileScalastyle = taskKey[Unit]("compileScalastyle") // Run as part of test task. lazy val testScalastyle = taskKey[Unit]("testScalastyle") +// Explanation: +// - ThisBuild / assemblyShadeRules sets the shading rules for the entire build +// - ShadeRule.rename(...) creates a rule to rename multiple package patterns +// - "shaded.@0" means prepend "shaded." to the original package name +// - .inAll applies the rule to all dependencies, not just direct dependencies +val packagesToShade = Seq( + "com.amazonaws.cloudwatch.**", + "com.fasterxml.jackson.core.**", + "com.fasterxml.jackson.dataformat.**", + "com.fasterxml.jackson.databind.**", + "com.sun.jna.**", + "com.thoughtworks.paranamer.**", + "javax.annotation.**", + "org.apache.commons.codec.**", + "org.apache.commons.logging.**", + "org.apache.hc.**", + "org.apache.http.**", + "org.glassfish.json.**", + "org.joda.time.**", + "org.reactivestreams.**", + "org.yaml.**", + "software.amazon.**" +) +ThisBuild / assemblyShadeRules := Seq( + ShadeRule.rename( + packagesToShade.map(_ -> "shaded.flint.@0"): _* + ).inAll +) lazy val commonSettings = Seq( javacOptions ++= Seq("-source", "11"), @@ -53,7 +82,11 @@ lazy val commonSettings = Seq( compileScalastyle := (Compile / scalastyle).toTask("").value, Compile / compile := ((Compile / compile) dependsOn compileScalastyle).value, testScalastyle := (Test / scalastyle).toTask("").value, + // Enable HTML report and output to separate folder per package + Test / testOptions += Tests.Argument(TestFrameworks.ScalaTest, "-h", s"target/test-reports/${name.value}"), Test / test := ((Test / test) dependsOn testScalastyle).value, + // Needed for HTML report + libraryDependencies += "com.vladsch.flexmark" % "flexmark-all" % "0.64.8" % "test", dependencyOverrides ++= Seq( "com.fasterxml.jackson.core" % "jackson-core" % jacksonVersion, "com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion @@ -89,6 +122,8 @@ lazy val flintCore = (project in file("flint-core")) "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), "software.amazon.awssdk" % "auth-crt" % "2.28.10", + "com.fasterxml.jackson.core" % "jackson-core" % jacksonVersion, + "com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion, "org.projectlombok" % "lombok" % "1.18.30" % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", @@ -241,7 +276,8 @@ lazy val integtest = (project in file("integ-test")) inConfig(IntegrationTest)(Defaults.testSettings ++ Seq( IntegrationTest / javaSource := baseDirectory.value / "src/integration/java", IntegrationTest / scalaSource := baseDirectory.value / "src/integration/scala", - IntegrationTest / parallelExecution := false, + IntegrationTest / resourceDirectory := baseDirectory.value / "src/integration/resources", + IntegrationTest / parallelExecution := false, IntegrationTest / fork := true, )), inConfig(AwsIntegrationTest)(Defaults.testSettings ++ Seq( diff --git a/docs/img/spark-ui.png b/docs/img/spark-ui.png new file mode 100644 index 000000000..dc2606272 Binary files /dev/null and b/docs/img/spark-ui.png differ diff --git a/docs/index.md b/docs/index.md index e76cb387a..82c147de2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -60,7 +60,7 @@ Currently, Flint metadata is only static configuration without version control a ```json { - "version": "0.6.0", + "version": "0.7.0", "name": "...", "kind": "skipping", "source": "...", @@ -698,7 +698,7 @@ For now, only single or conjunct conditions (conditions connected by AND) in WHE ### AWS EMR Spark Integration - Using execution role Flint use [DefaultAWSCredentialsProviderChain](https://docs.aws.amazon.com/AWSJavaSDK/latest/javadoc/com/amazonaws/auth/DefaultAWSCredentialsProviderChain.html). When running in EMR Spark, Flint use executionRole credentials ``` ---conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.6.0-SNAPSHOT \ +--conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.7.0-SNAPSHOT \ --conf spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots \ --conf spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ @@ -740,7 +740,7 @@ Flint use [DefaultAWSCredentialsProviderChain](https://docs.aws.amazon.com/AWSJa ``` 3. Set the spark.datasource.flint.customAWSCredentialsProvider property with value as com.amazonaws.emr.AssumeRoleAWSCredentialsProvider. Set the environment variable ASSUME_ROLE_CREDENTIALS_ROLE_ARN with the ARN value of CrossAccountRoleB. ``` ---conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.6.0-SNAPSHOT \ +--conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.7.0-SNAPSHOT \ --conf spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots \ --conf spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index e780f688d..7766c3b50 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -50,6 +50,10 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where a < 1 | fields a,b,c` - `source = table | where b != 'test' | fields a,b,c` - `source = table | where c = 'test' | fields a,b,c | head 3` +- `source = table | where c = 'test' AND a = 1 | fields a,b,c` +- `source = table | where c != 'test' OR a > 1 | fields a,b,c` +- `source = table | where (b > 1 OR a > 1) AND c != 'test' | fields a,b,c` +- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - Note: "AND" is optional - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` @@ -61,6 +65,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | trendline sma(2, temperature) as temp_trend` +- `source = table | trendline sort timestamp wma(2, temperature) as temp_trend` #### **IP related queries** [See additional command details](functions/ppl-ip.md) @@ -177,6 +182,7 @@ source = table | where ispresent(a) | - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` - `source = table | stats distinct_count(c)` +- `source = table | stats distinct_count_approx(c)` - `source = table | stats stddev_samp(c)` - `source = table | stats stddev_pop(c)` - `source = table | stats percentile(c, 90)` @@ -202,6 +208,7 @@ source = table | where ispresent(a) | - `source = table | where a < 50 | eventstats avg(c) ` - `source = table | eventstats max(c) by b` - `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats count(c) by b | head 5` - `source = table | eventstats stddev_samp(c)` - `source = table | eventstats stddev_pop(c)` - `source = table | eventstats percentile(c, 90)` @@ -246,12 +253,15 @@ source = table | where ispresent(a) | - `source=accounts | rare gender` - `source=accounts | rare age by gender` +- `source=accounts | rare 5 age by gender` +- `source=accounts | rare_approx age by gender` #### **Top** [See additional command details](ppl-top-command.md) - `source=accounts | top gender` - `source=accounts | top 1 gender` +- `source=accounts | top_approx 5 gender` - `source=accounts | top 1 age by gender` #### **Parse** @@ -306,7 +316,11 @@ source = table | where ispresent(a) | - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` - +- `source = table1 | inner join on table1.a = table2.a table2 | fields table1.a, table2.a, table1.b, table1.c` (directly refer table name) +- `source = table1 | inner join on a = c table2 | fields a, b, c, d` (ignore side aliases as long as no ambiguous) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields l.a, r.a` (side alias overrides table alias) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields t1.a, t2.a` (error, side alias overrides table alias) +- `source = table1 | join left = l right = r on l.a = r.a [ source = table2 ] as s | fields l.a, s.a` (error, side alias overrides subquery alias) #### **Lookup** [See additional command details](ppl-lookup-command.md) @@ -437,8 +451,30 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in _- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ ---- -#### Experimental Commands: + +#### **fillnull** +[See additional command details](ppl-fillnull-command.md) +```sql + - `source=accounts | fillnull fields status_code=101` + - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` + - `source=accounts | fillnull using field1=101` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` +``` + +#### **expand** +[See additional command details](ppl-expand-command.md) +```sql + - `source = table | expand field_with_array as array_list` + - `source = table | expand employee | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` + - `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` + - `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` + - `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` +``` + +#### Correlation Commands: [See additional command details](ppl-correlation-command.md) ```sql @@ -450,14 +486,3 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` > ppl-correlation-command is an experimental command - it may be removed in future versions --- -### Planned Commands: - -#### **fillnull** -[See additional command details](ppl-fillnull-command.md) -```sql - - `source=accounts | fillnull fields status_code=101` - - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` - - `source=accounts | fillnull using field1=101` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` -``` diff --git a/docs/ppl-lang/PPL-on-Spark.md b/docs/ppl-lang/PPL-on-Spark.md index 3b260bd37..1b057572b 100644 --- a/docs/ppl-lang/PPL-on-Spark.md +++ b/docs/ppl-lang/PPL-on-Spark.md @@ -34,7 +34,7 @@ sbt clean sparkPPLCosmetic/publishM2 ``` then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.6.0-SNAPSHOT" +bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.7.0-SNAPSHOT" ``` ### PPL Extension Usage @@ -46,7 +46,7 @@ spark-sql --conf "spark.sql.extensions=org.opensearch.flint.spark.FlintPPLSparkE ``` ### Running With both Flint & PPL Extensions -In order to make use of both flint and ppl extension, one can simply add both jars (`org.opensearch:opensearch-spark-ppl_2.12:0.6.0-SNAPSHOT`,`org.opensearch:opensearch-spark_2.12:0.6.0-SNAPSHOT`) to the cluster's +In order to make use of both flint and ppl extension, one can simply add both jars (`org.opensearch:opensearch-spark-ppl_2.12:0.7.0-SNAPSHOT`,`org.opensearch:opensearch-spark_2.12:0.7.0-SNAPSHOT`) to the cluster's classpath. Next need to configure both extensions : diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index d78f4c030..19e1a6ee0 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -71,6 +71,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`correlation commands`](ppl-correlation-command.md) - [`trendline commands`](ppl-trendline-command.md) + + - [`expand commands`](ppl-expand-command.md) * **Functions** @@ -92,7 +94,7 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`IP Address Functions`](functions/ppl-ip.md) - - [`Lambda Functions`](functions/ppl-lambda.md) + - [`Collection Functions`](functions/ppl-collection) --- ### PPL On Spark @@ -104,6 +106,15 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). ### Example PPL Queries See samples of [PPL queries](PPL-Example-Commands.md) +--- + +### Experiment PPL locally using Spark-Cluster +See ppl usage sample on local spark cluster[PPL on local spark ](local-spark-ppl-test-instruction.md) + +--- +### TPC-H PPL Query Rewriting +See samples of [TPC-H PPL query rewriting](ppl-tpch.md) + --- ### Planned PPL Commands diff --git a/docs/ppl-lang/functions/ppl-lambda.md b/docs/ppl-lang/functions/ppl-collection.md similarity index 57% rename from docs/ppl-lang/functions/ppl-lambda.md rename to docs/ppl-lang/functions/ppl-collection.md index cdb6f9e8f..b98f5f5ca 100644 --- a/docs/ppl-lang/functions/ppl-lambda.md +++ b/docs/ppl-lang/functions/ppl-collection.md @@ -1,4 +1,56 @@ -## Lambda Functions +## PPL Collection Functions + +### `ARRAY` + +**Description** + +`array(...)` Returns an array with the given elements. + +**Argument type:** +- A \ can be any kind of value such as string, number, or boolean. + +**Return type:** ARRAY + +Example: + + os> source=people | eval `array` = array(1, 2, 0, -1, 1.1, -0.11) + fetched rows / total rows = 1/1 + +------------------------------+ + | array | + +------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +------------------------------+ + os> source=people | eval `array` = array(true, false, true, true) + fetched rows / total rows = 1/1 + +------------------------------+ + | array | + +------------------------------+ + | [true, false, true, true] | + +------------------------------+ + + +### `ARRAY_LENGTH` + +**Description** + +`array_length(array)` Returns the number of elements in the outermost array. + +**Argument type:** ARRAY + +ARRAY or JSON_ARRAY object. + +**Return type:** INTEGER + +Example: + + os> source=people | eval `array` = array_length(array(1,2,3,4)), `empty_array` = array_length(array()) + fetched rows / total rows = 1/1 + +---------+---------------+ + | array | empty_array | + +---------+---------------+ + | 4 | 0 | + +---------+---------------+ + ### `FORALL` @@ -14,7 +66,7 @@ Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherw Example: - os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result + os> source=people | eval array = array(1, -1, 2), result = forall(array, x -> x > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -22,7 +74,7 @@ Example: | false | +-----------+ - os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result + os> source=people | eval array = array(1, 3, 2), result = forall(array, x -> x > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -41,7 +93,7 @@ Consider constructing the following array: and perform lambda functions against the nested fields `a` or `b`. See the examples: - os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result + os> source=people | eval array = array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -49,7 +101,7 @@ and perform lambda functions against the nested fields `a` or `b`. See the examp | false | +-----------+ - os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result + os> source=people | eval array = array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -71,7 +123,7 @@ Returns `TRUE` if at least one element in the array satisfies the lambda predica Example: - os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result + os> source=people | eval array = array(1, -1, 2), result = exists(array, x -> x > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -79,7 +131,7 @@ Example: | true | +-----------+ - os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result + os> source=people | eval array = array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -102,7 +154,7 @@ An ARRAY that contains all elements in the input array that satisfy the lambda p Example: - os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result + os> source=people | eval array = array(1, -1, 2), result = filter(array, x -> x > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -110,7 +162,7 @@ Example: | [1, 2] | +-----------+ - os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result + os> source=people | eval array = array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -132,7 +184,7 @@ An ARRAY that contains the result of applying the lambda transform function to e Example: - os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result + os> source=people | eval array = array(1, 2, 3), result = transform(array, x -> x + 1) | fields result fetched rows / total rows = 1/1 +--------------+ | result | @@ -140,7 +192,7 @@ Example: | [2, 3, 4] | +--------------+ - os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result + os> source=people | eval array = array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result fetched rows / total rows = 1/1 +--------------+ | result | @@ -162,7 +214,7 @@ The final result of applying the lambda functions to the start value and the inp Example: - os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result + os> source=people | eval array = array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -170,7 +222,7 @@ Example: | 6 | +-----------+ - os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result + os> source=people | eval array = array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result fetched rows / total rows = 1/1 +-----------+ | result | @@ -178,7 +230,7 @@ Example: | 16 | +-----------+ - os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result + os> source=people | eval array = array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result fetched rows / total rows = 1/1 +-----------+ | result | diff --git a/docs/ppl-lang/functions/ppl-json.md b/docs/ppl-lang/functions/ppl-json.md index 1953e8c70..2c0c0ca67 100644 --- a/docs/ppl-lang/functions/ppl-json.md +++ b/docs/ppl-lang/functions/ppl-json.md @@ -4,11 +4,11 @@ **Description** -`json(value)` Evaluates whether a value can be parsed as JSON. Returns the json string if valid, null otherwise. +`json(value)` Evaluates whether a string can be parsed as JSON format. Returns the string value if valid, null otherwise. -**Argument type:** STRING/JSON_ARRAY/JSON_OBJECT +**Argument type:** STRING -**Return type:** STRING +**Return type:** STRING/NULL A STRING expression of a valid JSON object format. @@ -47,7 +47,7 @@ A StructType expression of a valid JSON object. Example: - os> source=people | eval result = json(json_object('key', 123.45)) | fields result + os> source=people | eval result = json_object('key', 123.45) | fields result fetched rows / total rows = 1/1 +------------------+ | result | @@ -55,7 +55,7 @@ Example: | {"key":123.45} | +------------------+ - os> source=people | eval result = json(json_object('outer', json_object('inner', 123.45))) | fields result + os> source=people | eval result = json_object('outer', json_object('inner', 123.45)) | fields result fetched rows / total rows = 1/1 +------------------------------+ | result | @@ -81,13 +81,13 @@ Example: os> source=people | eval `json_array` = json_array(1, 2, 0, -1, 1.1, -0.11) fetched rows / total rows = 1/1 - +----------------------------+ - | json_array | - +----------------------------+ - | 1.0,2.0,0.0,-1.0,1.1,-0.11 | - +----------------------------+ + +------------------------------+ + | json_array | + +------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +------------------------------+ - os> source=people | eval `json_array_object` = json(json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11))) + os> source=people | eval `json_array_object` = json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11)) fetched rows / total rows = 1/1 +----------------------------------------+ | json_array_object | @@ -95,15 +95,49 @@ Example: | {"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]} | +----------------------------------------+ +**Limitation** + +The list of parameters of `json_array` should all be the same type. +`json_array('this', 'is', 1.1, -0.11, true, false)` throws exception. + +### `TO_JSON_STRING` + +**Description** + +`to_json_string(jsonObject)` Returns a JSON string with a given json object value. + +**Argument type:** JSON_OBJECT (Spark StructType/ArrayType) + +**Return type:** STRING + +Example: + + os> source=people | eval `json_string` = to_json_string(json_array(1, 2, 0, -1, 1.1, -0.11)) | fields json_string + fetched rows / total rows = 1/1 + +--------------------------------+ + | json_string | + +--------------------------------+ + | [1.0,2.0,0.0,-1.0,1.1,-0.11] | + +--------------------------------+ + + os> source=people | eval `json_string` = to_json_string(json_object('key', 123.45)) | fields json_string + fetched rows / total rows = 1/1 + +-----------------+ + | json_string | + +-----------------+ + | {'key', 123.45} | + +-----------------+ + + ### `JSON_ARRAY_LENGTH` **Description** -`json_array_length(jsonArray)` Returns the number of elements in the outermost JSON array. +`json_array_length(jsonArrayString)` Returns the number of elements in the outermost JSON array string. -**Argument type:** STRING/JSON_ARRAY +**Argument type:** STRING -A STRING expression of a valid JSON array format, or JSON_ARRAY object. +A STRING expression of a valid JSON array format. **Return type:** INTEGER @@ -119,13 +153,6 @@ Example: | 4 | 5 | null | +-----------+-----------+-------------+ - os> source=people | eval `json_array` = json_array_length(json_array(1,2,3,4)), `empty_array` = json_array_length(json_array()) - fetched rows / total rows = 1/1 - +--------------+---------------+ - | json_array | empty_array | - +--------------+---------------+ - | 4 | 0 | - +--------------+---------------+ ### `JSON_EXTRACT` @@ -235,3 +262,189 @@ Example: |------------------+---------| | 13 | null | +------------------+---------+ + +### `FORALL` + +**Description** + +`forall(json_array, lambda)` Evaluates whether a lambda predicate holds for all elements in the json_array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if all elements in the array satisfy the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(1, 3, 2), result = forall(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + +**Note:** The lambda expression can access the nested fields of the array elements. This applies to all lambda functions introduced in this document. + +Consider constructing the following array: + + array = [ + {"a":1, "b":1}, + {"a":-1, "b":2} + ] + +and perform lambda functions against the nested fields `a` or `b`. See the examples: + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.a > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + os> source=people | eval array = json_array(json_object("a", 1, "b", 1), json_object("a" , -1, "b", 2)), result = forall(array, x -> x.b > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + +### `EXISTS` + +**Description** + +`exists(json_array, lambda)` Evaluates whether a lambda predicate holds for one or more elements in the json_array. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** BOOLEAN + +Returns `TRUE` if at least one element in the array satisfies the lambda predicate, otherwise `FALSE`. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | true | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = exists(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | false | + +-----------+ + + +### `FILTER` + +**Description** + +`filter(json_array, lambda)` Filters the input json_array using the given lambda function. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains all elements in the input json_array that satisfy the lambda predicate. + +Example: + + os> source=people | eval array = json_array(1, -1, 2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [1, 2] | + +-----------+ + + os> source=people | eval array = json_array(-1, -3, -2), result = filter(array, x -> x > 0) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | [] | + +-----------+ + +### `TRANSFORM` + +**Description** + +`transform(json_array, lambda)` Transform elements in a json_array using the lambda transform function. The second argument implies the index of the element if using binary lambda function. This is similar to a `map` in functional programming. + +**Argument type:** ARRAY, LAMBDA + +**Return type:** ARRAY + +An ARRAY that contains the result of applying the lambda transform function to each element in the input array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, x -> x + 1) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [2, 3, 4] | + +--------------+ + + os> source=people | eval array = json_array(1, 2, 3), result = transform(array, (x, i) -> x + i) | fields result + fetched rows / total rows = 1/1 + +--------------+ + | result | + +--------------+ + | [1, 3, 5] | + +--------------+ + +### `REDUCE` + +**Description** + +`reduce(json_array, start, merge_lambda, finish_lambda)` Applies a binary merge lambda function to a start value and all elements in the json_array, and reduces this to a single state. The final state is converted into the final result by applying a finish lambda function. + +**Argument type:** ARRAY, ANY, LAMBDA, LAMBDA + +**Return type:** ANY + +The final result of applying the lambda functions to the start value and the input json_array. + +Example: + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 6 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 10, (acc, x) -> acc + x) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 16 | + +-----------+ + + os> source=people | eval array = json_array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | fields result + fetched rows / total rows = 1/1 + +-----------+ + | result | + +-----------+ + | 60 | + +-----------+ diff --git a/docs/ppl-lang/local-spark-ppl-test-instruction.md b/docs/ppl-lang/local-spark-ppl-test-instruction.md new file mode 100644 index 000000000..537ac043b --- /dev/null +++ b/docs/ppl-lang/local-spark-ppl-test-instruction.md @@ -0,0 +1,336 @@ +# Testing PPL using local Spark + +## Produce the PPL artifact +The first step would be to produce the spark-ppl artifact: `sbt clean sparkPPLCosmetic/assembly` + +The resulting artifact would be located in the project's build directory: +```sql +[info] Built: ./opensearch-spark/sparkPPLCosmetic/target/scala-2.12/opensearch-spark-ppl-assembly-x.y.z-SNAPSHOT.jar +``` +## Downloading spark 3.5.3 version +Download spark from the [official website](https://spark.apache.org/downloads.html) and install locally. + +## Start Spark with the plugin +Once installed, run spark with the generated PPL artifact: +```shell +bin/spark-sql --jars "/PATH_TO_ARTIFACT/opensearch-spark-ppl-assembly-x.y.z-SNAPSHOT.jar" \ +--conf "spark.sql.extensions=org.opensearch.flint.spark.FlintPPLSparkExtensions" \ +--conf "spark.sql.catalog.dev=org.apache.spark.opensearch.catalog.OpenSearchCatalog" \ +--conf "spark.hadoop.hive.cli.print.header=true" + +WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable +Setting default log level to "WARN". +To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). +WARN HiveConf: HiveConf of name hive.stats.jdbc.timeout does not exist +WARN HiveConf: HiveConf of name hive.stats.retries.wait does not exist +WARN ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0 +WARN ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore +Spark Web UI available at http://*.*.*.*:4040 +Spark master: local[*], Application Id: local-1731523264660 + +spark-sql (default)> +``` +The resulting would be a spark-sql prompt: `spark-sql (default)> ...` + +### Spark UI Html +One can also explore spark's UI portal to examine the execution jobs and how they are performing: + +![Spark-UX](../img/spark-ui.png) + + +### Configuring hive partition mode +For simpler configuration of partitioned tables, use the following non-strict mode: + +```shell +spark-sql (default)> SET hive.exec.dynamic.partition.mode = nonstrict; +``` + +--- + +# Testing PPL Commands + +In order to test ppl commands using the spark-sql command line - create and populate the following set of tables: + +## emails table +```sql +CREATE TABLE emails (name STRING, age INT, email STRING, street_address STRING, year INT, month INT) PARTITIONED BY (year, month); +INSERT INTO emails (name, age, email, street_address, year, month) VALUES ('Alice', 30, 'alice@example.com', '123 Main St, Seattle', 2023, 4), ('Bob', 55, 'bob@test.org', '456 Elm St, Portland', 2023, 5), ('Charlie', 65, 'charlie@domain.net', '789 Pine St, San Francisco', 2023, 4), ('David', 19, 'david@anotherdomain.com', '101 Maple St, New York', 2023, 5), ('Eve', 21, 'eve@examples.com', '202 Oak St, Boston', 2023, 4), ('Frank', 76, 'frank@sample.org', '303 Cedar St, Austin', 2023, 5), ('Grace', 41, 'grace@demo.net', '404 Birch St, Chicago', 2023, 4), ('Hank', 32, 'hank@demonstration.com', '505 Spruce St, Miami', 2023, 5), ('Ivy', 9, 'ivy@examples.com', '606 Fir St, Denver', 2023, 4), ('Jack', 12, 'jack@sample.net', '707 Ash St, Seattle', 2023, 5); +``` + +Now one can run the following ppl commands to test functionality: + +### Test `describe` command + +```sql +describe emails; + +col_name data_type comment +name string +age int +email string +street_address string +year int +month int +# Partition Information +# col_name data_type comment +year int +month int + +# Detailed Table Information +Catalog spark_catalog +Database default +Table emails +Owner USER +Created Time Wed Nov 13 14:45:12 MST 2024 +Last Access UNKNOWN +Created By Spark 3.5.3 +Type MANAGED +Provider hive +Table Properties [transient_lastDdlTime=1731534312] +Location file:/Users/USER/tools/spark-3.5.3-bin-hadoop3/bin/spark-warehouse/emails +Serde Library org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe +InputFormat org.apache.hadoop.mapred.TextInputFormat +OutputFormat org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat +Storage Properties [serialization.format=1] +Partition Provider Catalog + +Time taken: 0.128 seconds, Fetched 28 row(s) +``` + +### Test `grok` command +```sql +source=emails| grok email '.+@%{HOSTNAME:host}' | fields email, host; + +email host +hank@demonstration.com demonstration.com +bob@test.org test.org +jack@sample.net sample.net +frank@sample.org sample.org +david@anotherdomain.com anotherdomain.com +grace@demo.net demo.net +alice@example.com example.com +ivy@examples.com examples.com +eve@examples.com examples.com +charlie@domain.net domain.net + +Time taken: 0.626 seconds, Fetched 10 row(s) +``` + +```sql + source=emails| parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host; + +age email host +76 frank@sample.org sample.org +65 charlie@domain.net domain.net +55 bob@test.org test.org + +Time taken: 1.555 seconds, Fetched 3 row(s) +``` + +### Test `grok` | `top` commands combination +```sql +source=emails| grok email '.+@%{HOSTNAME:host}' | fields email, host | top 3 host; + +count_host host +2 examples.com +1 demonstration.com +1 test.org + +Time taken: 1.274 seconds, Fetched 3 row(s) +``` + +### Test `fieldsummary` command + +```sql +source=emails| fieldsummary includefields=age, email; + +Field COUNT DISTINCT MIN MAX AVG MEAN STDDEV Nulls TYPEOF +age 10 10 9 76 36.0 36.0 22.847319317591726 0 int +email 10 10 alice@example.com jack@sample.net NULL NULL NULL 0 string + +Time taken: 1.535 seconds, Fetched 2 row(s) +``` + +### Test `trendline` command + +```sql +source=email | sort - age | trendline sma(2, age); + +name age email street_address year month age_trendline +Frank 76 frank@sample.org 303 Cedar St, Austin 2023 5 NULL +Charlie 65 charlie@domain.net 789 Pine St, San Francisco 2023 4 70.5 +Bob 55 bob@test.org 456 Elm St, Portland 2023 5 60.0 +Grace 41 grace@demo.net 404 Birch St, Chicago 2023 4 48.0 +Hank 32 hank@demonstration.com 505 Spruce St, Miami 2023 5 36.5 +Alice 30 alice@example.com 123 Main St, Seattle 2023 4 31.0 +Eve 21 eve@examples.com 202 Oak St, Boston 2023 4 25.5 +David 19 david@anotherdomain.com 101 Maple St, New York 2023 5 20.0 +Jack 12 jack@sample.net 707 Ash St, Seattle 2023 5 15.5 +Ivy 9 ivy@examples.com 606 Fir St, Denver 2023 4 10.5 + +Time taken: 1.048 seconds, Fetched 10 row(s) +``` +### Test `expand` command + +```sql + +source=emails | eval array=json_array(1, 2 ) | expand array as uid | fields uid, name, age, email; + +uid name age email +1 Hank 32 hank@demonstration.com +2 Hank 32 hank@demonstration.com +1 Bob 55 bob@test.org +2 Bob 55 bob@test.org +1 Jack 12 jack@sample.net +2 Jack 12 jack@sample.net +1 Frank 76 frank@sample.org +2 Frank 76 frank@sample.org +1 David 19 david@anotherdomain.com +2 David 19 david@anotherdomain.com +1 Grace 41 grace@demo.net +2 Grace 41 grace@demo.net +1 Alice 30 alice@example.com +2 Alice 30 alice@example.com +1 Ivy 9 ivy@examples.com +2 Ivy 9 ivy@examples.com +1 Eve 21 eve@examples.com +2 Eve 21 eve@examples.com +1 Charlie 65 charlie@domain.net +2 Charlie 65 charlie@domain.net + +Time taken: 0.495 seconds, Fetched 20 row(s) +``` + +## nested table + +```sql +CREATE TABLE nested (int_col INT, struct_col STRUCT, field2: INT>, struct_col2 STRUCT, field2: INT>) USING JSON; +INSERT INTO nested SELECT /*+ COALESCE(1) */ * from VALUES ( 30, STRUCT(STRUCT("value1"),123), STRUCT(STRUCT("valueA"),23) ), ( 40, STRUCT(STRUCT("value5"),123), STRUCT(STRUCT("valueB"),33) ), ( 30, STRUCT(STRUCT("value4"),823), STRUCT(STRUCT("valueC"),83) ), ( 40, STRUCT(STRUCT("value2"),456), STRUCT(STRUCT("valueD"),46) ), ( 50, STRUCT(STRUCT("value3"),789), STRUCT(STRUCT("valueE"),89) ); +``` + +### Test `flatten` command + +```sql +source=nested | flatten struct_col | flatten field1 | flatten struct_col2; + +int_col field2 subfield field1 field2 +30 123 value1 {"subfield":"valueA"} 23 +40 123 value5 {"subfield":"valueB"} 33 +30 823 value4 {"subfield":"valueC"} 83 +40 456 value2 {"subfield":"valueD"} 46 +50 789 value3 {"subfield":"valueE"} 89 +30 123 value1 {"subfield":"valueA"} 23 + +Time taken: 2.682 seconds, Fetched 6 row(s) +``` + +```sql +source=nested| where struct_col.field2 > 200 | sort - struct_col.field2 | fields int_col, struct_col.field2; + +int_col field2 +30 823 +50 789 +40 456 + +Time taken: 0.722 seconds, Fetched 3 row(s) +``` + +## array table + +```sql +CREATE TABLE arrayTable (int_col INT, multi_valueA ARRAY>, multi_valueB ARRAY>) USING JSON; +INSERT INTO arrayTable VALUES (1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)), array(STRUCT("2_Monday", 2), null)), (2, array(STRUCT("2_Monday", 2), null), array(STRUCT("3_third", 3), STRUCT("3_4th", 4))), (3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)), array(STRUCT("1_one", 1))), (4, null, array(STRUCT("1_one", 1))); +``` + +### Test `expand` command + +```sql +source=arrayTable | expand multi_valueA as multiA | expand multi_valueB as multiB; + +int_col multiA multiB +1 {"name":"1_one","value":1} {"name":"2_Monday","value":2} +1 {"name":"1_one","value":1} NULL +1 {"name":null,"value":11} {"name":"2_Monday","value":2} +1 {"name":null,"value":11} NULL +1 {"name":"1_three","value":null} {"name":"2_Monday","value":2} +1 {"name":"1_three","value":null} NULL +2 {"name":"2_Monday","value":2} {"name":"3_third","value":3} +2 {"name":"2_Monday","value":2} {"name":"3_4th","value":4} +2 NULL {"name":"3_third","value":3} +2 NULL {"name":"3_4th","value":4} +3 {"name":"3_third","value":3} {"name":"1_one","value":1} +3 {"name":"3_4th","value":4} {"name":"1_one","value":1} + +Time taken: 0.173 seconds, Fetched 12 row(s) +``` + +### Test `expand` | `flattern` command combination + +```sql +source=arrayTable | flatten multi_valueA | expand multi_valueB; + +int_col multi_valueB name value col +1 [{"name":"2_Monday","value":2},null] 1_one 1 {"name":"2_Monday","value":2} +1 [{"name":"2_Monday","value":2},null] 1_one 1 NULL +1 [{"name":"2_Monday","value":2},null] NULL 11 {"name":"2_Monday","value":2} +1 [{"name":"2_Monday","value":2},null] NULL 11 NULL +1 [{"name":"2_Monday","value":2},null] 1_three NULL {"name":"2_Monday","value":2} +1 [{"name":"2_Monday","value":2},null] 1_three NULL NULL +2 [{"name":"3_third","value":3},{"name":"3_4th","value":4}] 2_Monday 2 {"name":"3_third","value":3} +2 [{"name":"3_third","value":3},{"name":"3_4th","value":4}] 2_Monday 2 {"name":"3_4th","value":4} +2 [{"name":"3_third","value":3},{"name":"3_4th","value":4}] NULL NULL {"name":"3_third","value":3} +2 [{"name":"3_third","value":3},{"name":"3_4th","value":4}] NULL NULL {"name":"3_4th","value":4} +3 [{"name":"1_one","value":1}] 3_third 3 {"name":"1_one","value":1} +3 [{"name":"1_one","value":1}] 3_4th 4 {"name":"1_one","value":1} +4 [{"name":"1_one","value":1}] NULL NULL {"name":"1_one","value":1} + +Time taken: 0.12 seconds, Fetched 13 row(s) +``` +### Test `fillnull` | `flattern` command combination + +```sql +source=arrayTable | flatten multi_valueA | fillnull with '1_zero' in name; + +int_col multi_valueB value name +1 [{"name":"2_Monday","value":2},null] 1 1_one +1 [{"name":"2_Monday","value":2},null] 11 1_zero +1 [{"name":"2_Monday","value":2},null] NULL 1_three +2 [{"name":"3_third","value":3},{"name":"3_4th","value":4}] 2 2_Monday +2 [{"name":"3_third","value":3},{"name":"3_4th","value":4}] NULL 1_zero +3 [{"name":"1_one","value":1}] 3 3_third +3 [{"name":"1_one","value":1}] 4 3_4th +4 [{"name":"1_one","value":1}] NULL 1_zero + +Time taken: 0.111 seconds, Fetched 8 row(s) +``` +## ip table + +```sql +CREATE TABLE ipTable ( id INT,ipAddress STRING,isV6 BOOLEAN, isValid BOOLEAN) using csv OPTIONS (header 'false',delimiter '\\t'); +INSERT INTO ipTable values (1, '127.0.0.1', false, true), (2, '192.168.1.0', false, true),(3, '192.168.1.1', false, true),(4, '192.168.2.1', false, true), (5, '192.168.2.', false, false),(6, '2001:db8::ff00:12:3455', true, true),(7, '2001:db8::ff00:12:3456', true, true),(8, '2001:db8::ff00:13:3457', true, true), (9, '2001:db8::ff00:12:', true, false); +``` + +### Test `cidr` command + +```sql +source=ipTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24'); + +id ipAddress isV6 isValid +2 192.168.1.0 false true +3 192.168.1.1 false true + +Time taken: 0.317 seconds, Fetched 2 row(s) +``` + +```sql +source=ipTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2001:db8::/32'); + +id ipAddress isV6 isValid +6 2001:db8::ff00:12:3455 true true +8 2001:db8::ff00:13:3457 true true +7 2001:db8::ff00:12:3456 true true + +Time taken: 0.09 seconds, Fetched 3 row(s) +``` + +--- \ No newline at end of file diff --git a/docs/ppl-lang/ppl-correlation-command.md b/docs/ppl-lang/ppl-correlation-command.md index 2e8507a14..74e04da86 100644 --- a/docs/ppl-lang/ppl-correlation-command.md +++ b/docs/ppl-lang/ppl-correlation-command.md @@ -1,4 +1,4 @@ -## PPL Correlation Command +## PPL `correlation` command > This is an experimental command - it may be removed in future versions diff --git a/docs/ppl-lang/ppl-dedup-command.md b/docs/ppl-lang/ppl-dedup-command.md index 28fe7f4a4..4e06d275e 100644 --- a/docs/ppl-lang/ppl-dedup-command.md +++ b/docs/ppl-lang/ppl-dedup-command.md @@ -1,6 +1,6 @@ -# PPL dedup command +## PPL `dedup` command -## Table of contents +### Table of contents - [Description](#description) - [Syntax](#syntax) @@ -11,11 +11,11 @@ - [Example 4: Dedup in consecutive document](#example-4-dedup-in-consecutive-document) - [Limitation](#limitation) -## Description +### Description Using `dedup` command to remove identical document defined by field from the search result. -## Syntax +### Syntax ```sql dedup [int] [keepempty=] [consecutive=] diff --git a/docs/ppl-lang/ppl-eval-command.md b/docs/ppl-lang/ppl-eval-command.md index 1908c087c..e98d4d4f2 100644 --- a/docs/ppl-lang/ppl-eval-command.md +++ b/docs/ppl-lang/ppl-eval-command.md @@ -1,10 +1,10 @@ -# PPL `eval` command +## PPL `eval` command -## Description +### Description The ``eval`` command evaluate the expression and append the result to the search result. -## Syntax +### Syntax ```sql eval = ["," = ]... ``` diff --git a/docs/ppl-lang/ppl-expand-command.md b/docs/ppl-lang/ppl-expand-command.md new file mode 100644 index 000000000..144c0aafa --- /dev/null +++ b/docs/ppl-lang/ppl-expand-command.md @@ -0,0 +1,45 @@ +## PPL `expand` command + +### Description +Using `expand` command to flatten a field of type: +- `Array` +- `Map` + + +### Syntax +`expand [As alias]` + +* field: to be expanded (exploded). The field must be of supported type. +* alias: Optional to be expanded as the name to be used instead of the original field name + +### Usage Guidelines +The expand command produces a row for each element in the specified array or map field, where: +- Array elements become individual rows. +- Map key-value pairs are broken into separate rows, with each key-value represented as a row. + +- When an alias is provided, the exploded values are represented under the alias instead of the original field name. +- This can be used in combination with other commands, such as stats, eval, and parse to manipulate or extract data post-expansion. + +### Examples: +- `source = table | expand employee | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` +- `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` +- `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` +- `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` + +- Expand command can be used in combination with other commands such as `eval`, `stats` and more +- Using multiple expand commands will create a cartesian product of all the internal elements within each composite array or map + +### Effective SQL push-down query +The expand command is translated into an equivalent SQL operation using LATERAL VIEW explode, allowing for efficient exploding of arrays or maps at the SQL query level. + +```sql +SELECT customer exploded_productId +FROM table +LATERAL VIEW explode(productId) AS exploded_productId +``` +Where the `explode` command offers the following functionality: +- it is a column operation that returns a new column +- it creates a new row for every element in the exploded column +- internal `null`s are ignored as part of the exploded field (no row is created/exploded for null) diff --git a/docs/ppl-lang/ppl-fields-command.md b/docs/ppl-lang/ppl-fields-command.md index e37fc644f..4ef041ee2 100644 --- a/docs/ppl-lang/ppl-fields-command.md +++ b/docs/ppl-lang/ppl-fields-command.md @@ -1,12 +1,12 @@ ## PPL `fields` command -**Description** +### Description Using ``field`` command to keep or remove fields from the search result. -**Syntax** +### Syntax -field [+|-] +`field [+|-] ` * index: optional. if the plus (+) is used, only the fields specified in the field list will be keep. if the minus (-) is used, all the fields specified in the field list will be removed. **Default** + * field list: mandatory. comma-delimited keep or remove fields. diff --git a/docs/ppl-lang/ppl-fieldsummary-command.md b/docs/ppl-lang/ppl-fieldsummary-command.md index 468c2046b..2015cf815 100644 --- a/docs/ppl-lang/ppl-fieldsummary-command.md +++ b/docs/ppl-lang/ppl-fieldsummary-command.md @@ -1,11 +1,11 @@ ## PPL `fieldsummary` command -**Description** +### Description Using `fieldsummary` command to : - Calculate basic statistics for each field (count, distinct count, min, max, avg, stddev, mean ) - Determine the data type of each field -**Syntax** +### Syntax `... | fieldsummary (nulls=true/false)` diff --git a/docs/ppl-lang/ppl-grok-command.md b/docs/ppl-lang/ppl-grok-command.md index 06028109b..8d5946563 100644 --- a/docs/ppl-lang/ppl-grok-command.md +++ b/docs/ppl-lang/ppl-grok-command.md @@ -1,4 +1,4 @@ -## PPL Correlation Command +## PPL `grok` command ### Description diff --git a/docs/ppl-lang/ppl-head-command.md b/docs/ppl-lang/ppl-head-command.md index e4172b1c6..51a87db3b 100644 --- a/docs/ppl-lang/ppl-head-command.md +++ b/docs/ppl-lang/ppl-head-command.md @@ -1,4 +1,4 @@ -## PPL `head` Command +## PPL `head` command **Description** The ``head`` command returns the first N number of specified results after an optional offset in search order. diff --git a/docs/ppl-lang/ppl-join-command.md b/docs/ppl-lang/ppl-join-command.md index 525373f7c..f04f1c5c1 100644 --- a/docs/ppl-lang/ppl-join-command.md +++ b/docs/ppl-lang/ppl-join-command.md @@ -1,10 +1,115 @@ -## PPL Join Command +## PPL `join` command -## Overview +### Description -[Trace analytics](https://opensearch.org/docs/latest/observability-plugin/trace/ta-dashboards/) considered using SQL/PPL for its queries, but some graphs rely on joining two indices (span index and service map index) together which is not supported by SQL/PPL. Trace analytics was implemented with DSL + javascript, would be good if `join` being added to SQL could support this use case. +`JOIN` command combines two datasets together. The left side could be an index or results from a piped commands, the right side could be either an index or a subquery. -### Schema +### Syntax + +`[joinType] join [leftAlias] [rightAlias] [joinHints] on ` + +**joinType** +- Syntax: `[INNER] | LEFT [OUTER] | RIGHT [OUTER] | FULL [OUTER] | CROSS | [LEFT] SEMI | [LEFT] ANTI` +- Optional +- Description: The type of join to perform. The default is `INNER` if not specified. + +**leftAlias** +- Syntax: `left = ` +- Optional +- Description: The subquery alias to use with the left join side, to avoid ambiguous naming. + +**rightAlias** +- Syntax: `right = ` +- Optional +- Description: The subquery alias to use with the right join side, to avoid ambiguous naming. + +**joinHints** +- Syntax: `[hint.left.key1 = value1 hint.right.key2 = value2]` +- Optional +- Description: Zero or more space-separated join hints in the form of `Key` = `Value`. The key must start with `hint.left.` or `hint.right.` + +**joinCriteria** +- Syntax: `` +- Required +- Description: The syntax starts with `ON`. It could be any comparison expression. Generally, the join criteria looks like `.=.`. For example: `l.id = r.id`. If the join criteria contains multiple conditions, you can specify `AND` and `OR` operator between each comparison expression. For example, `l.id = r.id AND l.email = r.email AND (r.age > 65 OR r.age < 18)`. + +**right-dataset** +- Required +- Description: Right dataset could be either an index or a subquery with/without alias. + +### Example 1: two indices join + +PPL query: + + os> source=customer | join ON c_custkey = o_custkey orders + | fields c_custkey, c_nationkey, c_mktsegment, o_orderkey, o_orderstatus, o_totalprice | head 10 + fetched rows / total rows = 10/10 + +----------+-------------+-------------+------------+---------------+-------------+ + | c_custkey| c_nationkey | c_mktsegment| o_orderkey | o_orderstatus | o_totalprice| + +----------+-------------+-------------+------------+---------------+-------------+ + | 36901 | 13 | AUTOMOBILE | 1 | O | 173665.47 | + | 78002 | 10 | AUTOMOBILE | 2 | O | 46929.18 | + | 123314 | 15 | MACHINERY | 3 | F | 193846.25 | + | 136777 | 10 | HOUSEHOLD | 4 | O | 32151.78 | + | 44485 | 20 | FURNITURE | 5 | F | 144659.2 | + | 55624 | 7 | AUTOMOBILE | 6 | F | 58749.59 | + | 39136 | 5 | FURNITURE | 7 | O | 252004.18 | + | 130057 | 9 | FURNITURE | 32 | O | 208660.75 | + | 66958 | 18 | MACHINERY | 33 | F | 163243.98 | + | 61001 | 3 | FURNITURE | 34 | O | 58949.67 | + +----------+-------------+-------------+------------+---------------+-------------+ + +### Example 2: three indices join + +PPL query: + + os> source=customer | join ON c_custkey = o_custkey orders | join ON c_nationkey = n_nationkey nation + | fields c_custkey, c_mktsegment, o_orderkey, o_orderstatus, o_totalprice, n_name | head 10 + fetched rows / total rows = 10/10 + +----------+-------------+------------+---------------+-------------+--------------+ + | c_custkey| c_mktsegment| o_orderkey | o_orderstatus | o_totalprice| n_name | + +----------+-------------+------------+---------------+-------------+--------------+ + | 36901 | AUTOMOBILE | 1 | O | 173665.47 | JORDAN | + | 78002 | AUTOMOBILE | 2 | O | 46929.18 | IRAN | + | 123314 | MACHINERY | 3 | F | 193846.25 | MOROCCO | + | 136777 | HOUSEHOLD | 4 | O | 32151.78 | IRAN | + | 44485 | FURNITURE | 5 | F | 144659.2 | SAUDI ARABIA | + | 55624 | AUTOMOBILE | 6 | F | 58749.59 | GERMANY | + | 39136 | FURNITURE | 7 | O | 252004.18 | ETHIOPIA | + | 130057 | FURNITURE | 32 | O | 208660.75 | INDONESIA | + | 66958 | MACHINERY | 33 | F | 163243.98 | CHINA | + | 61001 | FURNITURE | 34 | O | 58949.67 | CANADA | + +----------+-------------+------------+---------------+-------------+--------------+ + +### Example 3: join a subquery in right side + +PPL query: + + os>source=supplier| join right = revenue0 ON s_suppkey = supplier_no + [ + source=lineitem | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] + | fields s_name, s_phone, total_revenue, supplier_no | head 10 + fetched rows / total rows = 10/10 + +---------------------+----------------+-------------------+-------------+ + | s_name | s_phone | total_revenue | supplier_no | + +---------------------+----------------+-------------------+-------------+ + | Supplier#000007747 | 24-911-546-3505| 636204.0279 | 7747 | + | Supplier#000007748 | 29-535-184-2277| 538311.8099 | 7748 | + | Supplier#000007749 | 18-225-478-7489| 743462.4473000001 | 7749 | + | Supplier#000007750 | 28-680-484-7044| 616828.2220999999 | 7750 | + | Supplier#000007751 | 20-990-606-7343| 1092975.1925 | 7751 | + | Supplier#000007752 | 12-936-258-6650| 1090399.9666 | 7752 | + | Supplier#000007753 | 22-394-329-1153| 777130.7457000001 | 7753 | + | Supplier#000007754 | 26-941-591-5320| 866600.0501 | 7754 | + | Supplier#000007755 | 32-138-467-4225| 702256.7030000001 | 7755 | + | Supplier#000007756 | 29-860-205-8019| 1304979.0511999999| 7756 | + +---------------------+----------------+-------------------+-------------+ + +### Example 4: complex example in OTEL + +**Schema** There will be at least 2 indices, `otel-v1-apm-span-*` (large) and `otel-v1-apm-service-map` (small). @@ -30,153 +135,47 @@ Relevant fields from indices: Full schemas are defined in data-prepper repo: [`otel-v1-apm-span-*`](https://github.com/opensearch-project/data-prepper/blob/04dd7bd18977294800cf4b77d7f01914def75f23/docs/schemas/trace-analytics/otel-v1-apm-span-index-template.md), [`otel-v1-apm-service-map`](https://github.com/opensearch-project/data-prepper/blob/4e5f83814c4a0eed2a1ca9bab0693b9e32240c97/docs/schemas/trace-analytics/otel-v1-apm-service-map-index-template.md) -### Requirement - -Support `join` to calculate the following: +**Requirement** For each service, join span index on service map index to calculate metrics under different type of filters. ![image](https://user-images.githubusercontent.com/28062824/194170062-f0dd1d57-c5eb-44db-95e0-6b3b4e52f25a.png) -This sample query calculates latency when filtered by trace group `client_cancel_order` for the `order` service. I only have a subquery example, don't have the join version of the query.. - -```sql -SELECT avg(durationInNanos) -FROM `otel-v1-apm-span-000001` t1 -WHERE t1.serviceName = `order` - AND ((t1.name in - (SELECT target.resource - FROM `otel-v1-apm-service-map` - WHERE serviceName = `order` - AND traceGroupName = `client_cancel_order`) - AND t1.parentSpanId != NULL) - OR (t1.parentSpanId = NULL - AND t1.name = `client_cancel_order`)) - AND t1.traceId in - (SELECT traceId - FROM `otel-v1-apm-span-000001` - WHERE serviceName = `order`) -``` -## Migrate to PPL - -### Syntax of Join Command - -```sql -SEARCH source= -| -| [joinType] JOIN - leftAlias - rightAlias - [joinHints] - ON joinCriteria - -| -``` -**joinType** -- Syntax: `[INNER] | LEFT [OUTER] | RIGHT [OUTER] | FULL [OUTER] | CROSS | [LEFT] SEMI | [LEFT] ANTI` -- Optional -- Description: The type of join to perform. The default is `INNER` if not specified. +This sample query calculates latency when filtered by trace group `client_cancel_order` for the `order` service. I only have a subquery example, don't have the join version of the query. -**leftAlias** -- Syntax: `left = ` -- Required -- Description: The subquery alias to use with the left join side, to avoid ambiguous naming. - -**rightAlias** -- Syntax: `right = ` -- Required -- Description: The subquery alias to use with the right join side, to avoid ambiguous naming. - -**joinHints** -- Syntax: `[hint.left.key1 = value1 hint.right.key2 = value2]` -- Optional -- Description: Zero or more space-separated join hints in the form of `Key` = `Value`. The key must start with `hint.left.` or `hint.right.` - -**joinCriteria** -- Syntax: `` -- Required -- Description: The syntax starts with `ON`. It could be any comparison expression. Generally, the join criteria looks like `.=.`. For example: `l.id = r.id`. If the join criteria contains multiple conditions, you can specify `AND` and `OR` operator between each comparison expression. For example, `l.id = r.id AND l.email = r.email AND (r.age > 65 OR r.age < 18)`. - -**right-table** -- Required -- Description: The index or table name of join right-side. Sub-search is unsupported in join right side for now. - -### Rewriting -```sql -SEARCH source=otel-v1-apm-span-000001 +PPL query: +``` +source=otel-v1-apm-span-000001 | WHERE serviceName = 'order' | JOIN left=t1 right=t2 ON t1.traceId = t2.traceId AND t2.serviceName = 'order' - otel-v1-apm-span-000001 -- self inner join -| EVAL s_name = t1.name -- rename to avoid ambiguous -| EVAL s_parentSpanId = t1.parentSpanId -- RENAME command would be better when it is supported -| EVAL s_durationInNanos = t1.durationInNanos -| FIELDS s_name, s_parentSpanId, s_durationInNanos -- reduce colunms in join + otel-v1-apm-span-000001 // self inner join +| RENAME s_name as t1.name +| RENAME s_parentSpanId as t1.parentSpanId +| RENAME s_durationInNanos as t1.durationInNanos +| FIELDS s_name, s_parentSpanId, s_durationInNanos // reduce colunms in join | LEFT JOIN left=s1 right=t3 ON s_name = t3.target.resource AND t3.serviceName = 'order' AND t3.traceGroupName = 'client_cancel_order' otel-v1-apm-service-map | WHERE (s_parentSpanId IS NOT NULL OR (s_parentSpanId IS NULL AND s_name = 'client_cancel_order')) -| STATS avg(s_durationInNanos) -- no need to add alias if there is no ambiguous -``` - - -### More examples - -Migration from SQL query (TPC-H Q13): -```sql -SELECT c_count, COUNT(*) AS custdist -FROM - ( SELECT c_custkey, COUNT(o_orderkey) c_count - FROM customer LEFT OUTER JOIN orders ON c_custkey = o_custkey - AND o_comment NOT LIKE '%unusual%packages%' - GROUP BY c_custkey - ) AS c_orders -GROUP BY c_count -ORDER BY custdist DESC, c_count DESC; -``` -Rewritten by PPL Join query: -```sql -SEARCH source=customer -| FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o - ON c.c_custkey = o.o_custkey AND o_comment NOT LIKE '%unusual%packages%' - orders -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count -| SORT - custdist, - c_count -``` -_- **Limitation: sub-searches is unsupported in join right side**_ - -If sub-searches is supported, above ppl query could be rewritten as: -```sql -SEARCH source=customer -| FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o ON c.c_custkey = o.o_custkey - [ - SEARCH source=orders - | WHERE o_comment NOT LIKE '%unusual%packages%' - | FIELDS o_orderkey, o_custkey - ] -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count -| SORT - custdist, - c_count +| STATS avg(s_durationInNanos) ``` ### Comparison with [Correlation](ppl-correlation-command) A primary difference between `correlate` and `join` is that both sides of `correlate` are tables, but both sides of `join` are subqueries. For example: -```sql +``` source = testTable1 - | where country = 'Canada' OR country = 'England' - | eval cname = lower(name) - | fields cname, country, year, month - | inner join left=l, right=r - ON l.cname = r.name AND l.country = r.country AND l.year = 2023 AND r.month = 4 - testTable2s +| where country = 'Canada' OR country = 'England' +| eval cname = lower(name) +| fields cname, country, year, month +| inner join left=l right=r + ON l.cname = r.name AND l.country = r.country AND l.year = 2023 AND r.month = 4 + testTable2s ``` The subquery alias `l` does not represent the `testTable1` table itself. Instead, it represents the subquery: -```sql +``` source = testTable1 | where country = 'Canada' OR country = 'England' | eval cname = lower(name) diff --git a/docs/ppl-lang/ppl-lookup-command.md b/docs/ppl-lang/ppl-lookup-command.md index 1b8350533..87cf34bac 100644 --- a/docs/ppl-lang/ppl-lookup-command.md +++ b/docs/ppl-lang/ppl-lookup-command.md @@ -1,20 +1,18 @@ -## PPL Lookup Command +## PPL `lookup` command -## Overview +### Description Lookup command enriches your search data by adding or replacing data from a lookup index (dimension table). You can extend fields of an index with values from a dimension table, append or replace values when lookup condition is matched. As an alternative of [Join command](ppl-join-command), lookup command is more suitable for enriching the source data with a static dataset. -### Syntax of Lookup Command +### Syntax -```sql -SEARCH source= -| -| LOOKUP ( [AS ])... - [(REPLACE | APPEND) ( [AS ])...] -| ``` +LOOKUP ( [AS ])... + [(REPLACE | APPEND) ( [AS ])...] +``` + **lookupIndex** - Required - Description: the name of lookup index (dimension table) @@ -44,26 +42,49 @@ SEARCH source= - Description: If you specify REPLACE, matched values in \ field overwrite the values in result. If you specify APPEND, matched values in \ field only append to the missing values in result. ### Usage -> LOOKUP id AS cid REPLACE mail AS email
-> LOOKUP name REPLACE mail AS email
-> LOOKUP id AS cid, name APPEND address, mail AS email
-> LOOKUP id
- -### Example -```sql -SEARCH source= -| WHERE orderType = 'Cancelled' -| LOOKUP account_list, mkt_id AS mkt_code REPLACE amount, account_name AS name -| STATS count(mkt_code), avg(amount) BY name -``` -```sql -SEARCH source= -| DEDUP market_id -| EVAL category=replace(category, "-", ".") -| EVAL category=ltrim(category, "dvp.") -| LOOKUP bounce_category category AS category APPEND classification -``` -```sql -SEARCH source= -| LOOKUP bounce_category category -``` +- `LOOKUP id AS cid REPLACE mail AS email` +- `LOOKUP name REPLACE mail AS email` +- `LOOKUP id AS cid, name APPEND address, mail AS email` +- `LOOKUP id` + +### Examples 1: replace + +PPL query: + + os>source=people | LOOKUP work_info uid AS id REPLACE department | head 10 + fetched rows / total rows = 10/10 + +------+-----------+-------------+-----------+--------+------------------+ + | id | name | occupation | country | salary | department | + +------+-----------+-------------+-----------+--------+------------------+ + | 1000 | Daniel | Teacher | Canada | 56486 | CUSTOMER_SERVICE | + | 1001 | Joseph | Lawyer | Denmark | 135943 | FINANCE | + | 1002 | David | Artist | Finland | 60391 | DATA | + | 1003 | Charlotte | Lawyer | Denmark | 42173 | LEGAL | + | 1004 | Isabella | Veterinarian| Australia | 117699 | MARKETING | + | 1005 | Lily | Engineer | Italy | 37526 | IT | + | 1006 | Emily | Dentist | Denmark | 125340 | MARKETING | + | 1007 | James | Lawyer | Germany | 56532 | LEGAL | + | 1008 | Lucas | Lawyer | Japan | 87782 | DATA | + | 1009 | Sophia | Architect | Sweden | 37597 | MARKETING | + +------+-----------+-------------+-----------+--------+------------------+ + +### Examples 2: append + +PPL query: + + os>source=people| LOOKUP work_info uid AS ID, name APPEND department | where isnotnull(department) | head 10 + fetched rows / total rows = 10/10 + +------+---------+-------------+-------------+--------+------------+ + | id | name | occupation | country | salary | department | + +------+---------+-------------+-------------+--------+------------+ + | 1018 | Emma | Architect | USA | 72400 | IT | + | 1032 | James | Pilot | Netherlands | 71698 | SALES | + | 1043 | Jane | Nurse | Brazil | 45016 | FINANCE | + | 1046 | Joseph | Pharmacist | Mexico | 109152 | OPERATIONS | + | 1064 | Joseph | Electrician | New Zealand | 50253 | LEGAL | + | 1090 | Matthew | Psychologist| Germany | 73396 | DATA | + | 1103 | Emily | Electrician | Switzerland | 98391 | DATA | + | 1114 | Jake | Nurse | Denmark | 53418 | SALES | + | 1115 | Sofia | Engineer | Mexico | 64829 | OPERATIONS | + | 1122 | Oliver | Scientist | Netherlands | 31146 | DATA | + +------+---------+-------------+-------------+--------+------------+ diff --git a/docs/ppl-lang/ppl-parse-command.md b/docs/ppl-lang/ppl-parse-command.md index 10be21cc0..0e000756e 100644 --- a/docs/ppl-lang/ppl-parse-command.md +++ b/docs/ppl-lang/ppl-parse-command.md @@ -1,4 +1,4 @@ -## PPL Parse Command +## PPL `parse` command ### Description diff --git a/docs/ppl-lang/ppl-rare-command.md b/docs/ppl-lang/ppl-rare-command.md index 5645382f8..93967e6fe 100644 --- a/docs/ppl-lang/ppl-rare-command.md +++ b/docs/ppl-lang/ppl-rare-command.md @@ -1,15 +1,18 @@ -## PPL rare Command +## PPL `rare` command -**Description** -Using ``rare`` command to find the least common tuple of values of all fields in the field list. +### Description +Using `rare` command to find the least common tuple of values of all fields in the field list. **Note**: A maximum of 10 results is returned for each distinct tuple of values of the group-by fields. -**Syntax** -`rare [by-clause]` +### Syntax +`rare [N] [by-clause]` +`rare_approx [N] [by-clause]` +* N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. +* rare_approx: approximate count of the rare (n) fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the least common values in a field @@ -19,6 +22,8 @@ The example finds least common gender of all the accounts. PPL query: os> source=accounts | rare gender; + os> source=accounts | rare_approx 10 gender; + os> source=accounts | rare_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -34,7 +39,8 @@ The example finds least common age of all the accounts group by gender. PPL query: - os> source=accounts | rare age by gender; + os> source=accounts | rare 5 age by gender; + os> source=accounts | rare_approx 5 age by gender; fetched rows / total rows = 4/4 +----------+-------+ | gender | age | diff --git a/docs/ppl-lang/ppl-search-command.md b/docs/ppl-lang/ppl-search-command.md index bccfd04f0..6e1cf0e50 100644 --- a/docs/ppl-lang/ppl-search-command.md +++ b/docs/ppl-lang/ppl-search-command.md @@ -1,7 +1,7 @@ ## PPL `search` command ### Description -Using ``search`` command to retrieve document from the index. ``search`` command could be only used as the first command in the PPL query. +Using `search` command to retrieve document from the index. `search` command could be only used as the first command in the PPL query. ### Syntax diff --git a/docs/ppl-lang/ppl-sort-command.md b/docs/ppl-lang/ppl-sort-command.md index c3bf304d7..dd9b4b33d 100644 --- a/docs/ppl-lang/ppl-sort-command.md +++ b/docs/ppl-lang/ppl-sort-command.md @@ -1,7 +1,7 @@ -## PPL `sort`command +## PPL `sort` command ### Description -Using ``sort`` command to sorts all the search result by the specified fields. +Using `sort` command to sorts all the search result by the specified fields. ### Syntax diff --git a/docs/ppl-lang/ppl-stats-command.md b/docs/ppl-lang/ppl-stats-command.md index 552f83e46..a73800b26 100644 --- a/docs/ppl-lang/ppl-stats-command.md +++ b/docs/ppl-lang/ppl-stats-command.md @@ -1,7 +1,7 @@ ## PPL `stats` command ### Description -Using ``stats`` command to calculate the aggregation from search result. +Using `stats` command to calculate the aggregation from search result. ### NULL/MISSING values handling: diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index c4a0c337c..766b37130 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -1,27 +1,27 @@ -## PPL SubQuery Commands: +## PPL `subquery` command -### Syntax -The subquery command should be implemented using a clean, logical syntax that integrates with existing PPL structure. +### Description +The subquery commands contain 4 types: `InSubquery`, `ExistsSubquery`, `ScalarSubquery` and `RelationSubquery`. +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are subquery expressions, their common usage is in Where clause(`where `) and Search filter(`search source=* `). -```sql -source=logs | where field in [ subquery source=events | where condition | fields field ] +For example, a subquery expression could be used in boolean expression: ``` - -In this example, the primary search (`source=logs`) is filtered by results from the subquery (`source=events`). - -The subquery command should allow nested queries to be as complex as necessary, supporting multiple levels of nesting. - -Example: - -```sql - source=logs | where id in [ subquery source=users | where user in [ subquery source=actions | where action="login" | fields user] | fields uid ] +| where orders.order_id in [ source=returns | where return_reason="damaged" | field order_id ] ``` +The `orders.order_id in [ source=... ]` is a ``. -For additional info See [Issue](https://github.com/opensearch-project/opensearch-spark/issues/661) - ---- +But `RelationSubquery` is not a subquery expression, it is a subquery plan. +[Recall the join command doc](ppl-join-command.md), the example is a subquery/subsearch **plan**, rather than a **expression**. -### InSubquery usage +### Syntax +- `where [not] in [ source=... | ... | ... ]` (InSubquery) +- `where [not] exists [ source=... | ... | ... ]` (ExistsSubquery) +- `where = [ source=... | ... | ... ]` (ScalarSubquery) +- `source=[ source= ...]` (RelationSubquery) +- `| join ON condition [ source= ]` (RelationSubquery in join right side) + +### Usage +InSubquery: - `source = outer | where a in [ source = inner | fields b ]` - `source = outer | where (a) in [ source = inner | fields b ]` - `source = outer | where (a,b,c) in [ source = inner | fields d,e,f ]` @@ -33,92 +33,9 @@ For additional info See [Issue](https://github.com/opensearch-project/opensearch - `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) - `source = table1 | inner join left = l right = r on l.a = r.a AND r.a in [ source = inner | fields d ] | fields l.a, r.a, b, c` (as join filter) -**_SQL Migration examples with IN-Subquery PPL:_** -1. tpch q4 (in-subquery with aggregation) -```sql -select - o_orderpriority, - count(*) as order_count -from - orders -where - o_orderdate >= date '1993-07-01' - and o_orderdate < date '1993-07-01' + interval '3' month - and o_orderkey in ( - select - l_orderkey - from - lineitem - where l_commitdate < l_receiptdate - ) -group by - o_orderpriority -order by - o_orderpriority -``` -Rewritten by PPL InSubquery query: -```sql -source = orders -| where o_orderdate >= "1993-07-01" and o_orderdate < "1993-10-01" and o_orderkey IN - [ source = lineitem - | where l_commitdate < l_receiptdate - | fields l_orderkey - ] -| stats count(1) as order_count by o_orderpriority -| sort o_orderpriority -| fields o_orderpriority, order_count -``` -2.tpch q20 (nested in-subquery) -```sql -select - s_name, - s_address -from - supplier, - nation -where - s_suppkey in ( - select - ps_suppkey - from - partsupp - where - ps_partkey in ( - select - p_partkey - from - part - where - p_name like 'forest%' - ) - ) - and s_nationkey = n_nationkey - and n_name = 'CANADA' -order by - s_name -``` -Rewritten by PPL InSubquery query: -```sql -source = supplier -| where s_suppkey IN [ - source = partsupp - | where ps_partkey IN [ - source = part - | where like(p_name, "forest%") - | fields p_partkey - ] - | fields ps_suppkey - ] -| inner join left=l right=r on s_nationkey = n_nationkey and n_name = 'CANADA' - nation -| sort s_name -``` ---- - -### ExistsSubquery usage - -Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 +ExistsSubquery: +(Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2) - `source = outer | where exists [ source = inner | where a = c ]` - `source = outer | where not exists [ source = inner | where a = c ]` - `source = outer | where exists [ source = inner | where a = c and b = d ]` @@ -132,48 +49,9 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where not exists [ source = inner | where c > 10 ]` (uncorrelated exists) - `source = outer | where exists [ source = inner ] | eval l = "nonEmpty" | fields l` (special uncorrelated exists) -**_SQL Migration examples with Exists-Subquery PPL:_** - -tpch q4 (exists subquery with aggregation) -```sql -select - o_orderpriority, - count(*) as order_count -from - orders -where - o_orderdate >= date '1993-07-01' - and o_orderdate < date '1993-07-01' + interval '3' month - and exists ( - select - l_orderkey - from - lineitem - where l_orderkey = o_orderkey - and l_commitdate < l_receiptdate - ) -group by - o_orderpriority -order by - o_orderpriority -``` -Rewritten by PPL ExistsSubquery query: -```sql -source = orders -| where o_orderdate >= "1993-07-01" and o_orderdate < "1993-10-01" - and exists [ - source = lineitem - | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate - ] -| stats count(1) as order_count by o_orderpriority -| sort o_orderpriority -| fields o_orderpriority, order_count -``` ---- - -### ScalarSubquery usage +ScalarSubquery: -Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested +(Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested) **Uncorrelated scalar subquery in Select** - `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` @@ -203,146 +81,102 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` - `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` -_SQL Migration examples with Scalar-Subquery PPL:_ -Example 1 -```sql -SELECT * -FROM outer -WHERE a = (SELECT max(c) - FROM inner1 - WHERE c = (SELECT max(e) - FROM inner2 - GROUP BY f - ORDER BY f - ) - GROUP BY c - ORDER BY c - LIMIT 1) -``` -Rewritten by PPL ScalarSubquery query: -```sql -source = spark_catalog.default.outer -| where a = [ - source = spark_catalog.default.inner1 - | where c = [ - source = spark_catalog.default.inner2 - | stats max(e) by f - | sort f - ] - | stats max(d) by c - | sort c - | head 1 - ] -``` -Example 2 -```sql -SELECT * FROM outer -WHERE a = (SELECT max(c) - FROM inner - ORDER BY c) -OR b = (SELECT min(d) - FROM inner - WHERE c = 1 - ORDER BY d) -``` -Rewritten by PPL ScalarSubquery query: -```sql -source = spark_catalog.default.outer -| where a = [ - source = spark_catalog.default.inner | stats max(c) | sort c - ] OR b = [ - source = spark_catalog.default.inner | where c = 1 | stats min(d) | sort d - ] -``` ---- - -### (Relation) Subquery -`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or From clause. - -- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +RelationSubquery: +- `source = table1 | join left = l right = r on condition [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) - `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` -**_SQL Migration examples with Subquery PPL:_** - -tpch q13 -```sql -select - c_count, - count(*) as custdist -from - ( - select - c_custkey, - count(o_orderkey) as c_count - from - customer left outer join orders on - c_custkey = o_custkey - and o_comment not like '%special%requests%' - group by - c_custkey - ) as c_orders -group by - c_count -order by - custdist desc, - c_count desc -``` -Rewritten by PPL (Relation) Subquery: -```sql -SEARCH source = [ - SEARCH source = customer - | LEFT OUTER JOIN left = c right = o ON c_custkey = o_custkey - [ - SEARCH source = orders - | WHERE not like(o_comment, '%special%requests%') - ] - | STATS COUNT(o_orderkey) AS c_count BY c_custkey -] AS c_orders -| STATS COUNT(o_orderkey) AS c_count BY c_custkey -| STATS COUNT(1) AS custdist BY c_count -| SORT - custdist, - c_count -``` ---- +### Examples 1: TPC-H q20 + +InSubquery and ScalarSubquery + +PPL query: + + os> source=supplier + | join ON s_nationkey = n_nationkey nation + | where n_name = 'CANADA' + and s_suppkey in [ // InSubquery + source = partsupp + | where ps_partkey in [ // InSubquery + source = part + | where like(p_name, 'forest%') + | fields p_partkey + ] + and ps_availqty > [ // ScalarSubquery + source = lineitem + | where l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1994-01-01') + and l_shipdate < date_add(date('1994-01-01'), interval 1 year) + | stats sum(l_quantity) as sum_l_quantity + | eval half_sum_l_quantity = 0.5 * sum_l_quantity + | fields half_sum_l_quantity + ] + | fields ps_suppkey + ] + | fields s_suppkey, s_name, s_phone, s_acctbal, n_name | head 10 + fetched rows / total rows = 10/10 + +-----------+---------------------+----------------+----------+---------+ + | s_suppkey | s_name | s_phone | s_acctbal| n_name | + +-----------+---------------------+----------------+----------+---------+ + | 8243 | Supplier#000008243 | 13-707-547-1386| 9067.07 | CANADA | + | 736 | Supplier#000000736 | 13-681-806-8650| 5700.83 | CANADA | + | 9032 | Supplier#000009032 | 13-441-662-5539| 3982.32 | CANADA | + | 3201 | Supplier#000003201 | 13-600-413-7165| 3799.41 | CANADA | + | 3849 | Supplier#000003849 | 13-582-965-9117| 52.33 | CANADA | + | 5505 | Supplier#000005505 | 13-531-190-6523| 2023.4 | CANADA | + | 5195 | Supplier#000005195 | 13-622-661-2956| 3717.34 | CANADA | + | 9753 | Supplier#000009753 | 13-724-256-7877| 4406.93 | CANADA | + | 7135 | Supplier#000007135 | 13-367-994-6705| 4950.29 | CANADA | + | 5256 | Supplier#000005256 | 13-180-538-8836| 5624.79 | CANADA | + +-----------+---------------------+----------------+----------+---------+ + + +### Examples 2: TPC-H q22 + +RelationSubquery, ScalarSubquery and ExistsSubquery + +PPL query: + + os> source = [ // RelationSubquery + source = customer + | where substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > [ // ScalarSubquery + source = customer + | where c_acctbal > 0.00 + and substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + | stats avg(c_acctbal) + ] + and not exists [ // ExistsSubquery + source = orders + | where o_custkey = c_custkey + ] + | eval cntrycode = substring(c_phone, 1, 2) + | fields cntrycode, c_acctbal + ] as custsale + | stats count() as numcust, sum(c_acctbal) as totacctbal by cntrycode + | sort cntrycode + fetched rows / total rows = 10/10 + +---------+--------------------+------------+ + | numcust | totacctbal | cntrycode | + +---------+--------------------+------------+ + | 888 | 6737713.989999999 | 13 | + | 861 | 6460573.72 | 17 | + | 964 | 7236687.4 | 18 | + | 892 | 6701457.950000001 | 23 | + | 948 | 7158866.630000001 | 29 | + | 909 | 6808436.129999999 | 30 | + | 922 | 6806670.179999999 | 31 | + +---------+--------------------+------------+ ### Additional Context -`InSubquery`, `ExistsSubquery` and `ScalarSubquery` as subquery expressions, their common usage is in `where` clause and `search filter`. - -Where command: -``` -| where | ... -``` -Search filter: -``` -search source=* | ... -``` -A subquery expression could be used in boolean expression, for example - -```sql -| where orders.order_id in [ source=returns | where return_reason="damaged" | field order_id ] -``` - -The `orders.order_id in [ source=... ]` is a ``. - -In general, we name this kind of subquery clause the `InSubquery` expression, it is a ``. - -**Subquery with Different Join Types** +#### RelationSubquery -In issue description is a `ScalarSubquery`: - -```sql -source=employees -| join source=sales on employees.employee_id = sales.employee_id -| where sales.sale_amount > [ source=targets | where target_met="true" | fields target_value ] +RelationSubquery is plan instead of expression, for example ``` - -But `RelationSubquery` is not a subquery expression, it is a subquery plan. -[Recall the join command doc](ppl-join-command.md), the example is a subquery/subsearch **plan**, rather than a **expression**. - -```sql -SEARCH source=customer +source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o ON c.c_custkey = o.o_custkey +| LEFT OUTER JOIN left = c right = o ON c.c_custkey = o.o_custkey [ SEARCH source=orders | WHERE o_comment NOT LIKE '%unusual%packages%' @@ -351,7 +185,7 @@ SEARCH source=customer | STATS ... ``` simply into -```sql +``` SEARCH | LEFT OUTER JOIN ON [ @@ -359,21 +193,14 @@ SEARCH ] | STATS ... ``` -Apply the syntax here and simply into - -```sql -search | left join on [ search ... ] -``` - -The `[ search ...]` is not a `expression`, it's `plan`, similar to the `relation` plan -**Uncorrelated Subquery** +#### Uncorrelated Subquery An uncorrelated subquery is independent of the outer query. It is executed once, and the result is used by the outer query. It's **less common** when using `ExistsSubquery` because `ExistsSubquery` typically checks for the presence of rows that are dependent on the outer query’s row. There is a very special exists subquery which highlight by `(special uncorrelated exists)`: -```sql +``` SELECT 'nonEmpty' FROM outer WHERE EXISTS ( @@ -382,7 +209,7 @@ FROM outer ); ``` Rewritten by PPL ExistsSubquery query: -```sql +``` source = outer | where exists [ source = inner @@ -392,11 +219,11 @@ source = outer ``` This query just print "nonEmpty" if the inner table is not empty. -**Table alias in subquery** +#### Table alias in subquery Table alias is useful in query which contains a subquery, for example -```sql +``` select a, ( select sum(b) from catalog.schema.table1 as t1 diff --git a/docs/ppl-lang/ppl-top-command.md b/docs/ppl-lang/ppl-top-command.md index 4ba56f692..2bacdba50 100644 --- a/docs/ppl-lang/ppl-top-command.md +++ b/docs/ppl-lang/ppl-top-command.md @@ -1,16 +1,17 @@ -## PPL top Command +## PPL `top` command -**Description** +### Description Using ``top`` command to find the most common tuple of values of all fields in the field list. ### Syntax `top [N] [by-clause]` +`top_approx [N] [by-clause]` * N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. - +* top_approx: approximate count of the (n) top fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the most common values in a field @@ -19,6 +20,7 @@ The example finds most common gender of all the accounts. PPL query: os> source=accounts | top gender; + os> source=accounts | top_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -33,7 +35,7 @@ The example finds most common gender of all the accounts. PPL query: - os> source=accounts | top 1 gender; + os> source=accounts | top_approx 1 gender; fetched rows / total rows = 1/1 +----------+ | gender | @@ -48,6 +50,7 @@ The example finds most common age of all the accounts group by gender. PPL query: os> source=accounts | top 1 age by gender; + os> source=accounts | top_approx 1 age by gender; fetched rows / total rows = 2/2 +----------+-------+ | gender | age | diff --git a/docs/ppl-lang/ppl-tpch.md b/docs/ppl-lang/ppl-tpch.md new file mode 100644 index 000000000..ef5846ce0 --- /dev/null +++ b/docs/ppl-lang/ppl-tpch.md @@ -0,0 +1,102 @@ +## TPC-H Benchmark + +TPC-H is a decision support benchmark designed to evaluate the performance of database systems in handling complex business-oriented queries and concurrent data modifications. The benchmark utilizes a dataset that is broadly representative of various industries, making it widely applicable. TPC-H simulates a decision support environment where large volumes of data are analyzed, intricate queries are executed, and critical business questions are answered. + +### Test PPL Queries + +TPC-H 22 test query statements: [TPCH-Query-PPL](https://github.com/opensearch-project/opensearch-spark/blob/main/integ-test/src/integration/resources/tpch) + +### Data Preparation + +#### Option 1 - from PyPi + +``` +# Create the virtual environment +python3 -m venv .venv + +# Activate the virtual environment +. .venv/bin/activate + +pip install tpch-datagen +``` + +#### Option 2 - from source + +``` +git clone https://github.com/gizmodata/tpch-datagen + +cd tpch-datagen + +# Create the virtual environment +python3 -m venv .venv + +# Activate the virtual environment +. .venv/bin/activate + +# Upgrade pip, setuptools, and wheel +pip install --upgrade pip setuptools wheel + +# Install TPC-H Datagen - in editable mode with client and dev dependencies +pip install --editable .[dev] +``` + +#### Usage + +Here are the options for the tpch-datagen command: +``` +tpch-datagen --help +Usage: tpch-datagen [OPTIONS] + +Options: + --version / --no-version Prints the TPC-H Datagen package version and + exits. [required] + --scale-factor INTEGER The TPC-H Scale Factor to use for data + generation. + --data-directory TEXT The target output data directory to put the + files into [default: data; required] + --work-directory TEXT The work directory to use for data + generation. [default: /tmp; required] + --overwrite / --no-overwrite Can we overwrite the target directory if it + already exists... [default: no-overwrite; + required] + --num-chunks INTEGER The number of chunks that will be generated + - more chunks equals smaller memory + requirements, but more files generated. + [default: 10; required] + --num-processes INTEGER The maximum number of processes for the + multi-processing pool to use for data + generation. [default: 10; required] + --duckdb-threads INTEGER The number of DuckDB threads to use for data + generation (within each job process). + [default: 1; required] + --per-thread-output / --no-per-thread-output + Controls whether to write the output to a + single file or multiple files (for each + process). [default: per-thread-output; + required] + --compression-method [none|snappy|gzip|zstd] + The compression method to use for the + parquet files generated. [default: zstd; + required] + --file-size-bytes TEXT The target file size for the parquet files + generated. [default: 100m; required] + --help Show this message and exit. +``` + +### Generate 1 GB data with zstd (by default) compression + +``` +tpch-datagen --scale-factor 1 +``` + +### Generate 10 GB data with snappy compression + +``` +tpch-datagen --scale-factor 10 --compression-method snappy +``` + +### Query Test + +All TPC-H PPL Queries located in `integ-test/src/integration/resources/tpch` folder. + +To test all queries, run `org.opensearch.flint.spark.ppl.tpch.TPCHQueryITSuite`. \ No newline at end of file diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md index 393a9dd59..b2be172cd 100644 --- a/docs/ppl-lang/ppl-trendline-command.md +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -1,10 +1,9 @@ -## PPL trendline Command +## PPL `trendline` command -**Description** -Using ``trendline`` command to calculate moving averages of fields. +### Description +Using `trendline` command to calculate moving averages of fields. - -### Syntax +### Syntax - SMA (Simple Moving Average) `TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. @@ -13,8 +12,6 @@ Using ``trendline`` command to calculate moving averages of fields. * field: mandatory. the name of the field the moving average should be calculated for. * alias: optional. the name of the resulting column containing the moving average. -And the moment only the Simple Moving Average (SMA) type is supported. - It is calculated like f[i]: The value of field 'f' in the i-th data-point @@ -23,7 +20,7 @@ It is calculated like SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t -### Example 1: Calculate simple moving average for a timeseries of temperatures +#### Example 1: Calculate simple moving average for a timeseries of temperatures The example calculates the simple moving average over temperatures using two datapoints. @@ -41,7 +38,7 @@ PPL query: | 15| 258|2023-04-06 17:07:...| 14.5| +-----------+---------+--------------------+----------+ -### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. @@ -58,3 +55,58 @@ PPL query: | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| +-----------+---------+--------------------+------------+------------------+ + + +### Syntax - WMA (Weighted Moving Average) +`TRENDLINE sort <[+|-] sort-field> WMA(number-of-datapoints, field) [AS alias] [WMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory. this field specifies the ordering of data poients when calculating the nth_value aggregation. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving averag should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data point + n: The number of data points in the moving window (period) + t: The current time index + w[i]: The weight assigned to the i-th data point, typically increasing for more recent points + + WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) + +#### Example 1: Calculate weighted moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sort timestamp wma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.6| + | 14| 257|2023-04-06 17:07:...| 13.6| + | 15| 258|2023-04-06 17:07:...| 14.6| + +-----------+---------+--------------------+----------+ + +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id wma(2, temperature) as temp_trend_2 wma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.3| NULL| + | 13| 256|2023-04-06 17:07:...| 13.3| 13.6| + | 12| 1492|2023-04-06 17:07:...| 12.3| 12.6| + | 12| 1492|2023-04-06 17:07:...| 12.0| 12.16| + +-----------+---------+--------------------+------------+------------------+ diff --git a/docs/ppl-lang/ppl-where-command.md b/docs/ppl-lang/ppl-where-command.md index c954623c3..ec676ab62 100644 --- a/docs/ppl-lang/ppl-where-command.md +++ b/docs/ppl-lang/ppl-where-command.md @@ -1,4 +1,4 @@ -## PPL where Command +## PPL `where` command ### Description The ``where`` command bool-expression to filter the search result. The ``where`` command only return the result when bool-expression evaluated to true. @@ -27,15 +27,15 @@ PPL query: ### Additional Examples #### **Filters With Logical Conditions** -``` -- `source = table | where c = 'test' AND a = 1 | fields a,b,c` -- `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1` -- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - `source = table | where a = 1 | fields a,b,c` - `source = table | where a >= 1 | fields a,b,c` - `source = table | where a < 1 | fields a,b,c` - `source = table | where b != 'test' | fields a,b,c` - `source = table | where c = 'test' | fields a,b,c | head 3` +- `source = table | where c = 'test' AND a = 1 | fields a,b,c` +- `source = table | where c != 'test' OR a > 1 | fields a,b,c` +- `source = table | where (b > 1 OR a > 1) AND c != 'test' | fields a,b,c` +- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - Note: "AND" is optional - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` @@ -45,7 +45,6 @@ PPL query: - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - - `source = table | eval status_category = case(a >= 200 AND a < 300, 'Success', a >= 300 AND a < 400, 'Redirection', @@ -57,10 +56,8 @@ PPL query: a >= 400 AND a < 500, 'Client Error', a >= 500, 'Server Error' else 'Incorrect HTTP status code' - ) = 'Incorrect HTTP status code' - + ) = 'Incorrect HTTP status code'` - `source = table | eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1) | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even' | stats count() by factor` -``` \ No newline at end of file diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala b/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala new file mode 100644 index 000000000..c23178f00 --- /dev/null +++ b/flint-commons/src/main/scala/org/apache/spark/sql/exception/UnrecoverableException.scala @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.exception + +/** + * Represents an unrecoverable exception in session management and statement execution. This + * exception is used for errors that cannot be handled or recovered from. + */ +class UnrecoverableException private (message: String, cause: Throwable) + extends RuntimeException(message, cause) { + + def this(cause: Throwable) = + this(cause.getMessage, cause) +} + +object UnrecoverableException { + def apply(cause: Throwable): UnrecoverableException = + new UnrecoverableException(cause) + + def apply(message: String, cause: Throwable): UnrecoverableException = + new UnrecoverableException(message, cause) +} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/FlintVersion.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/FlintVersion.scala index 1203ea7ef..53574b770 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/FlintVersion.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/FlintVersion.scala @@ -20,6 +20,7 @@ object FlintVersion { val V_0_4_0: FlintVersion = FlintVersion("0.4.0") val V_0_5_0: FlintVersion = FlintVersion("0.5.0") val V_0_6_0: FlintVersion = FlintVersion("0.6.0") + val V_0_7_0: FlintVersion = FlintVersion("0.7.0") - def current(): FlintVersion = V_0_6_0 + def current(): FlintVersion = V_0_7_0 } diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index 915d5e229..16c9747d9 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -52,7 +52,7 @@ class InteractiveSession( val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None, + var error: Option[String] = None, sessionContext: Map[String, Any] = Map.empty[String, Any]) extends ContextualDataStore with Logging { @@ -72,7 +72,7 @@ class InteractiveSession( val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") val errorStr = error.getOrElse("None") // Does not include context, which could contain sensitive information. - s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + + s"InteractiveSession(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" } } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index ef4d01652..79e70b8c2 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -70,6 +70,11 @@ public final class MetricConstants { */ public static final String REPL_PROCESSING_TIME_METRIC = "session.processingTime"; + /** + * Metric name for counting the number of queries executed per session. + */ + public static final String REPL_QUERY_COUNT_METRIC = "session.query.count"; + /** * Prefix for metrics related to the request metadata read operations. */ @@ -135,6 +140,17 @@ public final class MetricConstants { */ public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime"; + /** + * Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX) + */ + public static final String QUERY_DROP_COUNT_METRIC = "query.drop.count"; + public static final String QUERY_VACUUM_COUNT_METRIC = "query.vacuum.count"; + public static final String QUERY_ALTER_COUNT_METRIC = "query.alter.count"; + public static final String QUERY_REFRESH_COUNT_METRIC = "query.refresh.count"; + public static final String QUERY_CREATE_INDEX_COUNT_METRIC = "query.createIndex.count"; + public static final String QUERY_CREATE_INDEX_AUTO_REFRESH_COUNT_METRIC = "query.createIndex.autoRefresh.count"; + public static final String QUERY_CREATE_INDEX_MANUAL_REFRESH_COUNT_METRIC = "query.createIndex.manualRefresh.count"; + /** * Metric for tracking the total bytes read from input */ @@ -155,6 +171,21 @@ public final class MetricConstants { */ public static final String OUTPUT_TOTAL_RECORDS_WRITTEN = "output.totalRecordsWritten.count"; + /** + * Metric group related to skipping indices, such as create success and failure + */ + public static final String CREATE_SKIPPING_INDICES = "query.execution.index.skipping"; + + /** + * Metric group related to covering indices, such as create success and failure + */ + public static final String CREATE_COVERING_INDICES = "query.execution.index.covering"; + + /** + * Metric group related to materialized view indices, such as create success and failure + */ + public static final String CREATE_MV_INDICES = "query.execution.index.mv"; + /** * Metric for tracking the latency of checkpoint deletion */ @@ -175,6 +206,16 @@ public final class MetricConstants { */ public static final String INITIAL_CONDITION_CHECK_FAILED_PREFIX = "initialConditionCheck.failed."; + /** + * Metric for tracking the JVM GC time per task + */ + public static final String TASK_JVM_GC_TIME_METRIC = "task.jvmGCTime.count"; + + /** + * Metric for tracking the total JVM GC time for query + */ + public static final String TOTAL_JVM_GC_TIME_METRIC = "query.totalJvmGCTime.count"; + private MetricConstants() { // Private constructor to prevent instantiation } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/MetricsSparkListener.scala similarity index 74% rename from flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala rename to flint-core/src/main/scala/org/opensearch/flint/core/metrics/MetricsSparkListener.scala index bfafd3eb3..2ee941260 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/MetricsSparkListener.scala @@ -6,17 +6,18 @@ package org.opensearch.flint.core.metrics import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerTaskEnd} import org.apache.spark.sql.SparkSession /** - * Collect and emit bytesRead/Written and recordsRead/Written metrics + * Collect and emit metrics by listening spark events */ -class ReadWriteBytesSparkListener extends SparkListener with Logging { +class MetricsSparkListener extends SparkListener with Logging { var bytesRead: Long = 0 var recordsRead: Long = 0 var bytesWritten: Long = 0 var recordsWritten: Long = 0 + var totalJvmGcTime: Long = 0 override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val inputMetrics = taskEnd.taskMetrics.inputMetrics @@ -31,21 +32,28 @@ class ReadWriteBytesSparkListener extends SparkListener with Logging { recordsRead += inputMetrics.recordsRead bytesWritten += outputMetrics.bytesWritten recordsWritten += outputMetrics.recordsWritten + totalJvmGcTime += taskEnd.taskMetrics.jvmGCTime + + MetricsUtil.addHistoricGauge( + MetricConstants.TASK_JVM_GC_TIME_METRIC, + taskEnd.taskMetrics.jvmGCTime) } def emitMetrics(): Unit = { logInfo(s"Input: totalBytesRead=${bytesRead}, totalRecordsRead=${recordsRead}") logInfo(s"Output: totalBytesWritten=${bytesWritten}, totalRecordsWritten=${recordsWritten}") + logInfo(s"totalJvmGcTime=${totalJvmGcTime}") MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_BYTES_READ, bytesRead) MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_RECORDS_READ, recordsRead) MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_BYTES_WRITTEN, bytesWritten) MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_RECORDS_WRITTEN, recordsWritten) + MetricsUtil.addHistoricGauge(MetricConstants.TOTAL_JVM_GC_TIME_METRIC, totalJvmGcTime) } } -object ReadWriteBytesSparkListener { +object MetricsSparkListener { def withMetrics[T](spark: SparkSession, lambda: () => T): T = { - val listener = new ReadWriteBytesSparkListener() + val listener = new MetricsSparkListener() spark.sparkContext.addSparkListener(listener) val result = lambda() diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index da22e3751..2bc097bba 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -16,6 +16,8 @@ import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.IRestHighLevelClient; +import org.opensearch.flint.core.metrics.MetricConstants; +import org.opensearch.flint.core.metrics.MetricsUtil; import scala.Option; import java.io.IOException; @@ -40,7 +42,13 @@ public FlintOpenSearchClient(FlintOptions options) { @Override public void createIndex(String indexName, FlintMetadata metadata) { LOG.info("Creating Flint index " + indexName + " with metadata " + metadata); - createIndex(indexName, FlintOpenSearchIndexMetadataService.serialize(metadata, false), metadata.indexSettings()); + try { + createIndex(indexName, FlintOpenSearchIndexMetadataService.serialize(metadata, false), metadata.indexSettings()); + emitIndexCreationSuccessMetric(metadata.kind()); + } catch (IllegalStateException ex) { + emitIndexCreationFailureMetric(metadata.kind()); + throw ex; + } } protected void createIndex(String indexName, String mapping, Option settings) { @@ -122,4 +130,28 @@ public IRestHighLevelClient createClient() { private String sanitizeIndexName(String indexName) { return OpenSearchClientUtils.sanitizeIndexName(indexName); } + + private void emitIndexCreationSuccessMetric(String indexKind) { + emitIndexCreationMetric(indexKind, "success"); + } + + private void emitIndexCreationFailureMetric(String indexKind) { + emitIndexCreationMetric(indexKind, "failed"); + } + + private void emitIndexCreationMetric(String indexKind, String status) { + switch (indexKind) { + case "skipping": + MetricsUtil.addHistoricGauge(String.format("%s.%s.count", MetricConstants.CREATE_SKIPPING_INDICES, status), 1); + break; + case "covering": + MetricsUtil.addHistoricGauge(String.format("%s.%s.count", MetricConstants.CREATE_COVERING_INDICES, status), 1); + break; + case "mv": + MetricsUtil.addHistoricGauge(String.format("%s.%s.count", MetricConstants.CREATE_MV_INDICES, status), 1); + break; + default: + break; + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index ca659550d..3a12b63fe 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -14,7 +14,7 @@ import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedView -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getSourceTablesFromMetadata, MV_INDEX_TYPE} import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind @@ -141,9 +141,9 @@ object FlintSparkIndexFactory extends Logging { } private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = { - val sourceTables = getArrayString(metadata.properties, "sourceTables") + val sourceTables = getSourceTablesFromMetadata(metadata) if (sourceTables.isEmpty) { - FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source) + FlintSparkMaterializedView.extractSourceTablesFromQuery(spark, metadata.source) } else { sourceTables } @@ -161,12 +161,4 @@ object FlintSparkIndexFactory extends Logging { Some(value.asInstanceOf[String]) } } - - private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = { - map.get(key) match { - case list: java.util.ArrayList[_] => - list.toArray.map(_.toString) - case _ => Array.empty[String] - } - } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala index e1c0f318c..86267c881 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.FlintSparkIndexOptions -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getSourceTablesFromMetadata, MV_INDEX_TYPE} import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser /** @@ -61,12 +61,7 @@ object FlintMetadataCache { None } val sourceTables = metadata.kind match { - case MV_INDEX_TYPE => - metadata.properties.get("sourceTables") match { - case list: java.util.ArrayList[_] => - list.toArray.map(_.toString) - case _ => Array.empty[String] - } + case MV_INDEX_TYPE => getSourceTablesFromMetadata(metadata) case _ => Array(metadata.source) } val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry => diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala index 2bc373792..f6fc0ba6f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala @@ -38,13 +38,15 @@ class FlintOpenSearchMetadataCacheWriter(options: FlintOptions) .isInstanceOf[FlintOpenSearchIndexMetadataService] override def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit = { - logInfo(s"Updating metadata cache for $indexName"); + logInfo(s"Updating metadata cache for $indexName with $metadata"); val osIndexName = OpenSearchClientUtils.sanitizeIndexName(indexName) var client: IRestHighLevelClient = null try { client = OpenSearchClientUtils.createClient(options) val request = new PutMappingRequest(osIndexName) - request.source(serialize(metadata), XContentType.JSON) + val serialized = serialize(metadata) + logInfo(s"Serialized: $serialized") + request.source(serialized, XContentType.JSON) client.updateIndexMapping(request, RequestOptions.DEFAULT) } catch { case e: Exception => diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index aecfc99df..d5c450e7e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.mv import java.util.Locale -import scala.collection.JavaConverters.mapAsJavaMapConverter +import scala.collection.JavaConverters._ import scala.collection.convert.ImplicitConversions.`map AsScala` import org.opensearch.flint.common.metadata.FlintMetadata @@ -18,6 +18,7 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} @@ -64,10 +65,14 @@ case class FlintSparkMaterializedView( }.toArray val schema = generateSchema(outputSchema).asJava + // Convert Scala Array to Java ArrayList for consistency with OpenSearch JSON parsing. + // OpenSearch uses Jackson, which deserializes JSON arrays into ArrayLists. + val sourceTablesProperty = new java.util.ArrayList[String](sourceTables.toSeq.asJava) + metadataBuilder(this) .name(mvName) .source(query) - .addProperty("sourceTables", sourceTables) + .addProperty("sourceTables", sourceTablesProperty) .indexedColumns(indexColumnMaps) .schema(schema) .build() @@ -133,8 +138,14 @@ case class FlintSparkMaterializedView( // Assume first aggregate item must be time column val winFunc = winFuncs.head - val timeCol = winFunc.arguments.head.asInstanceOf[Attribute] - Some(agg, timeCol) + val timeCol = winFunc.arguments.head + timeCol match { + case attr: Attribute => + Some(agg, attr) + case _ => + throw new IllegalArgumentException( + s"Tumble function only supports simple timestamp column, but found: $timeCol") + } } private def isWindowingFunction(func: UnresolvedFunction): Boolean = { @@ -147,7 +158,7 @@ case class FlintSparkMaterializedView( } } -object FlintSparkMaterializedView { +object FlintSparkMaterializedView extends Logging { /** MV index type name */ val MV_INDEX_TYPE = "mv" @@ -179,13 +190,40 @@ object FlintSparkMaterializedView { * @return * source table names */ - def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = { - spark.sessionState.sqlParser + def extractSourceTablesFromQuery(spark: SparkSession, query: String): Array[String] = { + logInfo(s"Extracting source tables from query $query") + val sourceTables = spark.sessionState.sqlParser .parsePlan(query) .collect { case relation: UnresolvedRelation => qualifyTableName(spark, relation.tableName) } .toArray + logInfo(s"Extracted tables: [${sourceTables.mkString(", ")}]") + sourceTables + } + + /** + * Get source tables from Flint metadata properties field. + * + * @param metadata + * Flint metadata + * @return + * source table names + */ + def getSourceTablesFromMetadata(metadata: FlintMetadata): Array[String] = { + logInfo(s"Getting source tables from metadata $metadata") + val sourceTables = metadata.properties.get("sourceTables") + sourceTables match { + case list: java.util.ArrayList[_] => + logInfo(s"sourceTables is [${list.asScala.mkString(", ")}]") + list.toArray.map(_.toString) + case null => + logInfo("sourceTables property does not exist") + Array.empty[String] + case _ => + logInfo(s"sourceTables has unexpected type: ${sourceTables.getClass.getName}") + Array.empty[String] + } } /** Builder class for MV build */ @@ -217,7 +255,7 @@ object FlintSparkMaterializedView { */ def query(query: String): Builder = { this.query = query - this.sourceTables = extractSourceTableNames(flint.spark, query) + this.sourceTables = extractSourceTablesFromQuery(flint.spark, query) this } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala index bedeeba54..ba605d3bf 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.refresh import java.util.Collections -import org.opensearch.flint.core.metrics.ReadWriteBytesSparkListener +import org.opensearch.flint.core.metrics.MetricsSparkListener import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions, FlintSparkValidationHelper} import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode} @@ -68,7 +68,7 @@ class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) // Flint index has specialized logic and capability for incremental refresh case refresh: StreamingRefresh => logInfo("Start refreshing index in streaming style") - val job = ReadWriteBytesSparkListener.withMetrics( + val job = MetricsSparkListener.withMetrics( spark, () => refresh diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/IndexMetricHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/IndexMetricHelper.scala new file mode 100644 index 000000000..45b439ff0 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/IndexMetricHelper.scala @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sql + +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} + +trait IndexMetricHelper { + def emitCreateIndexMetric(autoRefresh: Boolean): Unit = { + MetricsUtil.incrementCounter(MetricConstants.QUERY_CREATE_INDEX_COUNT_METRIC) + if (autoRefresh) { + MetricsUtil.incrementCounter(MetricConstants.QUERY_CREATE_INDEX_AUTO_REFRESH_COUNT_METRIC) + } else { + MetricsUtil.incrementCounter(MetricConstants.QUERY_CREATE_INDEX_MANUAL_REFRESH_COUNT_METRIC) + } + } + + def emitRefreshIndexMetric(): Unit = { + MetricsUtil.incrementCounter(MetricConstants.QUERY_REFRESH_COUNT_METRIC) + } + + def emitAlterIndexMetric(): Unit = { + MetricsUtil.incrementCounter(MetricConstants.QUERY_ALTER_COUNT_METRIC) + } + + def emitDropIndexMetric(): Unit = { + MetricsUtil.incrementCounter(MetricConstants.QUERY_DROP_COUNT_METRIC) + } + + def emitVacuumIndexMetric(): Unit = { + MetricsUtil.incrementCounter(MetricConstants.QUERY_VACUUM_COUNT_METRIC) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala index 4a8f9018e..35a780020 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala @@ -6,9 +6,10 @@ package org.opensearch.flint.spark.sql.covering import org.antlr.v4.runtime.tree.RuleNode +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex -import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} +import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, IndexMetricHelper, SparkSqlAstBuilder} import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ @@ -20,7 +21,9 @@ import org.apache.spark.sql.types.StringType /** * Flint Spark AST builder that builds Spark command for Flint covering index statement. */ -trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[AnyRef] { +trait FlintSparkCoveringIndexAstBuilder + extends FlintSparkSqlExtensionsVisitor[AnyRef] + with IndexMetricHelper { self: SparkSqlAstBuilder => override def visitCreateCoveringIndexStatement( @@ -49,6 +52,8 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A .options(indexOptions, indexName) .create(ignoreIfExists) + emitCreateIndexMetric(indexOptions.autoRefresh()) + // Trigger auto refresh if enabled and not using external scheduler if (indexOptions .autoRefresh() && !indexBuilder.isExternalSchedulerEnabled()) { @@ -62,6 +67,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitRefreshCoveringIndexStatement( ctx: RefreshCoveringIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => + MetricsUtil.incrementCounter(MetricConstants.QUERY_REFRESH_COUNT_METRIC) val flintIndexName = getFlintIndexName(flint, ctx.indexName, ctx.tableName) flint.refreshIndex(flintIndexName) Seq.empty @@ -107,6 +113,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitAlterCoveringIndexStatement( ctx: AlterCoveringIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitAlterIndexMetric() val indexName = getFlintIndexName(flint, ctx.indexName, ctx.tableName) val indexOptions = visitPropertyList(ctx.propertyList()) val index = flint @@ -121,6 +128,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitDropCoveringIndexStatement( ctx: DropCoveringIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitDropIndexMetric() val flintIndexName = getFlintIndexName(flint, ctx.indexName, ctx.tableName) flint.deleteIndex(flintIndexName) Seq.empty @@ -130,6 +138,7 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitVacuumCoveringIndexStatement( ctx: VacuumCoveringIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitVacuumIndexMetric() val flintIndexName = getFlintIndexName(flint, ctx.indexName, ctx.tableName) flint.vacuumIndex(flintIndexName) Seq.empty diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala index 8f3aa9917..9c8d2da0b 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala @@ -10,7 +10,7 @@ import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` import org.antlr.v4.runtime.tree.RuleNode import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.mv.FlintSparkMaterializedView -import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} +import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, IndexMetricHelper, SparkSqlAstBuilder} import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText, IndexBelongsTo} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ @@ -22,7 +22,9 @@ import org.apache.spark.sql.types.StringType /** * Flint Spark AST builder that builds Spark command for Flint materialized view statement. */ -trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisitor[AnyRef] { +trait FlintSparkMaterializedViewAstBuilder + extends FlintSparkSqlExtensionsVisitor[AnyRef] + with IndexMetricHelper { self: SparkSqlAstBuilder => override def visitCreateMaterializedViewStatement( @@ -40,6 +42,8 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito val indexOptions = visitPropertyList(ctx.propertyList()) val flintIndexName = getFlintIndexName(flint, ctx.mvName) + emitCreateIndexMetric(indexOptions.autoRefresh()) + mvBuilder .options(indexOptions, flintIndexName) .create(ignoreIfExists) @@ -56,6 +60,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito override def visitRefreshMaterializedViewStatement( ctx: RefreshMaterializedViewStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitRefreshIndexMetric() val flintIndexName = getFlintIndexName(flint, ctx.mvName) flint.refreshIndex(flintIndexName) Seq.empty @@ -106,6 +111,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito override def visitAlterMaterializedViewStatement( ctx: AlterMaterializedViewStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitAlterIndexMetric() val indexName = getFlintIndexName(flint, ctx.mvName) val indexOptions = visitPropertyList(ctx.propertyList()) val index = flint @@ -120,6 +126,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito override def visitDropMaterializedViewStatement( ctx: DropMaterializedViewStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitDropIndexMetric() flint.deleteIndex(getFlintIndexName(flint, ctx.mvName)) Seq.empty } @@ -128,6 +135,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito override def visitVacuumMaterializedViewStatement( ctx: VacuumMaterializedViewStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitVacuumIndexMetric() flint.vacuumIndex(getFlintIndexName(flint, ctx.mvName)) Seq.empty } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala index 67f6bc3d4..9ed06f6b0 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala @@ -14,7 +14,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, MIN_MAX, PARTITION, VALUE_SET} import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.VALUE_SET_MAX_SIZE_KEY -import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} +import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, IndexMetricHelper, SparkSqlAstBuilder} import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ @@ -26,7 +26,9 @@ import org.apache.spark.sql.types.StringType /** * Flint Spark AST builder that builds Spark command for Flint skipping index statement. */ -trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[AnyRef] { +trait FlintSparkSkippingIndexAstBuilder + extends FlintSparkSqlExtensionsVisitor[AnyRef] + with IndexMetricHelper { self: SparkSqlAstBuilder => override def visitCreateSkippingIndexStatement( @@ -73,6 +75,8 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A val indexOptions = visitPropertyList(ctx.propertyList()) val indexName = getSkippingIndexName(flint, ctx.tableName) + emitCreateIndexMetric(indexOptions.autoRefresh()) + indexBuilder .options(indexOptions, indexName) .create(ignoreIfExists) @@ -88,6 +92,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitRefreshSkippingIndexStatement( ctx: RefreshSkippingIndexStatementContext): Command = FlintSparkSqlCommand() { flint => + emitRefreshIndexMetric() val indexName = getSkippingIndexName(flint, ctx.tableName) flint.refreshIndex(indexName) Seq.empty @@ -115,6 +120,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitAlterSkippingIndexStatement( ctx: AlterSkippingIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitAlterIndexMetric() val indexName = getSkippingIndexName(flint, ctx.tableName) val indexOptions = visitPropertyList(ctx.propertyList()) val index = flint @@ -142,6 +148,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitDropSkippingIndexStatement(ctx: DropSkippingIndexStatementContext): Command = FlintSparkSqlCommand() { flint => + emitDropIndexMetric() val indexName = getSkippingIndexName(flint, ctx.tableName) flint.deleteIndex(indexName) Seq.empty @@ -150,6 +157,7 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitVacuumSkippingIndexStatement( ctx: VacuumSkippingIndexStatementContext): Command = { FlintSparkSqlCommand() { flint => + emitVacuumIndexMetric() val indexName = getSkippingIndexName(flint, ctx.tableName) flint.vacuumIndex(indexName) Seq.empty diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 1c9a9e83c..78d2eb09e 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -64,7 +64,9 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { metadata.kind shouldBe MV_INDEX_TYPE metadata.source shouldBe "SELECT 1" metadata.properties should contain key "sourceTables" - metadata.properties.get("sourceTables").asInstanceOf[Array[String]] should have size 0 + metadata.properties + .get("sourceTables") + .asInstanceOf[java.util.ArrayList[String]] should have size 0 metadata.indexedColumns shouldBe Array( Map("columnName" -> "test_col", "columnType" -> "integer").asJava) metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava diff --git a/integ-test/script/README.md b/integ-test/script/README.md new file mode 100644 index 000000000..7ce0c6886 --- /dev/null +++ b/integ-test/script/README.md @@ -0,0 +1,158 @@ +# Sanity Test Script + +### Description +This Python script executes test queries from a CSV file using an asynchronous query API and generates comprehensive test reports. + +The script produces two report types: +1. An Excel report with detailed test information for each query +2. A JSON report containing both test result overview and query-specific details + +Apart from the basic feature, it also has some advanced functionality includes: +1. Concurrent query execution (note: the async query service has session limits, so use thread workers moderately despite it already supports session ID reuse) +2. Configurable query timeout with periodic status checks and automatic cancellation if timeout occurs. +3. Flexible row selection from the input CSV file, by specifying start row and end row of the input CSV file. +4. Expected status validation when expected_status is present in the CSV +5. Ability to generate partial reports if testing is interrupted + +### Usage +To use this script, you need to have Python **3.6** or higher installed. It also requires the following Python libraries: +```shell +pip install requests pandas openpyxl +``` + +After getting the requisite libraries, you can run the script with the following command line parameters in your shell: +```shell +python SanityTest.py --base-url ${URL_ADDRESS} --username *** --password *** --datasource ${DATASOURCE_NAME} --input-csv test_cases.csv --output-file test_report --max-workers 2 --check-interval 10 --timeout 600 +``` +You need to replace the placeholders with your actual values of URL_ADDRESS, DATASOURCE_NAME and USERNAME, PASSWORD for authentication to your endpoint. + +For more details of the command line parameters, you can see the help manual via command: +```shell +python SanityTest.py --help + +usage: SanityTest.py [-h] --base-url BASE_URL --username USERNAME --password PASSWORD --datasource DATASOURCE --input-csv INPUT_CSV + --output-file OUTPUT_FILE [--max-workers MAX_WORKERS] [--check-interval CHECK_INTERVAL] [--timeout TIMEOUT] + [--start-row START_ROW] [--end-row END_ROW] + +Run tests from a CSV file and generate a report. + +options: + -h, --help show this help message and exit + --base-url BASE_URL Base URL of the service + --username USERNAME Username for authentication + --password PASSWORD Password for authentication + --datasource DATASOURCE + Datasource name + --input-csv INPUT_CSV + Path to the CSV file containing test queries + --output-file OUTPUT_FILE + Path to the output report file + --max-workers MAX_WORKERS + optional, Maximum number of worker threads (default: 2) + --check-interval CHECK_INTERVAL + optional, Check interval in seconds (default: 10) + --timeout TIMEOUT optional, Timeout in seconds (default: 600) + --start-row START_ROW + optional, The start row of the query to run, start from 1 + --end-row END_ROW optional, The end row of the query to run, not included + --log-level LOG_LEVEL + optional, Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL, default: INFO) +``` + +### Input CSV File +As claimed in the description, the input CSV file should at least have the column of `query` to run the tests. It also supports an optional column of `expected_status`, the script will check the actual status against the expected status and generate a new column of `check_status` for the check result -- TRUE means the status check passed; FALSE means the status check failed. + +We also provide a sample input CSV file `test_cases.csv` for reference. It includes all sanity test cases we have currently in the Flint. + +**TODO**: the prerequisite data of the test cases and ingesting process + +### Report Explanation +The generated report contains two files: + +#### Excel Report +The Excel report provides the test result details of each query, including the query name(i.e. sequence number in the input csv file currently), query itself, expected status, actual status, and whether the status satisfy the expected status or not. + +It provides an error message if the query execution failed, otherwise it provides the query execution result with empty error. + +It also provides the query_id, session_id and start/end time for each query, which can be used to debug the query execution in the Flint. + +An example of Excel report: + +| query_name | query | expected_status | status | check_status | error | result | Duration (s) | query_id | session_id | Start Time | End Time | +|------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------|---------|--------------|------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------|-------------------------------|------------------------------|----------------------|---------------------| +| 1 | describe myglue_test.default.http_logs | SUCCESS | SUCCESS | TRUE | | {'status': 'SUCCESS', 'schema': [{...}, ...], 'datarows': [[...], ...], 'total': 31, 'size': 31} | 37.51 | SHFEVWxDNnZjem15Z2x1ZV90ZXN0 | RkgzZm0xNlA5MG15Z2x1ZV90ZXN0 | 2024-11-07 13:34:10 | 2024-11-07 13:34:47 | +| 2 | source = myglue_test.default.http_logs \| dedup status CONSECUTIVE=true | SUCCESS | FAILED | FALSE | {"Message":"Fail to run query. Cause: Consecutive deduplication is not supported"} | | 39.53 | dVNlaVVxOFZrZW15Z2x1ZV90ZXN0 | ZGU2MllVYmI4dG15Z2x1ZV90ZXN0 | 2024-11-07 13:34:10 | 2024-11-07 13:34:49 | +| 3 | source = myglue_test.default.http_logs \| eval res = json_keys(json('{"account_number":1,"balance":39225,"age":32,"gender":"M"}')) \| head 1 \| fields res | SUCCESS | SUCCESS | TRUE | | {'status': 'SUCCESS', 'schema': [{'name': 'res', 'type': 'array'}], 'datarows': [[['account_number', 'balance', 'age', 'gender']]], 'total': 1, 'size': 1} | 12.77 | WHQxaXlVSGtGUm15Z2x1ZV90ZXN0 | RkgzZm0xNlA5MG15Z2x1ZV90ZXN0 | 2024-11-07 13:34:47 | 2024-11-07 13:38:45 | +| ... | ... | ... | ... | ... | | | ... | ... | ... | ... | ... | + + +#### JSON Report +The JSON report provides the same information as the Excel report, but in JSON format.Additionally, it includes a statistical summary of the test results at the beginning of the report. + +An example of JSON report: +```json +{ + "summary": { + "total_queries": 115, + "successful_queries": 110, + "failed_queries": 3, + "submit_failed_queries": 0, + "timeout_queries": 2, + "execution_time": 16793.223807 + }, + "detailed_results": [ + { + "query_name": 1, + "query": "source = myglue_test.default.http_logs | stats avg(size)", + "query_id": "eFZmTlpTa3EyTW15Z2x1ZV90ZXN0", + "session_id": "bFJDMWxzb2NVUm15Z2x1ZV90ZXN0", + "status": "SUCCESS", + "error": "", + "result": { + "status": "SUCCESS", + "schema": [ + { + "name": "avg(size)", + "type": "double" + } + ], + "datarows": [ + [ + 4654.305710913499 + ] + ], + "total": 1, + "size": 1 + }, + "duration": 170.621145, + "start_time": "2024-11-07 14:56:13.869226", + "end_time": "2024-11-07 14:59:04.490371" + }, + { + "query_name": 2, + "query": "source = myglue_test.default.http_logs | eval res = json_keys(json(\u2018{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}')) | head 1 | fields res", + "query_id": "bjF4Y1VnbXdFYm15Z2x1ZV90ZXN0", + "session_id": "c3pvU1V6OW8xM215Z2x1ZV90ZXN0", + "status": "FAILED", + "error": "{\"Message\":\"Syntax error: \\n[PARSE_SYNTAX_ERROR] Syntax error at or near 'source'.(line 1, pos 0)\\n\\n== SQL ==\\nsource = myglue_test.default.http_logs | eval res = json_keys(json(\u2018{\\\"teacher\\\":\\\"Alice\\\",\\\"student\\\":[{\\\"name\\\":\\\"Bob\\\",\\\"rank\\\":1},{\\\"name\\\":\\\"Charlie\\\",\\\"rank\\\":2}]}')) | head 1 | fields res\\n^^^\\n\"}", + "result": null, + "duration": 14.051738, + "start_time": "2024-11-07 14:59:18.699335", + "end_time": "2024-11-07 14:59:32.751073" + }, + { + "query_name": 2, + "query": "source = myglue_test.default.http_logs | eval col1 = size, col2 = clientip | stats avg(col1) by col2", + "query_id": "azVyMFFORnBFRW15Z2x1ZV90ZXN0", + "session_id": "VWF0SEtrNWM3bm15Z2x1ZV90ZXN0", + "status": "TIMEOUT", + "error": "Query execution exceeded 600 seconds with last status: running", + "result": null, + "duration": 673.710946, + "start_time": "2024-11-07 14:45:00.157875", + "end_time": "2024-11-07 14:56:13.868821" + }, + ... + ] +} +``` diff --git a/integ-test/script/SanityTest.py b/integ-test/script/SanityTest.py new file mode 100644 index 000000000..eb97752b4 --- /dev/null +++ b/integ-test/script/SanityTest.py @@ -0,0 +1,292 @@ +""" +Copyright OpenSearch Contributors +SPDX-License-Identifier: Apache-2.0 +""" + +import signal +import sys +import requests +import json +import csv +import time +import logging +from datetime import datetime +import pandas as pd +import argparse +from requests.auth import HTTPBasicAuth +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading + +""" +Environment: python3 + +Example to use this script: + +python SanityTest.py --base-url ${URL_ADDRESS} --username *** --password *** --datasource ${DATASOURCE_NAME} --input-csv test_queries.csv --output-file test_report --max-workers 2 --check-interval 10 --timeout 600 + +The input file test_queries.csv should contain column: `query` + +For more details, please use command: + +python SanityTest.py --help + +""" + +class FlintTester: + def __init__(self, base_url, username, password, datasource, max_workers, check_interval, timeout, output_file, start_row, end_row, log_level): + self.base_url = base_url + self.auth = HTTPBasicAuth(username, password) + self.datasource = datasource + self.headers = { 'Content-Type': 'application/json' } + self.max_workers = max_workers + self.check_interval = check_interval + self.timeout = timeout + self.output_file = output_file + self.start = start_row - 1 if start_row else None + self.end = end_row - 1 if end_row else None + self.log_level = log_level + self.max_attempts = (int)(timeout / check_interval) + self.logger = self._setup_logger() + self.executor = ThreadPoolExecutor(max_workers=self.max_workers) + self.thread_local = threading.local() + self.test_results = [] + + def _setup_logger(self): + logger = logging.getLogger('FlintTester') + logger.setLevel(self.log_level) + + fh = logging.FileHandler('flint_test.log') + fh.setLevel(self.log_level) + + ch = logging.StreamHandler() + ch.setLevel(self.log_level) + + formatter = logging.Formatter( + '%(asctime)s - %(threadName)s - %(levelname)s - %(message)s' + ) + fh.setFormatter(formatter) + ch.setFormatter(formatter) + + logger.addHandler(fh) + logger.addHandler(ch) + + return logger + + + def get_session_id(self): + if not hasattr(self.thread_local, 'session_id'): + self.thread_local.session_id = "empty_session_id" + self.logger.debug(f"get session id {self.thread_local.session_id}") + return self.thread_local.session_id + + def set_session_id(self, session_id): + """Reuse the session id for the same thread""" + self.logger.debug(f"set session id {session_id}") + self.thread_local.session_id = session_id + + # Call submit API to submit the query + def submit_query(self, query, session_id="Empty"): + url = f"{self.base_url}/_plugins/_async_query" + payload = { + "datasource": self.datasource, + "lang": "ppl", + "query": query, + "sessionId": session_id + } + self.logger.debug(f"Submit query with payload: {payload}") + response_json = None + try: + response = requests.post(url, auth=self.auth, json=payload, headers=self.headers) + response_json = response.json() + response.raise_for_status() + return response_json + except Exception as e: + return {"error": f"{str(e)}, got response {response_json}"} + + # Call get API to check the query status + def get_query_result(self, query_id): + url = f"{self.base_url}/_plugins/_async_query/{query_id}" + response_json = None + try: + response = requests.get(url, auth=self.auth) + response_json = response.json() + response.raise_for_status() + return response_json + except Exception as e: + return {"status": "FAILED", "error": f"{str(e)}, got response {response_json}"} + + # Call delete API to cancel the query + def cancel_query(self, query_id): + url = f"{self.base_url}/_plugins/_async_query/{query_id}" + response_json = None + try: + response = requests.delete(url, auth=self.auth) + response_json = response.json() + response.raise_for_status() + self.logger.info(f"Cancelled query [{query_id}] with info {response.json()}") + return response_json + except Exception as e: + self.logger.warning(f"Cancel query [{query_id}] error: {str(e)}, got response {response_json}") + + # Run the test and return the result + def run_test(self, query, seq_id, expected_status): + self.logger.info(f"Starting test: {seq_id}, {query}") + start_time = datetime.now() + pre_session_id = self.get_session_id() + submit_result = self.submit_query(query, pre_session_id) + if "error" in submit_result: + self.logger.warning(f"Submit error: {submit_result}") + return { + "query_name": seq_id, + "query": query, + "expected_status": expected_status, + "status": "SUBMIT_FAILED", + "check_status": "SUBMIT_FAILED" == expected_status if expected_status else None, + "error": submit_result["error"], + "duration": 0, + "start_time": start_time, + "end_time": datetime.now() + } + + query_id = submit_result["queryId"] + session_id = submit_result["sessionId"] + self.logger.info(f"Submit return: {submit_result}") + if (session_id != pre_session_id): + self.logger.info(f"Update session id from {pre_session_id} to {session_id}") + self.set_session_id(session_id) + + test_result = self.check_query_status(query_id) + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() + + return { + "query_name": seq_id, + "query": query, + "query_id": query_id, + "session_id": session_id, + "expected_status": expected_status, + "status": test_result["status"], + "check_status": test_result["status"] == expected_status if expected_status else None, + "error": test_result.get("error", ""), + "result": test_result if test_result["status"] == "SUCCESS" else None, + "duration": duration, + "start_time": start_time, + "end_time": end_time + } + + # Check the status of the query periodically until it is completed or failed or exceeded the timeout + def check_query_status(self, query_id): + query_id = query_id + + for attempt in range(self.max_attempts): + time.sleep(self.check_interval) + result = self.get_query_result(query_id) + + if result["status"] == "FAILED" or result["status"] == "SUCCESS": + return result + + # Cancel the query if it exceeds the timeout + self.cancel_query(query_id) + return { + "status": "TIMEOUT", + "error": "Query execution exceeded " + str(self.timeout) + " seconds with last status: " + result["status"], + } + + def run_tests_from_csv(self, csv_file): + with open(csv_file, 'r') as f: + reader = csv.DictReader(f) + queries = [(row['query'], i, row.get('expected_status', None)) for i, row in enumerate(reader, start=1) if row['query'].strip()] + + # Filtering queries based on start and end + queries = queries[self.start:self.end] + + # Parallel execution + futures = [self.executor.submit(self.run_test, query, seq_id, expected_status) for query, seq_id, expected_status in queries] + for future in as_completed(futures): + result = future.result() + self.logger.info(f"Completed test: {result["query_name"]}, {result["query"]}, got result status: {result["status"]}") + self.test_results.append(result) + + def generate_report(self): + self.logger.info("Generating report...") + total_queries = len(self.test_results) + successful_queries = sum(1 for r in self.test_results if r['status'] == 'SUCCESS') + failed_queries = sum(1 for r in self.test_results if r['status'] == 'FAILED') + submit_failed_queries = sum(1 for r in self.test_results if r['status'] == 'SUBMIT_FAILED') + timeout_queries = sum(1 for r in self.test_results if r['status'] == 'TIMEOUT') + + # Create report + report = { + "summary": { + "total_queries": total_queries, + "successful_queries": successful_queries, + "failed_queries": failed_queries, + "submit_failed_queries": submit_failed_queries, + "timeout_queries": timeout_queries, + "execution_time": sum(r['duration'] for r in self.test_results) + }, + "detailed_results": self.test_results + } + + # Save report to JSON file + with open(f"{self.output_file}.json", 'w') as f: + json.dump(report, f, indent=2, default=str) + + # Save reults to Excel file + df = pd.DataFrame(self.test_results) + df.to_excel(f"{self.output_file}.xlsx", index=False) + + self.logger.info(f"Generated report in {self.output_file}.xlsx and {self.output_file}.json") + +def signal_handler(sig, frame, tester): + print(f"Signal {sig} received, generating report...") + try: + tester.executor.shutdown(wait=False, cancel_futures=True) + tester.generate_report() + finally: + sys.exit(0) + +def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description="Run tests from a CSV file and generate a report.") + parser.add_argument("--base-url", required=True, help="Base URL of the service") + parser.add_argument("--username", required=True, help="Username for authentication") + parser.add_argument("--password", required=True, help="Password for authentication") + parser.add_argument("--datasource", required=True, help="Datasource name") + parser.add_argument("--input-csv", required=True, help="Path to the CSV file containing test queries") + parser.add_argument("--output-file", required=True, help="Path to the output report file") + parser.add_argument("--max-workers", type=int, default=2, help="optional, Maximum number of worker threads (default: 2)") + parser.add_argument("--check-interval", type=int, default=5, help="optional, Check interval in seconds (default: 5)") + parser.add_argument("--timeout", type=int, default=600, help="optional, Timeout in seconds (default: 600)") + parser.add_argument("--start-row", type=int, default=None, help="optional, The start row of the query to run, start from 1") + parser.add_argument("--end-row", type=int, default=None, help="optional, The end row of the query to run, not included") + parser.add_argument("--log-level", default="INFO", help="optional, Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL, default: INFO)") + + args = parser.parse_args() + + tester = FlintTester( + base_url=args.base_url, + username=args.username, + password=args.password, + datasource=args.datasource, + max_workers=args.max_workers, + check_interval=args.check_interval, + timeout=args.timeout, + output_file=args.output_file, + start_row=args.start_row, + end_row=args.end_row, + log_level=args.log_level, + ) + + # Register signal handlers to generate report on interrupt + signal.signal(signal.SIGINT, lambda sig, frame: signal_handler(sig, frame, tester)) + signal.signal(signal.SIGTERM, lambda sig, frame: signal_handler(sig, frame, tester)) + + # Running tests + tester.run_tests_from_csv(args.input_csv) + + # Gnerate report + tester.generate_report() + +if __name__ == "__main__": + main() diff --git a/integ-test/script/test_cases.csv b/integ-test/script/test_cases.csv new file mode 100644 index 000000000..7df05f5a3 --- /dev/null +++ b/integ-test/script/test_cases.csv @@ -0,0 +1,567 @@ +query,expected_status +describe myglue_test.default.http_logs,FAILED +describe `myglue_test`.`default`.`http_logs`,FAILED +"source = myglue_test.default.http_logs | dedup 1 status | fields @timestamp, clientip, status, size | head 10",SUCCESS +"source = myglue_test.default.http_logs | dedup status, size | head 10",SUCCESS +source = myglue_test.default.http_logs | dedup 1 status keepempty=true | head 10,SUCCESS +"source = myglue_test.default.http_logs | dedup status, size keepempty=true | head 10",SUCCESS +source = myglue_test.default.http_logs | dedup 2 status | head 10,SUCCESS +"source = myglue_test.default.http_logs | dedup 2 status, size | head 10",SUCCESS +"source = myglue_test.default.http_logs | dedup 2 status, size keepempty=true | head 10",SUCCESS +source = myglue_test.default.http_logs | dedup status CONSECUTIVE=true | fields status,FAILED +"source = myglue_test.default.http_logs | dedup 2 status, size CONSECUTIVE=true | fields status",FAILED +"source = myglue_test.default.http_logs | sort stat | fields @timestamp, clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | fields @timestamp, notexisted | head 10",FAILED +"source = myglue_test.default.nested | fields int_col, struct_col.field1, struct_col2.field1 | head 10",FAILED +"source = myglue_test.default.nested | where struct_col2.field1.subfield > 'valueA' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield",FAILED +"source = myglue_test.default.http_logs | fields - @timestamp, clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval new_time = @timestamp, new_clientip = clientip | fields - new_time, new_clientip, status | head 10",SUCCESS +source = myglue_test.default.http_logs | eval new_clientip = lower(clientip) | fields - new_clientip | head 10,SUCCESS +"source = myglue_test.default.http_logs | fields + @timestamp, clientip, status | fields - clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | fields - clientip, status | fields + @timestamp, clientip, status| head 10",SUCCESS +source = myglue_test.default.http_logs | where status = 200 | head 10,SUCCESS +source = myglue_test.default.http_logs | where status != 200 | head 10,SUCCESS +source = myglue_test.default.http_logs | where size > 0 | head 10,SUCCESS +source = myglue_test.default.http_logs | where size <= 0 | head 10,SUCCESS +source = myglue_test.default.http_logs | where clientip = '236.14.2.0' | head 10,SUCCESS +source = myglue_test.default.http_logs | where size > 0 AND status = 200 OR clientip = '236.14.2.0' | head 100,SUCCESS +"source = myglue_test.default.http_logs | where size <= 0 AND like(request, 'GET%') | head 10",SUCCESS +source = myglue_test.default.http_logs status = 200 | head 10,SUCCESS +source = myglue_test.default.http_logs size > 0 AND status = 200 OR clientip = '236.14.2.0' | head 100,SUCCESS +"source = myglue_test.default.http_logs size <= 0 AND like(request, 'GET%') | head 10",SUCCESS +"source = myglue_test.default.http_logs substring(clientip, 5, 2) = ""12"" | head 10",SUCCESS +source = myglue_test.default.http_logs | where isempty(size),FAILED +source = myglue_test.default.http_logs | where ispresent(size),FAILED +source = myglue_test.default.http_logs | where isnull(size) | head 10,SUCCESS +source = myglue_test.default.http_logs | where isnotnull(size) | head 10,SUCCESS +"source = myglue_test.default.http_logs | where isnotnull(coalesce(size, status)) | head 10",FAILED +"source = myglue_test.default.http_logs | where like(request, 'GET%') | head 10",SUCCESS +"source = myglue_test.default.http_logs | where like(request, '%bordeaux%') | head 10",SUCCESS +"source = myglue_test.default.http_logs | where substring(clientip, 5, 2) = ""12"" | head 10",SUCCESS +"source = myglue_test.default.http_logs | where lower(request) = ""get /images/backnews.gif http/1.0"" | head 10",SUCCESS +source = myglue_test.default.http_logs | where length(request) = 38 | head 10,SUCCESS +"source = myglue_test.default.http_logs | where case(status = 200, 'success' else 'failed') = 'success' | head 10",FAILED +"source = myglue_test.default.http_logs | eval h = ""Hello"", w = ""World"" | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval @h = ""Hello"" | eval @w = ""World"" | fields @timestamp, @h, @w",SUCCESS +source = myglue_test.default.http_logs | eval newF = clientip | head 10,SUCCESS +"source = myglue_test.default.http_logs | eval newF = clientip | fields clientip, newF | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval f = size | where f > 1 | sort f | fields size, clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval f = status * 2 | eval h = f * 2 | fields status, f, h | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval f = size * 2, h = status | stats sum(f) by h",SUCCESS +"source = myglue_test.default.http_logs | eval f = UPPER(request) | eval h = 40 | fields f, h | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval request = ""test"" | fields request | head 10",FAILED +source = myglue_test.default.http_logs | eval size = abs(size) | where size < 500,FAILED +"source = myglue_test.default.http_logs | eval status_string = case(status = 200, 'success' else 'failed') | head 10",FAILED +"source = myglue_test.default.http_logs | eval n = now() | eval t = unix_timestamp(@timestamp) | fields n, t | head 10",SUCCESS +source = myglue_test.default.http_logs | eval e = isempty(size) | eval p = ispresent(size) | head 10,FAILED +"source = myglue_test.default.http_logs | eval c = coalesce(size, status) | head 10",FAILED +source = myglue_test.default.http_logs | eval c = coalesce(request) | head 10,FAILED +source = myglue_test.default.http_logs | eval col1 = ln(size) | eval col2 = unix_timestamp(@timestamp) | sort - col1 | head 10,SUCCESS +"source = myglue_test.default.http_logs | eval col1 = 1 | sort col1 | head 4 | eval col2 = 2 | sort - col2 | sort - size | head 2 | fields @timestamp, clientip, col2",SUCCESS +"source = myglue_test.default.mini_http_logs | eval stat = status | where stat > 300 | sort stat | fields @timestamp,clientip,status | head 5",SUCCESS +"source = myglue_test.default.http_logs | eval col1 = size, col2 = clientip | stats avg(col1) by col2",SUCCESS +source = myglue_test.default.http_logs | stats avg(size) by clientip,SUCCESS +"source = myglue_test.default.http_logs | eval new_request = upper(request) | eval compound_field = concat('Hello ', if(like(new_request, '%bordeaux%'), 'World', clientip)) | fields new_request, compound_field | head 10",SUCCESS +source = myglue_test.default.http_logs | stats avg(size),SUCCESS +source = myglue_test.default.nested | stats max(int_col) by struct_col.field2,SUCCESS +source = myglue_test.default.nested | stats distinct_count(int_col),SUCCESS +source = myglue_test.default.nested | stats stddev_samp(int_col),SUCCESS +source = myglue_test.default.nested | stats stddev_pop(int_col),SUCCESS +source = myglue_test.default.nested | stats percentile(int_col),SUCCESS +source = myglue_test.default.nested | stats percentile_approx(int_col),SUCCESS +source = myglue_test.default.mini_http_logs | stats stddev_samp(status),SUCCESS +"source = myglue_test.default.mini_http_logs | where stats > 200 | stats percentile_approx(status, 99)",SUCCESS +"source = myglue_test.default.nested | stats count(int_col) by span(struct_col.field2, 10) as a_span",SUCCESS +"source = myglue_test.default.nested | stats avg(int_col) by span(struct_col.field2, 10) as a_span, struct_col2.field2",SUCCESS +"source = myglue_test.default.http_logs | stats sum(size) by span(@timestamp, 1d) as age_size_per_day | sort - age_size_per_day | head 10",SUCCESS +"source = myglue_test.default.http_logs | stats distinct_count(clientip) by span(@timestamp, 1d) as age_size_per_day | sort - age_size_per_day | head 10",SUCCESS +"source = myglue_test.default.http_logs | stats avg(size) as avg_size by status, year | stats avg(avg_size) as avg_avg_size by year",SUCCESS +"source = myglue_test.default.http_logs | stats avg(size) as avg_size by status, year, month | stats avg(avg_size) as avg_avg_size by year, month | stats avg(avg_avg_size) as avg_avg_avg_size by year",SUCCESS +"source = myglue_test.default.nested | stats avg(int_col) as avg_int by struct_col.field2, struct_col2.field2 | stats avg(avg_int) as avg_avg_int by struct_col2.field2",FAILED +"source = myglue_test.default.nested | stats avg(int_col) as avg_int by struct_col.field2, struct_col2.field2 | eval new_col = avg_int | stats avg(avg_int) as avg_avg_int by new_col",SUCCESS +source = myglue_test.default.nested | rare int_col,SUCCESS +source = myglue_test.default.nested | rare int_col by struct_col.field2,SUCCESS +source = myglue_test.default.http_logs | rare request,SUCCESS +source = myglue_test.default.http_logs | where status > 300 | rare request by status,SUCCESS +source = myglue_test.default.http_logs | rare clientip,SUCCESS +source = myglue_test.default.http_logs | where status > 300 | rare clientip,SUCCESS +source = myglue_test.default.http_logs | where status > 300 | rare clientip by day,SUCCESS +source = myglue_test.default.nested | top int_col by struct_col.field2,SUCCESS +source = myglue_test.default.nested | top 1 int_col by struct_col.field2,SUCCESS +source = myglue_test.default.nested | top 2 int_col by struct_col.field2,SUCCESS +source = myglue_test.default.nested | top int_col,SUCCESS +source = myglue_test.default.http_logs | inner join left=l right=r on l.status = r.int_col myglue_test.default.nested | head 10,FAILED +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | fields request, domain | head 10",SUCCESS +source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | top 1 domain,SUCCESS +source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | stats count() by domain,SUCCESS +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | eval a = 1 | fields a, domain | head 10",SUCCESS +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | where size > 0 | sort - size | fields size, domain | head 10",SUCCESS +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/(?[a-zA-Z]+)/.*' | where domain = 'english' | sort - picName | fields domain, picName | head 10",SUCCESS +source = myglue_test.default.http_logs | patterns request | fields patterns_field | head 10,SUCCESS +source = myglue_test.default.http_logs | patterns request | where size > 0 | fields patterns_field | head 10,SUCCESS +"source = myglue_test.default.http_logs | patterns new_field='no_letter' pattern='[a-zA-Z]' request | fields request, no_letter | head 10",SUCCESS +source = myglue_test.default.http_logs | patterns new_field='no_letter' pattern='[a-zA-Z]' request | stats count() by no_letter,SUCCESS +"source = myglue_test.default.http_logs | patterns new_field='status' pattern='[a-zA-Z]' request | fields request, status | head 10",FAILED +source = myglue_test.default.http_logs | rename @timestamp as timestamp | head 10,FAILED +source = myglue_test.default.http_logs | sort size | head 10,SUCCESS +source = myglue_test.default.http_logs | sort + size | head 10,SUCCESS +source = myglue_test.default.http_logs | sort - size | head 10,SUCCESS +"source = myglue_test.default.http_logs | sort + size, + @timestamp | head 10",SUCCESS +"source = myglue_test.default.http_logs | sort - size, - @timestamp | head 10",SUCCESS +"source = myglue_test.default.http_logs | sort - size, @timestamp | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = upper(request) | eval c2 = concat('Hello ', if(like(c1, '%bordeaux%'), 'World', clientip)) | eval c3 = length(request) | eval c4 = ltrim(request) | eval c5 = rtrim(request) | eval c6 = substring(clientip, 5, 2) | eval c7 = trim(request) | eval c8 = upper(request) | eval c9 = position('bordeaux' IN request) | eval c10 = replace(request, 'GET', 'GGG') | fields c1, c2, c3, c4, c5, c6, c7, c8, c9, c10 | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = unix_timestamp(@timestamp) | eval c2 = now() | eval c3 = +DAY_OF_WEEK(@timestamp) | eval c4 = +DAY_OF_MONTH(@timestamp) | eval c5 = +DAY_OF_YEAR(@timestamp) | eval c6 = +WEEK_OF_YEAR(@timestamp) | eval c7 = +WEEK(@timestamp) | eval c8 = +MONTH_OF_YEAR(@timestamp) | eval c9 = +HOUR_OF_DAY(@timestamp) | eval c10 = +MINUTE_OF_HOUR(@timestamp) | eval c11 = +SECOND_OF_MINUTE(@timestamp) | eval c12 = +LOCALTIME() | fields c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12 | head 10",SUCCESS +"source=myglue_test.default.people | eval c1 = adddate(@timestamp, 1) | fields c1 | head 10",SUCCESS +"source=myglue_test.default.people | eval c2 = subdate(@timestamp, 1) | fields c2 | head 10",SUCCESS +source=myglue_test.default.people | eval c1 = date_add(@timestamp INTERVAL 1 DAY) | fields c1 | head 10,SUCCESS +source=myglue_test.default.people | eval c1 = date_sub(@timestamp INTERVAL 1 DAY) | fields c1 | head 10,SUCCESS +source=myglue_test.default.people | eval `CURDATE()` = CURDATE() | fields `CURDATE()`,SUCCESS +source=myglue_test.default.people | eval `CURRENT_DATE()` = CURRENT_DATE() | fields `CURRENT_DATE()`,SUCCESS +source=myglue_test.default.people | eval `CURRENT_TIMESTAMP()` = CURRENT_TIMESTAMP() | fields `CURRENT_TIMESTAMP()`,SUCCESS +source=myglue_test.default.people | eval `DATE('2020-08-26')` = DATE('2020-08-26') | fields `DATE('2020-08-26')`,SUCCESS +source=myglue_test.default.people | eval `DATE(TIMESTAMP('2020-08-26 13:49:00'))` = DATE(TIMESTAMP('2020-08-26 13:49:00')) | fields `DATE(TIMESTAMP('2020-08-26 13:49:00'))`,SUCCESS +source=myglue_test.default.people | eval `DATE('2020-08-26 13:49')` = DATE('2020-08-26 13:49') | fields `DATE('2020-08-26 13:49')`,SUCCESS +"source=myglue_test.default.people | eval `DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS')` = DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS'), `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a')` = DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a') | fields `DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS')`, `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a')`",SUCCESS +"source=myglue_test.default.people | eval `'2000-01-02' - '2000-01-01'` = DATEDIFF(TIMESTAMP('2000-01-02 00:00:00'), TIMESTAMP('2000-01-01 23:59:59')), `'2001-02-01' - '2004-01-01'` = DATEDIFF(DATE('2001-02-01'), TIMESTAMP('2004-01-01 00:00:00')) | fields `'2000-01-02' - '2000-01-01'`, `'2001-02-01' - '2004-01-01'`", +source=myglue_test.default.people | eval `DAY(DATE('2020-08-26'))` = DAY(DATE('2020-08-26')) | fields `DAY(DATE('2020-08-26'))`, +source=myglue_test.default.people | eval `DAYNAME(DATE('2020-08-26'))` = DAYNAME(DATE('2020-08-26')) | fields `DAYNAME(DATE('2020-08-26'))`,FAILED +source=myglue_test.default.people | eval `CURRENT_TIMEZONE()` = CURRENT_TIMEZONE() | fields `CURRENT_TIMEZONE()`,SUCCESS +source=myglue_test.default.people | eval `UTC_TIMESTAMP()` = UTC_TIMESTAMP() | fields `UTC_TIMESTAMP()`,SUCCESS +"source=myglue_test.default.people | eval `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | eval `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))` = TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00')) | fields `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')`, `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))`",SUCCESS +"source=myglue_test.default.people | eval `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | eval `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` = TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | fields `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')`, `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')`",SUCCESS + source = myglue_test.default.http_logs | stats count(),SUCCESS +"source = myglue_test.default.http_logs | stats avg(size) as c1, max(size) as c2, min(size) as c3, sum(size) as c4, percentile(size, 50) as c5, stddev_pop(size) as c6, stddev_samp(size) as c7, distinct_count(size) as c8",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = abs(size) | eval c2 = ceil(size) | eval c3 = floor(size) | eval c4 = sqrt(size) | eval c5 = ln(size) | eval c6 = pow(size, 2) | eval c7 = mod(size, 2) | fields c1, c2, c3, c4, c5, c6, c7 | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = isnull(request) | eval c2 = isnotnull(request) | eval c3 = ifnull(request, +""Unknown"") | eval c4 = nullif(request, +""Unknown"") | eval c5 = isnull(size) | eval c6 = if(like(request, '%bordeaux%'), 'hello', 'world') | fields c1, c2, c3, c4, c5, c6 | head 10",SUCCESS +/* this is block comment */ source = myglue_test.tpch_csv.orders | head 1 // this is line comment,SUCCESS +"/* test in tpch q16, q18, q20 */ source = myglue_test.tpch_csv.orders | head 1 // add source=xx to avoid failure in automation",SUCCESS +"/* test in tpch q4, q21, q22 */ source = myglue_test.tpch_csv.orders | head 1",SUCCESS +"/* test in tpch q2, q11, q15, q17, q20, q22 */ source = myglue_test.tpch_csv.orders | head 1",SUCCESS +"/* test in tpch q7, q8, q9, q13, q15, q22 */ source = myglue_test.tpch_csv.orders | head 1",SUCCESS +/* lots of inner join tests in tpch */ source = myglue_test.tpch_csv.orders | head 1,SUCCESS +/* left join test in tpch q13 */ source = myglue_test.tpch_csv.orders | head 1,SUCCESS +"source = myglue_test.tpch_csv.orders + | right outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + myglue_test.tpch_csv.customer +| stats count(o_orderkey) as c_count by c_custkey +| sort - c_count",SUCCESS +"source = myglue_test.tpch_csv.orders + | full outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + myglue_test.tpch_csv.customer +| stats count(o_orderkey) as c_count by c_custkey +| sort - c_count",SUCCESS +"source = myglue_test.tpch_csv.customer +| semi join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| where c_mktsegment = 'BUILDING' + | sort - c_custkey +| head 10",SUCCESS +"source = myglue_test.tpch_csv.customer +| anti join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| where c_mktsegment = 'BUILDING' + | sort - c_custkey +| head 10",SUCCESS +"source = myglue_test.tpch_csv.supplier +| where like(s_comment, '%Customer%Complaints%') +| join ON s_nationkey > n_nationkey [ source = myglue_test.tpch_csv.nation | where n_name = 'SAUDI ARABIA' ] +| sort - s_name +| head 10",SUCCESS +"source = myglue_test.tpch_csv.supplier +| where like(s_comment, '%Customer%Complaints%') +| join [ source = myglue_test.tpch_csv.nation | where n_name = 'SAUDI ARABIA' ] +| sort - s_name +| head 10",SUCCESS +source=myglue_test.default.people | LOOKUP myglue_test.default.work_info uid AS id REPLACE department | stats distinct_count(department),SUCCESS +source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS id APPEND department | stats distinct_count(department),SUCCESS +source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS id REPLACE department AS country | stats distinct_count(country),SUCCESS +source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS id APPEND department AS country | stats distinct_count(country),SUCCESS +"source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uID AS id, name REPLACE department | stats distinct_count(department)",SUCCESS +"source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS ID, name APPEND department | stats distinct_count(department)",SUCCESS +"source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uID AS id, name | head 10",SUCCESS +"source = myglue_test.default.people | eval major = occupation | fields id, name, major, country, salary | LOOKUP myglue_test.default.work_info name REPLACE occupation AS major | stats distinct_count(major)",SUCCESS +"source = myglue_test.default.people | eval major = occupation | fields id, name, major, country, salary | LOOKUP myglue_test.default.work_info name APPEND occupation AS major | stats distinct_count(major)",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('{""f1"":""abc"",""f2"":{""f3"":""a"",""f4"":""b""}}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('[1,2,3,{""f1"":1,""f2"":[5,6]},4]') | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json('[]') | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json(‘{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('{""invalid"": ""json""') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('[1,2,3]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json(‘[1,2') | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json('[invalid json]') | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json('invalid json') | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json(null) | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array('this', 'is', 'a', 'string', 'array') | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json_array() | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array(1, 2, 0, -1, 1.1, -0.11) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array('this', 'is', 1.1, -0.11, true, false) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = array_length(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = array_length(json_array()) | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json_array_length('[]') | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array_length('[1,2,3,{""f1"":1,""f2"":[5,6]},4]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array_length('{\""key\"": 1}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array_length('[1,2') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', 'string_value')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', 123.45)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', true)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object(""a"", 1, ""b"", 2, ""c"", 3)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', array())) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', array(1, 2, 3))) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('outer', json_object('inner', 123.45))) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object(""array"", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | where json_valid(('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}') | head 1",SUCCESS +"source = myglue_test.default.http_logs | where not json_valid(('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}') | head 1",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('{""f1"":""abc"",""f2"":{""f3"":""a"",""f4"":""b""}}')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('[1,2,3,{""f1"":1,""f2"":[5,6]},4]')) | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json('[]')) | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json(‘{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('{""invalid"": ""json""')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('[1,2,3]')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('[1,2')) | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json('[invalid json]')) | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json('invalid json')) | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json(null)) | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.teacher') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[*]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[0]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[*].name') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[1].name') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[0].not_exist_key') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[10]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(json_object(""a"",1,""b"",-1),json_object(""a"",-1,""b"",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(json_object(""a"",1,""b"",-1),json_object(""a"",-1,""b"",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result",SUCCESS +source=myglue_test.default.people | eval age = salary | eventstats avg(age) | sort id | head 10,SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count | sort id | head 10",SUCCESS +source=myglue_test.default.people | eventstats avg(salary) by country | sort id | head 10,SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count +by span(age, 10) | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, country | sort id | head 10",SUCCESS +"source=myglue_test.default.people | where country != 'USA' | eventstats stddev_samp(salary), stddev_pop(salary), percentile_approx(salary, 60) by span(salary, 1000) as salary_span | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age by occupation, country | eventstats avg(avg_age) as avg_state_age by country | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eventstats distinct_count(salary) by span(salary, 1000) as age_span",FAILED +"source = myglue_test.tpch_csv.lineitem +| where l_shipdate <= subdate(date('1998-12-01'), 90) +| stats sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count() as count_order + by l_returnflag, l_linestatus +| sort l_returnflag, l_linestatus",SUCCESS +"source = myglue_test.tpch_csv.part +| join ON p_partkey = ps_partkey myglue_test.tpch_csv.partsupp +| join ON s_suppkey = ps_suppkey myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| join ON n_regionkey = r_regionkey myglue_test.tpch_csv.region +| where p_size = 15 AND like(p_type, '%BRASS') AND r_name = 'EUROPE' AND ps_supplycost = [ + source = myglue_test.tpch_csv.partsupp + | join ON s_suppkey = ps_suppkey myglue_test.tpch_csv.supplier + | join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation + | join ON n_regionkey = r_regionkey myglue_test.tpch_csv.region + | where r_name = 'EUROPE' + | stats MIN(ps_supplycost) + ] +| sort - s_acctbal, n_name, s_name, p_partkey +| head 100",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON l_orderkey = o_orderkey myglue_test.tpch_csv.lineitem +| where c_mktsegment = 'BUILDING' AND o_orderdate < date('1995-03-15') AND l_shipdate > date('1995-03-15') +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by l_orderkey, o_orderdate, o_shippriority + | sort - revenue, o_orderdate +| head 10",SUCCESS +"source = myglue_test.tpch_csv.orders +| where o_orderdate >= date('1993-07-01') + and o_orderdate < date_add(date('1993-07-01'), interval 3 month) + and exists [ + source = myglue_test.tpch_csv.lineitem + | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate + ] +| stats count() as order_count by o_orderpriority +| sort o_orderpriority",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON l_orderkey = o_orderkey myglue_test.tpch_csv.lineitem +| join ON l_suppkey = s_suppkey AND c_nationkey = s_nationkey myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| join ON n_regionkey = r_regionkey myglue_test.tpch_csv.region +| where r_name = 'ASIA' AND o_orderdate >= date('1994-01-01') AND o_orderdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by n_name +| sort - revenue",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| where l_shipdate >= date('1994-01-01') + and l_shipdate < adddate(date('1994-01-01'), 365) + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +| stats sum(l_extendedprice * l_discount) as revenue",SUCCESS +"source = [ + source = myglue_test.tpch_csv.supplier + | join ON s_suppkey = l_suppkey myglue_test.tpch_csv.lineitem + | join ON o_orderkey = l_orderkey myglue_test.tpch_csv.orders + | join ON c_custkey = o_custkey myglue_test.tpch_csv.customer + | join ON s_nationkey = n1.n_nationkey myglue_test.tpch_csv.nation as n1 + | join ON c_nationkey = n2.n_nationkey myglue_test.tpch_csv.nation as n2 + | where l_shipdate between date('1995-01-01') and date('1996-12-31') + and n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY' or n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE' + | eval supp_nation = n1.n_name, cust_nation = n2.n_name, l_year = year(l_shipdate), volume = l_extendedprice * (1 - l_discount) + | fields supp_nation, cust_nation, l_year, volume + ] as shipping +| stats sum(volume) as revenue by supp_nation, cust_nation, l_year +| sort supp_nation, cust_nation, l_year",SUCCESS +"source = [ + source = myglue_test.tpch_csv.part + | join ON p_partkey = l_partkey myglue_test.tpch_csv.lineitem + | join ON s_suppkey = l_suppkey myglue_test.tpch_csv.supplier + | join ON l_orderkey = o_orderkey myglue_test.tpch_csv.orders + | join ON o_custkey = c_custkey myglue_test.tpch_csv.customer + | join ON c_nationkey = n1.n_nationkey myglue_test.tpch_csv.nation as n1 + | join ON s_nationkey = n2.n_nationkey myglue_test.tpch_csv.nation as n2 + | join ON n1.n_regionkey = r_regionkey myglue_test.tpch_csv.region + | where r_name = 'AMERICA' AND p_type = 'ECONOMY ANODIZED STEEL' + and o_orderdate between date('1995-01-01') and date('1996-12-31') + | eval o_year = year(o_orderdate) + | eval volume = l_extendedprice * (1 - l_discount) + | eval nation = n2.n_name + | fields o_year, volume, nation + ] as all_nations +| stats sum(case(nation = 'BRAZIL', volume else 0)) as sum_case, sum(volume) as sum_volume by o_year +| eval mkt_share = sum_case / sum_volume +| fields mkt_share, o_year +| sort o_year",SUCCESS +"source = [ + source = myglue_test.tpch_csv.part + | join ON p_partkey = l_partkey myglue_test.tpch_csv.lineitem + | join ON s_suppkey = l_suppkey myglue_test.tpch_csv.supplier + | join ON ps_partkey = l_partkey and ps_suppkey = l_suppkey myglue_test.tpch_csv.partsupp + | join ON o_orderkey = l_orderkey myglue_test.tpch_csv.orders + | join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation + | where like(p_name, '%green%') + | eval nation = n_name + | eval o_year = year(o_orderdate) + | eval amount = l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity + | fields nation, o_year, amount + ] as profit +| stats sum(amount) as sum_profit by nation, o_year +| sort nation, - o_year",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON l_orderkey = o_orderkey myglue_test.tpch_csv.lineitem +| join ON c_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where o_orderdate >= date('1993-10-01') + AND o_orderdate < date_add(date('1993-10-01'), interval 3 month) + AND l_returnflag = 'R' +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by c_custkey, c_name, c_acctbal, c_phone, n_name, c_address, c_comment +| sort - revenue +| head 20",SUCCESS +"source = myglue_test.tpch_csv.partsupp +| join ON ps_suppkey = s_suppkey myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where n_name = 'GERMANY' +| stats sum(ps_supplycost * ps_availqty) as value by ps_partkey +| where value > [ + source = myglue_test.tpch_csv.partsupp + | join ON ps_suppkey = s_suppkey myglue_test.tpch_csv.supplier + | join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation + | where n_name = 'GERMANY' + | stats sum(ps_supplycost * ps_availqty) as check + | eval threshold = check * 0.0001000000 + | fields threshold + ] +| sort - value",SUCCESS +"source = myglue_test.tpch_csv.orders +| join ON o_orderkey = l_orderkey myglue_test.tpch_csv.lineitem +| where l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_shipmode in ('MAIL', 'SHIP') + and l_receiptdate >= date('1994-01-01') + and l_receiptdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(case(o_orderpriority = '1-URGENT' or o_orderpriority = '2-HIGH', 1 else 0)) as high_line_count, + sum(case(o_orderpriority != '1-URGENT' and o_orderpriority != '2-HIGH', 1 else 0)) as low_line_countby + by l_shipmode +| sort l_shipmode",SUCCESS +"source = [ + source = myglue_test.tpch_csv.customer + | left outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + myglue_test.tpch_csv.orders + | stats count(o_orderkey) as c_count by c_custkey + ] as c_orders +| stats count() as custdist by c_count +| sort - custdist, - c_count",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| join ON l_partkey = p_partkey + AND l_shipdate >= date('1995-09-01') + AND l_shipdate < date_add(date('1995-09-01'), interval 1 month) + myglue_test.tpch_csv.part +| stats sum(case(like(p_type, 'PROMO%'), l_extendedprice * (1 - l_discount) else 0)) as sum1, + sum(l_extendedprice * (1 - l_discount)) as sum2 +| eval promo_revenue = 100.00 * sum1 / sum2 // Stats and Eval commands can combine when issues/819 resolved +| fields promo_revenue",SUCCESS +"source = myglue_test.tpch_csv.supplier +| join right = revenue0 ON s_suppkey = supplier_no [ + source = myglue_test.tpch_csv.lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] +| where total_revenue = [ + source = [ + source = myglue_test.tpch_csv.lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] + | stats max(total_revenue) + ] +| sort s_suppkey +| fields s_suppkey, s_name, s_address, s_phone, total_revenue",SUCCESS +"source = myglue_test.tpch_csv.partsupp +| join ON p_partkey = ps_partkey myglue_test.tpch_csv.part +| where p_brand != 'Brand#45' + and not like(p_type, 'MEDIUM POLISHED%') + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in [ + source = myglue_test.tpch_csv.supplier + | where like(s_comment, '%Customer%Complaints%') + | fields s_suppkey + ] +| stats distinct_count(ps_suppkey) as supplier_cnt by p_brand, p_type, p_size +| sort - supplier_cnt, p_brand, p_type, p_size",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| join ON p_partkey = l_partkey myglue_test.tpch_csv.part +| where p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < [ + source = myglue_test.tpch_csv.lineitem + | where l_partkey = p_partkey + | stats avg(l_quantity) as avg + | eval `0.2 * avg` = 0.2 * avg + | fields `0.2 * avg` + ] +| stats sum(l_extendedprice) as sum +| eval avg_yearly = sum / 7.0 +| fields avg_yearly",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON o_orderkey = l_orderkey myglue_test.tpch_csv.lineitem +| where o_orderkey in [ + source = myglue_test.tpch_csv.lineitem + | stats sum(l_quantity) as sum by l_orderkey + | where sum > 300 + | fields l_orderkey + ] +| stats sum(l_quantity) by c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice +| sort - o_totalprice, o_orderdate +| head 100",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| join ON p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + myglue_test.tpch_csv.part",SUCCESS +"source = myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where n_name = 'CANADA' + and s_suppkey in [ + source = myglue_test.tpch_csv.partsupp + | where ps_partkey in [ + source = myglue_test.tpch_csv.part + | where like(p_name, 'forest%') + | fields p_partkey + ] + and ps_availqty > [ + source = myglue_test.tpch_csv.lineitem + | where l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1994-01-01') + and l_shipdate < date_add(date('1994-01-01'), interval 1 year) + | stats sum(l_quantity) as sum_l_quantity + | eval half_sum_l_quantity = 0.5 * sum_l_quantity + | fields half_sum_l_quantity + ] + | fields ps_suppkey + ]",SUCCESS +"source = myglue_test.tpch_csv.supplier +| join ON s_suppkey = l1.l_suppkey myglue_test.tpch_csv.lineitem as l1 +| join ON o_orderkey = l1.l_orderkey myglue_test.tpch_csv.orders +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists [ + source = myglue_test.tpch_csv.lineitem as l2 + | where l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey != l1.l_suppkey + ] + and not exists [ + source = myglue_test.tpch_csv.lineitem as l3 + | where l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey != l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ] + and n_name = 'SAUDI ARABIA' +| stats count() as numwait by s_name +| sort - numwait, s_name +| head 100",SUCCESS +"source = [ + source = myglue_test.tpch_csv.customer + | where substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > [ + source = myglue_test.tpch_csv.customer + | where c_acctbal > 0.00 + and substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + | stats avg(c_acctbal) + ] + and not exists [ + source = myglue_test.tpch_csv.orders + | where o_custkey = c_custkey + ] + | eval cntrycode = substring(c_phone, 1, 2) + | fields cntrycode, c_acctbal + ] as custsale +| stats count() as numcust, sum(c_acctbal) as totacctbal by cntrycode +| sort cntrycode",SUCCESS diff --git a/integ-test/src/integration/resources/tpch/q1.ppl b/integ-test/src/integration/resources/tpch/q1.ppl new file mode 100644 index 000000000..885ce35c6 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q1.ppl @@ -0,0 +1,35 @@ +/* +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + l_shipdate <= date '1998-12-01' - interval '90' day +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus +*/ + +source = lineitem +| where l_shipdate <= subdate(date('1998-12-01'), 90) +| stats sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count() as count_order + by l_returnflag, l_linestatus +| sort l_returnflag, l_linestatus \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q10.ppl b/integ-test/src/integration/resources/tpch/q10.ppl new file mode 100644 index 000000000..10a050785 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q10.ppl @@ -0,0 +1,45 @@ +/* +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| join ON c_nationkey = n_nationkey nation +| where o_orderdate >= date('1993-10-01') + AND o_orderdate < date_add(date('1993-10-01'), interval 3 month) + AND l_returnflag = 'R' +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by c_custkey, c_name, c_acctbal, c_phone, n_name, c_address, c_comment +| sort - revenue +| head 20 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q11.ppl b/integ-test/src/integration/resources/tpch/q11.ppl new file mode 100644 index 000000000..3a55d986e --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q11.ppl @@ -0,0 +1,45 @@ +/* +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc +*/ + +source = partsupp +| join ON ps_suppkey = s_suppkey supplier +| join ON s_nationkey = n_nationkey nation +| where n_name = 'GERMANY' +| stats sum(ps_supplycost * ps_availqty) as value by ps_partkey +| where value > [ + source = partsupp + | join ON ps_suppkey = s_suppkey supplier + | join ON s_nationkey = n_nationkey nation + | where n_name = 'GERMANY' + | stats sum(ps_supplycost * ps_availqty) as check + | eval threshold = check * 0.0001000000 + | fields threshold + ] +| sort - value \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q12.ppl b/integ-test/src/integration/resources/tpch/q12.ppl new file mode 100644 index 000000000..79672d844 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q12.ppl @@ -0,0 +1,42 @@ +/* +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by + l_shipmode +order by + l_shipmode +*/ + +source = orders +| join ON o_orderkey = l_orderkey lineitem +| where l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_shipmode in ('MAIL', 'SHIP') + and l_receiptdate >= date('1994-01-01') + and l_receiptdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(case(o_orderpriority = '1-URGENT' or o_orderpriority = '2-HIGH', 1 else 0)) as high_line_count, + sum(case(o_orderpriority != '1-URGENT' and o_orderpriority != '2-HIGH', 1 else 0)) as low_line_countby + by l_shipmode +| sort l_shipmode \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q13.ppl b/integ-test/src/integration/resources/tpch/q13.ppl new file mode 100644 index 000000000..6e77c9b0a --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q13.ppl @@ -0,0 +1,31 @@ +/* +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc +*/ + +source = [ + source = customer + | left outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + orders + | stats count(o_orderkey) as c_count by c_custkey + ] as c_orders +| stats count() as custdist by c_count +| sort - custdist, - c_count \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q14.ppl b/integ-test/src/integration/resources/tpch/q14.ppl new file mode 100644 index 000000000..553f1e549 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q14.ppl @@ -0,0 +1,25 @@ +/* +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month +*/ + +source = lineitem +| join ON l_partkey = p_partkey + AND l_shipdate >= date('1995-09-01') + AND l_shipdate < date_add(date('1995-09-01'), interval 1 month) + part +| stats sum(case(like(p_type, 'PROMO%'), l_extendedprice * (1 - l_discount) else 0)) as sum1, + sum(l_extendedprice * (1 - l_discount)) as sum2 +| eval promo_revenue = 100.00 * sum1 / sum2 // Stats and Eval commands can combine when issues/819 resolved +| fields promo_revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q15.ppl b/integ-test/src/integration/resources/tpch/q15.ppl new file mode 100644 index 000000000..96f5ecea2 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q15.ppl @@ -0,0 +1,52 @@ +/* +with revenue0 as + (select + l_suppkey as supplier_no, + sum(l_extendedprice * (1 - l_discount)) as total_revenue + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey) +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey +*/ + +// CTE is unsupported in PPL +source = supplier +| join right = revenue0 ON s_suppkey = supplier_no [ + source = lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] +| where total_revenue = [ + source = [ + source = lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] + | stats max(total_revenue) + ] +| sort s_suppkey +| fields s_suppkey, s_name, s_address, s_phone, total_revenue diff --git a/integ-test/src/integration/resources/tpch/q16.ppl b/integ-test/src/integration/resources/tpch/q16.ppl new file mode 100644 index 000000000..4c5765f04 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q16.ppl @@ -0,0 +1,45 @@ +/* +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size +*/ + +source = partsupp +| join ON p_partkey = ps_partkey part +| where p_brand != 'Brand#45' + and not like(p_type, 'MEDIUM POLISHED%') + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in [ + source = supplier + | where like(s_comment, '%Customer%Complaints%') + | fields s_suppkey + ] +| stats distinct_count(ps_suppkey) as supplier_cnt by p_brand, p_type, p_size +| sort - supplier_cnt, p_brand, p_type, p_size \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q17.ppl b/integ-test/src/integration/resources/tpch/q17.ppl new file mode 100644 index 000000000..994b7ee18 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q17.ppl @@ -0,0 +1,34 @@ +/* +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ) +*/ + +source = lineitem +| join ON p_partkey = l_partkey part +| where p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < [ + source = lineitem + | where l_partkey = p_partkey + | stats avg(l_quantity) as avg + | eval `0.2 * avg` = 0.2 * avg // Stats and Eval commands can combine when issues/819 resolved + | fields `0.2 * avg` + ] +| stats sum(l_extendedprice) as sum +| eval avg_yearly = sum / 7.0 // Stats and Eval commands can combine when issues/819 resolved +| fields avg_yearly \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q18.ppl b/integ-test/src/integration/resources/tpch/q18.ppl new file mode 100644 index 000000000..1dab3d473 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q18.ppl @@ -0,0 +1,48 @@ +/* +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON o_orderkey = l_orderkey lineitem +| where o_orderkey in [ + source = lineitem + | stats sum(l_quantity) as sum by l_orderkey + | where sum > 300 + | fields l_orderkey + ] +| stats sum(l_quantity) by c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice +| sort - o_totalprice, o_orderdate +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q19.ppl b/integ-test/src/integration/resources/tpch/q19.ppl new file mode 100644 index 000000000..63312d2f0 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q19.ppl @@ -0,0 +1,66 @@ +/* +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) +*/ + +source = lineitem +| join ON + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) OR ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) OR ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + part \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q2.ppl b/integ-test/src/integration/resources/tpch/q2.ppl new file mode 100644 index 000000000..aa95d9d14 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q2.ppl @@ -0,0 +1,62 @@ +/* +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit 100 +*/ + +source = part +| join ON p_partkey = ps_partkey partsupp +| join ON s_suppkey = ps_suppkey supplier +| join ON s_nationkey = n_nationkey nation +| join ON n_regionkey = r_regionkey region +| where p_size = 15 AND like(p_type, '%BRASS') AND r_name = 'EUROPE' AND ps_supplycost = [ + source = partsupp + | join ON s_suppkey = ps_suppkey supplier + | join ON s_nationkey = n_nationkey nation + | join ON n_regionkey = r_regionkey region + | where r_name = 'EUROPE' + | stats MIN(ps_supplycost) + ] +| sort - s_acctbal, n_name, s_name, p_partkey +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q20.ppl b/integ-test/src/integration/resources/tpch/q20.ppl new file mode 100644 index 000000000..08bd21277 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q20.ppl @@ -0,0 +1,62 @@ +/* +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name +*/ + +source = supplier +| join ON s_nationkey = n_nationkey nation +| where n_name = 'CANADA' + and s_suppkey in [ + source = partsupp + | where ps_partkey in [ + source = part + | where like(p_name, 'forest%') + | fields p_partkey + ] + and ps_availqty > [ + source = lineitem + | where l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1994-01-01') + and l_shipdate < date_add(date('1994-01-01'), interval 1 year) + | stats sum(l_quantity) as sum_l_quantity + | eval half_sum_l_quantity = 0.5 * sum_l_quantity // Stats and Eval commands can combine when issues/819 resolved + | fields half_sum_l_quantity + ] + | fields ps_suppkey + ] \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q21.ppl b/integ-test/src/integration/resources/tpch/q21.ppl new file mode 100644 index 000000000..0eb7149f6 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q21.ppl @@ -0,0 +1,64 @@ +/* +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100 +*/ + +source = supplier +| join ON s_suppkey = l1.l_suppkey lineitem as l1 +| join ON o_orderkey = l1.l_orderkey orders +| join ON s_nationkey = n_nationkey nation +| where o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists [ + source = lineitem as l2 + | where l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey != l1.l_suppkey + ] + and not exists [ + source = lineitem as l3 + | where l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey != l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ] + and n_name = 'SAUDI ARABIA' +| stats count() as numwait by s_name +| sort - numwait, s_name +| head 100 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q22.ppl b/integ-test/src/integration/resources/tpch/q22.ppl new file mode 100644 index 000000000..811308cb0 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q22.ppl @@ -0,0 +1,58 @@ +/* +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode +*/ + +source = [ + source = customer + | where substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > [ + source = customer + | where c_acctbal > 0.00 + and substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + | stats avg(c_acctbal) + ] + and not exists [ + source = orders + | where o_custkey = c_custkey + ] + | eval cntrycode = substring(c_phone, 1, 2) + | fields cntrycode, c_acctbal + ] as custsale +| stats count() as numcust, sum(c_acctbal) as totacctbal by cntrycode +| sort cntrycode \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q3.ppl b/integ-test/src/integration/resources/tpch/q3.ppl new file mode 100644 index 000000000..0ece358ab --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q3.ppl @@ -0,0 +1,33 @@ +/* +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit 10 +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| where c_mktsegment = 'BUILDING' AND o_orderdate < date('1995-03-15') AND l_shipdate > date('1995-03-15') +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by l_orderkey, o_orderdate, o_shippriority +| sort - revenue, o_orderdate +| head 10 \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q4.ppl b/integ-test/src/integration/resources/tpch/q4.ppl new file mode 100644 index 000000000..cc01bda7d --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q4.ppl @@ -0,0 +1,33 @@ +/* +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority +*/ + +source = orders +| where o_orderdate >= date('1993-07-01') + and o_orderdate < date_add(date('1993-07-01'), interval 3 month) + and exists [ + source = lineitem + | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate + ] +| stats count() as order_count by o_orderpriority +| sort o_orderpriority \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q5.ppl b/integ-test/src/integration/resources/tpch/q5.ppl new file mode 100644 index 000000000..4761b0365 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q5.ppl @@ -0,0 +1,36 @@ +/* +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by + n_name +order by + revenue desc +*/ + +source = customer +| join ON c_custkey = o_custkey orders +| join ON l_orderkey = o_orderkey lineitem +| join ON l_suppkey = s_suppkey AND c_nationkey = s_nationkey supplier +| join ON s_nationkey = n_nationkey nation +| join ON n_regionkey = r_regionkey region +| where r_name = 'ASIA' AND o_orderdate >= date('1994-01-01') AND o_orderdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by n_name +| sort - revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q6.ppl b/integ-test/src/integration/resources/tpch/q6.ppl new file mode 100644 index 000000000..6a77877c3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q6.ppl @@ -0,0 +1,18 @@ +/* +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +*/ + +source = lineitem +| where l_shipdate >= date('1994-01-01') + and l_shipdate < adddate(date('1994-01-01'), 365) + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +| stats sum(l_extendedprice * l_discount) as revenue \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q7.ppl b/integ-test/src/integration/resources/tpch/q7.ppl new file mode 100644 index 000000000..a6ea66d63 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q7.ppl @@ -0,0 +1,56 @@ +/* +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + year(l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year +*/ + +source = [ + source = supplier + | join ON s_suppkey = l_suppkey lineitem + | join ON o_orderkey = l_orderkey orders + | join ON c_custkey = o_custkey customer + | join ON s_nationkey = n1.n_nationkey nation as n1 + | join ON c_nationkey = n2.n_nationkey nation as n2 + | where l_shipdate between date('1995-01-01') and date('1996-12-31') + and ((n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')) + | eval supp_nation = n1.n_name, cust_nation = n2.n_name, l_year = year(l_shipdate), volume = l_extendedprice * (1 - l_discount) + | fields supp_nation, cust_nation, l_year, volume + ] as shipping +| stats sum(volume) as revenue by supp_nation, cust_nation, l_year +| sort supp_nation, cust_nation, l_year \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q8.ppl b/integ-test/src/integration/resources/tpch/q8.ppl new file mode 100644 index 000000000..a73c7f7c3 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q8.ppl @@ -0,0 +1,60 @@ +/* +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year +*/ + +source = [ + source = part + | join ON p_partkey = l_partkey lineitem + | join ON s_suppkey = l_suppkey supplier + | join ON l_orderkey = o_orderkey orders + | join ON o_custkey = c_custkey customer + | join ON c_nationkey = n1.n_nationkey nation as n1 + | join ON s_nationkey = n2.n_nationkey nation as n2 + | join ON n1.n_regionkey = r_regionkey region + | where r_name = 'AMERICA' AND p_type = 'ECONOMY ANODIZED STEEL' + and o_orderdate between date('1995-01-01') and date('1996-12-31') + | eval o_year = year(o_orderdate) + | eval volume = l_extendedprice * (1 - l_discount) + | eval nation = n2.n_name + | fields o_year, volume, nation + ] as all_nations +| stats sum(case(nation = 'BRAZIL', volume else 0)) as sum_case, sum(volume) as sum_volume by o_year +| eval mkt_share = sum_case / sum_volume +| fields mkt_share, o_year +| sort o_year \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q9.ppl b/integ-test/src/integration/resources/tpch/q9.ppl new file mode 100644 index 000000000..7692afd74 --- /dev/null +++ b/integ-test/src/integration/resources/tpch/q9.ppl @@ -0,0 +1,50 @@ +/* +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc +*/ + +source = [ + source = part + | join ON p_partkey = l_partkey lineitem + | join ON s_suppkey = l_suppkey supplier + | join ON ps_partkey = l_partkey and ps_suppkey = l_suppkey partsupp + | join ON o_orderkey = l_orderkey orders + | join ON s_nationkey = n_nationkey nation + | where like(p_name, '%green%') + | eval nation = n_name + | eval o_year = year(o_orderdate) + | eval amount = l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity + | fields nation, o_year, amount + ] as profit +| stats sum(amount) as sum_profit by nation, o_year +| sort nation, - o_year \ No newline at end of file diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 1ddfa540b..51bcf8e40 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -17,15 +17,49 @@ import org.opensearch.OpenSearchStatusException import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} import org.opensearch.flint.core.{FlintClient, FlintOptions} -import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} -import org.opensearch.search.sort.SortOrder +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY -import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} +import org.apache.spark.sql.exception.UnrecoverableException +import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_STATEMENT_MANAGER, DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils +/** + * A StatementExecutionManagerImpl that throws UnrecoverableException during statement execution. + * Used for testing error handling in FlintREPL. + */ +class FailingStatementExecutionManager( + private var spark: SparkSession, + private var sessionId: String) + extends StatementExecutionManager { + + def this() = { + this(null, null) + } + + override def prepareStatementExecution(): Either[String, Unit] = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def executeStatement(statement: FlintStatement): DataFrame = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def getNextStatement(): Option[FlintStatement] = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def updateStatement(statement: FlintStatement): Unit = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } + + override def terminateStatementExecution(): Unit = { + throw UnrecoverableException(new RuntimeException("Simulated execution failure")) + } +} + class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { var flintClient: FlintClient = _ @@ -584,6 +618,27 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } } + test("REPL should handle unrecoverable exception from statement execution") { + // Note: This test sharing system property with other test cases so cannot run alone + System.setProperty( + CUSTOM_STATEMENT_MANAGER.key, + "org.apache.spark.sql.FailingStatementExecutionManager") + try { + createSession(jobRunId, "") + FlintREPL.main(Array(resultIndex)) + fail("The REPL should throw an unrecoverable exception, but it succeeded instead.") + } catch { + case ex: UnrecoverableException => + assert( + ex.getMessage.contains("Simulated execution failure"), + s"Unexpected exception message: ${ex.getMessage}") + case ex: Throwable => + fail(s"Unexpected exception type: ${ex.getClass} with message: ${ex.getMessage}") + } finally { + System.setProperty(CUSTOM_STATEMENT_MANAGER.key, "") + } + } + /** * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or * removed when inside a JSON string. The same goes for tab characters, which should be diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index fc77faaea..cf0347820 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -18,7 +18,7 @@ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName import org.opensearch.flint.spark.mv.FlintSparkMaterializedView -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTableNames, getFlintIndexName} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTablesFromQuery, getFlintIndexName, getSourceTablesFromMetadata, MV_INDEX_TYPE} import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -65,14 +65,76 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | FROM spark_catalog.default.`table/3` | INNER JOIN spark_catalog.default.`table.4` |""".stripMargin - extractSourceTableNames(flint.spark, testComplexQuery) should contain theSameElementsAs + extractSourceTablesFromQuery(flint.spark, testComplexQuery) should contain theSameElementsAs Array( "spark_catalog.default.table1", "spark_catalog.default.table2", "spark_catalog.default.`table/3`", "spark_catalog.default.`table.4`") - extractSourceTableNames(flint.spark, "SELECT 1") should have size 0 + extractSourceTablesFromQuery(flint.spark, "SELECT 1") should have size 0 + } + + test("get source table names from index metadata successfully") { + val mv = FlintSparkMaterializedView( + "spark_catalog.default.mv", + s"SELECT 1 FROM $testTable", + Array(testTable), + Map("1" -> "integer")) + val metadata = mv.metadata() + getSourceTablesFromMetadata(metadata) should contain theSameElementsAs Array(testTable) + } + + test("get source table names from deserialized metadata successfully") { + val metadata = FlintOpenSearchIndexMetadataService.deserialize(s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": [ + | "$testTable" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin) + getSourceTablesFromMetadata(metadata) should contain theSameElementsAs Array(testTable) + } + + test("get empty source tables from invalid field in metadata") { + val metadataWrongType = FlintOpenSearchIndexMetadataService.deserialize(s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": "$testTable" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin) + val metadataMissingField = FlintOpenSearchIndexMetadataService.deserialize(s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin) + + getSourceTablesFromMetadata(metadataWrongType) shouldBe empty + getSourceTablesFromMetadata(metadataMissingField) shouldBe empty } test("create materialized view with metadata successfully") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index 9e75078d2..ae2e53090 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -448,5 +448,80 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { } } + test("tumble function should raise error for non-simple time column") { + val httpLogs = s"$catalogName.default.mv_test_tumble" + withTable(httpLogs) { + createTableHttpLog(httpLogs) + + withTempDir { checkpointDir => + val ex = the[IllegalStateException] thrownBy { + sql(s""" + | CREATE MATERIALIZED VIEW `$catalogName`.`default`.`mv_test_metrics` + | AS + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $httpLogs + | GROUP BY + | TUMBLE(CAST(timestamp AS TIMESTAMP), '10 Minute') + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}', + | watermark_delay = '1 Second' + | ) + |""".stripMargin) + } + ex.getCause should have message + "Tumble function only supports simple timestamp column, but found: cast('timestamp as timestamp)" + } + } + } + + test("tumble function should succeed with casted time column within subquery") { + val httpLogs = s"$catalogName.default.mv_test_tumble" + withTable(httpLogs) { + createTableHttpLog(httpLogs) + + withTempDir { checkpointDir => + sql(s""" + | CREATE MATERIALIZED VIEW `$catalogName`.`default`.`mv_test_metrics` + | AS + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM ( + | SELECT CAST(timestamp AS TIMESTAMP) AS time + | FROM $httpLogs + | ) + | GROUP BY + | TUMBLE(time, '10 Minute') + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}', + | watermark_delay = '1 Second' + | ) + |""".stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testFlintIndex) + job shouldBe defined + failAfter(streamingTimeout) { + job.get.processAllAvailable() + } + + checkAnswer( + flint.queryIndex(testFlintIndex).select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 10:00:00"), 2), + Row(timestamp("2023-10-01 10:10:00"), 2) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 10:20:00"), 2) + */ + )) + } + } + } + private def timestamp(ts: String): Timestamp = Timestamp.valueOf(ts) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index c53eee548..68d370791 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -559,6 +559,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiColumnArrayTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_valueA Array>, + | multi_valueB Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)), array(STRUCT("2_Monday", 2), null) ), + | ( 2, array(STRUCT("2_Monday", 2), null) , array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) , array(STRUCT("1_one", 1))), + | ( 4, null, array(STRUCT("1_one", 1))) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala index c0d253fd3..5b4dd0208 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala @@ -18,6 +18,7 @@ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} import org.opensearch.flint.spark.{FlintSparkIndexOptions, FlintSparkSuite} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} import org.scalatest.Entry @@ -161,12 +162,29 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(testTable) + .asInstanceOf[java.util.ArrayList[String]] should contain theSameElementsAs Array( + testTable) } } - test(s"write metadata cache to materialized view index mappings with source tables") { + test("write metadata cache with source tables from index metadata") { + val mv = FlintSparkMaterializedView( + "spark_catalog.default.mv", + s"SELECT 1 FROM $testTable", + Array(testTable), + Map("1" -> "integer")) + val metadata = mv.metadata().copy(latestLogEntry = Some(flintMetadataLogEntry)) + + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[java.util.ArrayList[String]] should contain theSameElementsAs Array(testTable) + } + + test("write metadata cache with source tables from deserialized metadata") { val testTable2 = "spark_catalog.default.metadatacache_test2" val content = s""" { @@ -194,8 +212,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(testTable, testTable2) + .asInstanceOf[java.util.ArrayList[String]] should contain theSameElementsAs Array( + testTable, + testTable2) } test("write metadata cache to index mappings with refresh interval") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index 0bebca9b0..aa96d0991 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -494,4 +494,43 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans comparePlans(expectedPlan, logicalPlan, false) } + + test( + "create ppl simple distinct count age by span of interval of 10 years query with state filter test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Quebec' | stats distinct_count_approx(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(1, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index bcfe22764..2275c775c 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -835,6 +835,43 @@ class FlintSparkPPLAggregationsITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl simple country distinct_count using approximation ") { + val frame = sql(s""" + | source = $testTable| stats distinct_count_approx(country) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = true), + "distinct_count_approx(country)")() + + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl simple age distinct_count group by country query test with sort") { val frame = sql(s""" | source = $testTable | stats distinct_count(age) by country | sort country @@ -881,6 +918,53 @@ class FlintSparkPPLAggregationsITSuite s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") } + test( + "create ppl simple age distinct_count group by country query test with sort using approximation") { + val frame = sql(s""" + | source = $testTable | stats distinct_count_approx(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + test("create ppl simple age distinct_count group by country with state filter query test") { val frame = sql(s""" | source = $testTable | where state != 'Ontario' | stats distinct_count(age) by country @@ -920,6 +1004,46 @@ class FlintSparkPPLAggregationsITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test( + "create ppl simple age distinct_count group by country with state filter query test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats distinct_count_approx(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test("two-level stats") { val frame = sql(s""" | source = $testTable| stats avg(age) as avg_age by state, country | stats avg(avg_age) as avg_state_age by country diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index cbc4308b0..300b44b5a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -541,11 +541,6 @@ class FlintSparkPPLBasicITSuite | """.stripMargin)) assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND")) } - val t7 = "spark_catalog.default.flint_ppl_test7.log" - val ex = intercept[IllegalArgumentException](sql(s""" - | source = $t7| head 2 - | """.stripMargin)) - assert(ex.getMessage().contains("Invalid table name")) } test("test describe backtick table names and name contains '.'") { @@ -564,11 +559,6 @@ class FlintSparkPPLBasicITSuite | """.stripMargin)) assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND")) } - val t7 = "spark_catalog.default.flint_ppl_test7.log" - val ex = intercept[IllegalArgumentException](sql(s""" - | describe $t7 - | """.stripMargin)) - assert(ex.getMessage().contains("Invalid table name")) } test("test explain backtick table names and name contains '.'") { @@ -590,11 +580,95 @@ class FlintSparkPPLBasicITSuite Project(Seq(UnresolvedStar(None)), relation), ExplainMode.fromString("extended")) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + // TODO Do not support 4+ parts table identifier in future (may be reverted this PR in 0.8.0) + test("test table name with more than 3 parts") { val t7 = "spark_catalog.default.flint_ppl_test7.log" - val ex = intercept[IllegalArgumentException](sql(s""" - | explain extended | source = $t7 - | """.stripMargin)) - assert(ex.getMessage().contains("Invalid table name")) + val t4Parts = "`spark_catalog`.default.`startTime:1,endTime:2`.`this(is:['a/name'])`" + val t5Parts = + "`spark_catalog`.default.`startTime:1,endTime:2`.`this(is:['sub/name'])`.`this(is:['sub-sub/name'])`" + + Seq(t7, t4Parts, t5Parts).foreach { table => + val ex = intercept[AnalysisException](sql(s""" + | source = $table| head 2 + | """.stripMargin)) + // Expected since V2SessionCatalog only supports 3 parts + assert( + ex.getMessage() + .contains( + "[REQUIRES_SINGLE_PART_NAMESPACE] spark_catalog requires a single-part namespace")) + } + + Seq(t7, t4Parts, t5Parts).foreach { table => + val ex = intercept[AnalysisException](sql(s""" + | describe $table + | """.stripMargin)) + assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND")) + } + } + + test("Search multiple tables - translated into union call with fields") { + val frame = sql(s""" + | source = $t1, $t2 + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val allFields1 = UnresolvedStar(None) + val allFields2 = UnresolvedStar(None) + + val projectedTable1 = Project(Seq(allFields1), table1) + val projectedTable2 = Project(Seq(allFields2), table2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("Search multiple tables - with table alias") { + val frame = sql(s""" + | source = $t1, $t2 as t | where t.country = "USA" + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + val plan1 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table1)) + val plan2 = Filter( + EqualTo(UnresolvedAttribute("t.country"), Literal("USA")), + SubqueryAlias("t", table2)) + + val projectedTable1 = Project(Seq(UnresolvedStar(None)), plan1) + val projectedTable2 = Project(Seq(UnresolvedStar(None)), plan2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala new file mode 100644 index 000000000..f0404bf7b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala @@ -0,0 +1,255 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import scala.collection.mutable + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Explode, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLExpandITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val occupationTable = "spark_catalog.default.flint_ppl_flat_table_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val multiArraysTable = "spark_catalog.default.flint_ppl_multi_array_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createOccupationTable(occupationTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + createMultiColumnArrayTable(multiArraysTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("expand for eval field of an array") { + val frame = sql( + s""" source = $occupationTable | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", "Engineer", 1), + Row("Jake", "Engineer", 2), + Row("Jake", "Engineer", 3), + Row("Hello", "Artist", 1), + Row("Hello", "Artist", 2), + Row("Hello", "Artist", 3), + Row("John", "Doctor", 1), + Row("John", "Doctor", 2), + Row("John", "Doctor", 3), + Row("David", "Doctor", 1), + Row("David", "Doctor", 2), + Row("David", "Doctor", 3), + Row("David", "Unemployed", 1), + Row("David", "Unemployed", 2), + Row("David", "Unemployed", 3), + Row("Jane", "Scientist", 1), + Row("Jane", "Scientist", 2), + Row("Jane", "Scientist", 3)) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_flat_table_test")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), table) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project( + seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("occupation"), + UnresolvedAttribute("uid")), + dropSourceColumn) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("expand for structs") { + val frame = sql( + s""" source = $multiValueTable | expand multi_value AS exploded_multi_value | fields exploded_multi_value + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row("1_one", 1)), + Row(Row(null, 11)), + Row(Row("1_three", null)), + Row(Row("2_Monday", 2)), + Row(null), + Row(Row("3_third", 3)), + Row(Row("3_4th", 4)), + Row(null)) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val generate = Generate( + Explode(UnresolvedAttribute("multi_value")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("exploded_multi_value")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_value")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("exploded_multi_value")), dropSourceColumn) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | expand bridges + | | fields bridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))), + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))) + // Row(null)) -> in case of outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) it will include the `null` row + ) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val generate = + Generate(Explode(UnresolvedAttribute("bridges")), seq(), outer = false, None, seq(), filter) + val expectedPlan = Project(Seq(UnresolvedAttribute("bridges")), generate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs with alias") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' + | | expand bridges as britishBridges + | | fields britishBridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge")), + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge"))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter(EqualTo(UnresolvedAttribute("country"), Literal("England")), table) + val generate = Generate( + Explode(UnresolvedAttribute("bridges")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("britishBridges")), + filter) + val dropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("bridges")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("britishBridges")), dropColumn) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val frame = sql(s""" + | source = $multiArraysTable + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1, Row("1_one", 1), Row("2_Monday", 2)), + Row(1, Row("1_one", 1), null), + Row(1, Row(null, 11), Row("2_Monday", 2)), + Row(1, Row(null, 11), null), + Row(1, Row("1_three", null), Row("2_Monday", 2)), + Row(1, Row("1_three", null), null), + Row(2, Row("2_Monday", 2), Row("3_third", 3)), + Row(2, Row("2_Monday", 2), Row("3_4th", 4)), + Row(2, null, Row("3_third", 3)), + Row(2, null, Row("3_4th", 4)), + Row(3, Row("3_third", 3), Row("1_one", 1)), + Row(3, Row("3_4th", 4), Row("1_one", 1))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_array_test")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), table) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index f2d7ee844..62c735597 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -467,4 +467,96 @@ class FlintSparkPPLFiltersITSuite val expectedPlan = Project(Seq(UnresolvedAttribute("state")), filter) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test parenthesis in filter") { + val frame = sql(s""" + | source = $testTable | where country = 'Canada' or age > 60 and age < 25 | fields name, age, country + | """.stripMargin) + assertSameRows(Seq(Row("John", 25, "Canada"), Row("Jane", 20, "Canada")), frame) + + val frameWithParenthesis = sql(s""" + | source = $testTable | where (country = 'Canada' or age > 60) and age < 25 | fields name, age, country + | """.stripMargin) + assertSameRows(Seq(Row("Jane", 20, "Canada")), frameWithParenthesis) + + val logicalPlan: LogicalPlan = frameWithParenthesis.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filter = Filter( + And( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("Canada")), + GreaterThan(UnresolvedAttribute("age"), Literal(60))), + LessThan(UnresolvedAttribute("age"), Literal(25))), + table) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("country")), + filter) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test complex and nested parenthesis in filter") { + val frame1 = sql(s""" + | source = $testTable | WHERE (age > 18 AND (state = 'California' OR state = 'New York')) + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame1) + + val frame2 = sql(s""" + | source = $testTable | WHERE ((((age > 18) AND ((((state = 'California') OR state = 'New York')))))) + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame2) + + val frame3 = sql(s""" + | source = $testTable | WHERE (year = 2023 AND (month BETWEEN 1 AND 6)) AND (age >= 31 OR country = 'Canada') + | """.stripMargin) + assertSameRows( + Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4)), + frame3) + + val frame4 = sql(s""" + | source = $testTable | WHERE ((state = 'Texas' OR state = 'California') AND (age < 30 OR (country = 'USA' AND year > 2020))) + | """.stripMargin) + assertSameRows(Seq(Row("Jake", 70, "California", "USA", 2023, 4)), frame4) + + val frame5 = sql(s""" + | source = $testTable | WHERE (LIKE(LOWER(name), 'a%') OR LIKE(LOWER(name), 'j%')) AND (LENGTH(state) > 6 OR (country = 'USA' AND age > 18)) + | """.stripMargin) + assertSameRows( + Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame5) + + val frame6 = sql(s""" + | source = $testTable | WHERE (age BETWEEN 25 AND 40) AND ((state IN ('California', 'New York', 'Texas') AND year = 2023) OR (country != 'USA' AND (month = 1 OR month = 12))) + | """.stripMargin) + assertSameRows(Seq(Row("Hello", 30, "New York", "USA", 2023, 4)), frame6) + + val frame7 = sql(s""" + | source = $testTable | WHERE NOT (age < 18 OR (state = 'Alaska' AND year < 2020)) AND (country = 'USA' OR (country = 'Mexico' AND month BETWEEN 6 AND 8)) + | """.stripMargin) + assertSameRows( + Seq( + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4)), + frame7) + + val frame8 = sql(s""" + | source = $testTable | WHERE (NOT (year < 2020 OR age < 18)) AND ((state = 'Texas' AND month % 2 = 0) OR (country = 'Mexico' AND (year = 2023 OR (year = 2022 AND month > 6)))) + | """.stripMargin) + assertSameRows(Seq(), frame8) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index 00e55d50a..3127325c8 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} @@ -924,4 +924,271 @@ class FlintSparkPPLJoinITSuite s }.size == 13) } + + test("test multiple joins without table aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN ON $testTable1.name = $testTable2.name $testTable2 + | | JOIN ON $testTable2.name = $testTable3.name $testTable3 + | | fields $testTable1.name, $testTable2.name, $testTable3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | fields t1.name, t2.name, t3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("check access the reference by aliases") { + var frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 as t1 + | | JOIN ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 ] as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as t2 ] + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("access the reference by override aliases should throw exception") { + var ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as tt ] + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as tt ] as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 ] as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala index 7cc0a221d..fca758101 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -163,30 +163,32 @@ class FlintSparkPPLJsonFunctionITSuite assert(ex.getMessage().contains("should all be the same type")) } - test("test json_array() with json()") { + test("test json_array() with to_json_tring()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""[1.0,2.0,0.0,-1.0,1.1,-0.11]""")), frame) } - test("test json_array_length()") { + test("test array_length()") { var frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(5)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(6)), frame) frame = sql(s""" - | source = $testTable | eval result = json_array_length(json_array()) | head 1 | fields result - | """.stripMargin) + | source = $testTable| eval result = array_length(json_array()) | head 1 | fields result + | """.stripMargin) assertSameRows(Seq(Row(0)), frame) + } - frame = sql(s""" + test("test json_array_length()") { + var frame = sql(s""" | source = $testTable | eval result = json_array_length('[]') | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row(0)), frame) @@ -211,24 +213,24 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object()") { // test value is a string var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 'string_value')) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 'string_value')) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":"string_value"}""")), frame) // test value is a number frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', 123.45)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', 123.45)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":123.45}""")), frame) // test value is a boolean frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', true)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', true)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":true}""")), frame) frame = sql(s""" - | source = $testTable| eval result = json(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"a":1,"b":2,"c":3}""")), frame) } @@ -236,13 +238,13 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() and json_array()") { // test value is an empty array var frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array())) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array())) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[]}""")), frame) // test value is an array frame = sql(s""" - | source = $testTable| eval result = json(json_object('key', array(1, 2, 3))) | head 1 | fields result + | source = $testTable| eval result = to_json_string(json_object('key', array(1, 2, 3))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"key":[1,2,3]}""")), frame) @@ -272,14 +274,14 @@ class FlintSparkPPLJsonFunctionITSuite test("test json_object() nested") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"outer":{"inner":123.45}}""")), frame) } test("test json_object(), json_array() and json()") { val frame = sql(s""" - | source = $testTable | eval result = json(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result + | source = $testTable | eval result = to_json_string(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result | """.stripMargin) assertSameRows(Seq(Row("""{"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]}""")), frame) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index f10b6e2f5..4a1633035 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -84,6 +84,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl rare address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl rare address by age field query test") { val frame = sql(s""" | source = $testTable| rare address by age @@ -132,6 +174,104 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl rare 3 address by age field query test") { + val frame = sql(s""" + | source = $testTable| rare 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRow = Row(1, "Vancouver", 60) + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl rare 3 address by age field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl top address field query test") { val frame = sql(s""" | source = $testTable| top address @@ -179,6 +319,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| top_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 3 countries query test") { val frame = sql(s""" | source = $newTestTable| top 3 country @@ -226,6 +408,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top 3 countries query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField) + val aggregatePlan = + Aggregate( + Seq(countryField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 2 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation @@ -277,4 +501,50 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("create ppl top 2 countries by occupation field query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country by occupation + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index bc4463537..9a8379288 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -5,9 +5,14 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -244,4 +249,265 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) } + + test("test trendline wma command with sort field and without alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#185] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command with sort field and with alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) as trendline_alias + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "trendline_alias"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline wma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(2, age) as two_points_wma wma(3, age) as three_points_wma + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "two_points_wma", + "three_points_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 23.333333333333332, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 28.333333333333332, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 56.666666666666664, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS two_points_wma#247, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age wma(2, doubled_age) as doubled_age_wma | fields name, doubled_age, doubled_age_wma + | """.stripMargin) + + // Compare the headers + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 46.666666666666664), + Row("Hello", 60, 56.666666666666664), + Row("Jake", 140, 113.33333333333333)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + getNthValueAggregation("doubled_age", "age", 1, -1), + getNthValueAggregation("doubled_age", "age", 2, -1)) + val wmaExpression = Divide(dividend, Literal(3)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val doubledAged = Alias( + UnresolvedFunction( + seq("*"), + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false), + "doubled_age")() + val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) + val sortedTable = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("doubled_age"), + UnresolvedAttribute("doubled_age_wma")), + Project(trendlineProjectList, sortedTable)) + + /** + * 'Project ['name, 'doubled_age, 'doubled_age_wma] +- 'Project [*, (( + * ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('doubled_age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS doubled_age_wma#288] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'Project [*, '`*`('age, 2) AS doubled_age#287] +- 'UnresolvedRelation [spark_catalog, + * default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + } + + test("test invalid wma command with negative dataPoint value") { + val exception = intercept[ParseException](sql(s""" + | source = $testTable | trendline sort + age wma(-3, age) + | """.stripMargin)) + assert(exception.getMessage contains "[PARSE_SYNTAX_ERROR] Syntax error") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala new file mode 100644 index 000000000..fb14210e9 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryBase.scala @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl.tpch + +import org.opensearch.flint.spark.ppl.FlintPPLSuite + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_SECOND +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.internal.SQLConf + +trait TPCHQueryBase extends FlintPPLSuite { + + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS.key, Int.MaxValue.toString) + } + + override def beforeAll(): Unit = { + super.beforeAll() + RuleExecutor.resetMetrics() + CodeGenerator.resetCompileTime() + WholeStageCodegenExec.resetCodeGenTime() + tpchCreateTable.values.foreach { ppl => + sql(ppl) + } + } + + override def afterAll(): Unit = { + try { + tpchCreateTable.keys.foreach { tableName => + spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) + } + // For debugging dump some statistics about how much time was spent in various optimizer rules + // code generation, and compilation. + logWarning(RuleExecutor.dumpTimeSpent()) + val codeGenTime = WholeStageCodegenExec.codeGenTime.toDouble / NANOS_PER_SECOND + val compileTime = CodeGenerator.compileTime.toDouble / NANOS_PER_SECOND + val codegenInfo = + s""" + |=== Metrics of Whole-stage Codegen === + |Total code generation time: $codeGenTime seconds + |Total compile time: $compileTime seconds + """.stripMargin + logWarning(codegenInfo) + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + def checkGeneratedCode(plan: SparkPlan, checkMethodCodeSize: Boolean = true): Unit = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() + + def findSubtrees(plan: SparkPlan): Unit = { + plan foreach { + case s: WholeStageCodegenExec => + codegenSubtrees += s + case s => + s.subqueries.foreach(findSubtrees) + } + } + + findSubtrees(plan) + codegenSubtrees.toSeq.foreach { subtree => + val code = subtree.doCodeGen()._2 + val (_, ByteCodeStats(maxMethodCodeSize, _, _)) = + try { + // Just check the generated code can be properly compiled + CodeGenerator.compile(code) + } catch { + case e: Exception => + val msg = + s""" + |failed to compile: + |Subtree: + |$subtree + |Generated code: + |${CodeFormatter.format(code)} + """.stripMargin + throw new Exception(msg, e) + } + + assert( + !checkMethodCodeSize || + maxMethodCodeSize <= CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT, + s"too long generated codes found in the WholeStageCodegenExec subtree (id=${subtree.id}) " + + s"and JIT optimization might not work:\n${subtree.treeString}") + } + } + + val tpchCreateTable = Map( + "orders" -> + """ + |CREATE TABLE `orders` ( + |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, + |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, + |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) + |USING parquet + """.stripMargin, + "nation" -> + """ + |CREATE TABLE `nation` ( + |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) + |USING parquet + """.stripMargin, + "region" -> + """ + |CREATE TABLE `region` ( + |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) + |USING parquet + """.stripMargin, + "part" -> + """ + |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, + |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, + |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) + |USING parquet + """.stripMargin, + "partsupp" -> + """ + |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, + |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) + |USING parquet + """.stripMargin, + "customer" -> + """ + |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, + |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), + |`c_mktsegment` STRING, `c_comment` STRING) + |USING parquet + """.stripMargin, + "supplier" -> + """ + |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, + |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) + |USING parquet + """.stripMargin, + "lineitem" -> + """ + |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, + |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), + |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, + |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, + |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) + |USING parquet + """.stripMargin) + + val tpchQueries = Seq( + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "q16", + "q17", + "q18", + "q19", + "q20", + "q21", + "q22") +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala new file mode 100644 index 000000000..1b9681618 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/tpch/TPCHQueryITSuite.scala @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl.tpch + +import org.opensearch.flint.spark.ppl.LogicalPlanTestUtils + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.streaming.StreamTest + +class TPCHQueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with TPCHQueryBase + with StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + tpchQueries.foreach { name => + val queryString = resourceToString( + s"tpch/$name.ppl", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index a6c2f06ff..34d408fb0 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -23,7 +23,9 @@ DEDUP: 'DEDUP'; SORT: 'SORT'; EVAL: 'EVAL'; HEAD: 'HEAD'; +TOP_APPROX: 'TOP_APPROX'; TOP: 'TOP'; +RARE_APPROX: 'RARE_APPROX'; RARE: 'RARE'; PARSE: 'PARSE'; METHOD: 'METHOD'; @@ -37,6 +39,7 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +EXPAND: 'EXPAND'; FLATTEN: 'FLATTEN'; TRENDLINE: 'TRENDLINE'; @@ -93,6 +96,7 @@ NULLS: 'NULLS'; //TRENDLINE KEYWORDS SMA: 'SMA'; +WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; @@ -215,6 +219,7 @@ BIT_XOR_OP: '^'; AVG: 'AVG'; COUNT: 'COUNT'; DISTINCT_COUNT: 'DISTINCT_COUNT'; +DISTINCT_COUNT_APPROX: 'DISTINCT_COUNT_APPROX'; ESTDC: 'ESTDC'; ESTDC_ERROR: 'ESTDC_ERROR'; MAX: 'MAX'; @@ -378,6 +383,7 @@ JSON: 'JSON'; JSON_OBJECT: 'JSON_OBJECT'; JSON_ARRAY: 'JSON_ARRAY'; JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; +TO_JSON_STRING: 'TO_JSON_STRING'; JSON_EXTRACT: 'JSON_EXTRACT'; JSON_KEYS: 'JSON_KEYS'; JSON_VALID: 'JSON_VALID'; @@ -393,6 +399,7 @@ JSON_VALID: 'JSON_VALID'; // COLLECTION FUNCTIONS ARRAY: 'ARRAY'; +ARRAY_LENGTH: 'ARRAY_LENGTH'; // LAMBDA FUNCTIONS //EXISTS: 'EXISTS'; @@ -409,6 +416,9 @@ ISPRESENT: 'ISPRESENT'; BETWEEN: 'BETWEEN'; CIDRMATCH: 'CIDRMATCH'; +// Geo Loction +GEOIP: 'GEOIP'; + // FLOWCONTROL FUNCTIONS IFNULL: 'IFNULL'; NULLIF: 'NULLIF'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index c5c017be6..0a4f548d8 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -54,6 +54,7 @@ commands | fillnullCommand | fieldsummaryCommand | flattenCommand + | expandCommand | trendlineCommand ; @@ -75,13 +76,16 @@ commandName | SORT | HEAD | TOP + | TOP_APPROX | RARE + | RARE_APPROX | EVAL | GROK | PARSE | PATTERNS | LOOKUP | RENAME + | EXPAND | FILLNULL | FIELDSUMMARY | FLATTEN @@ -178,11 +182,11 @@ headCommand ; topCommand - : TOP (number = integerLiteral)? fieldList (byClause)? + : (TOP | TOP_APPROX) (number = integerLiteral)? fieldList (byClause)? ; rareCommand - : RARE fieldList (byClause)? + : (RARE | RARE_APPROX) (number = integerLiteral)? fieldList (byClause)? ; grokCommand @@ -250,6 +254,10 @@ fillnullCommand : expression ; +expandCommand + : EXPAND fieldExpression (AS alias = qualifiedName)? + ; + flattenCommand : FLATTEN fieldExpression ; @@ -259,11 +267,12 @@ trendlineCommand ; trendlineClause - : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + : trendlineType LT_PRTHS numberOfDataPoints = INTEGER_LITERAL COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? ; trendlineType : SMA + | WMA ; kmeansCommand @@ -339,7 +348,7 @@ joinType ; sideAlias - : LEFT EQUAL leftAlias = ident COMMA? RIGHT EQUAL rightAlias = ident + : (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)? ; joinCriteria @@ -394,7 +403,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall | percentileFunctionName = (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS valueExpression COMMA percent = integerLiteral RT_PRTHS # percentileFunctionCall ; @@ -416,6 +425,7 @@ expression logicalExpression : NOT logicalExpression # logicalNot + | LT_PRTHS logicalExpression RT_PRTHS # parentheticLogicalExpr | comparisonExpression # comparsion | left = logicalExpression (AND)? right = logicalExpression # logicalAnd | left = logicalExpression OR right = logicalExpression # logicalOr @@ -441,6 +451,7 @@ valueExpression | positionFunction # positionFunctionCall | caseFunction # caseExpr | timestampFunction # timestampFunctionCall + | geoipFunction # geoFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr | ident ARROW expression # lambda @@ -538,6 +549,11 @@ dataTypeFunctionCall : CAST LT_PRTHS expression AS convertedDataType RT_PRTHS ; +// geoip function +geoipFunction + : GEOIP LT_PRTHS (datasource = functionArg COMMA)? ipAddress = functionArg (COMMA properties = stringLiteral)? RT_PRTHS + ; + // boolean functions booleanFunctionCall : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS @@ -571,6 +587,7 @@ evalFunctionName | cryptographicFunctionName | jsonFunctionName | collectionFunctionName + | geoipFunctionName | lambdaFunctionName ; @@ -861,6 +878,7 @@ jsonFunctionName | JSON_OBJECT | JSON_ARRAY | JSON_ARRAY_LENGTH + | TO_JSON_STRING | JSON_EXTRACT | JSON_KEYS | JSON_VALID @@ -877,6 +895,7 @@ jsonFunctionName collectionFunctionName : ARRAY + | ARRAY_LENGTH ; lambdaFunctionName @@ -887,6 +906,10 @@ lambdaFunctionName | REDUCE ; +geoipFunctionName + : GEOIP + ; + positionFunctionName : POSITION ; @@ -1135,6 +1158,7 @@ keywordsCanBeId // AGGREGATIONS | statsFunctionName | DISTINCT_COUNT + | DISTINCT_COUNT_APPROX | PERCENTILE | PERCENTILE_APPROX | ESTDC diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index c7d61b2e2..597e9e8cc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -112,6 +112,10 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitExpand(Expand node, C context) { + return visitChildren(node, context); + } + public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java new file mode 100644 index 000000000..9a4aa5d7d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.Literal; + +import java.util.Optional; + +/** + * marker interface for numeric based count aggregation (specific number of returned results) + */ +public interface CountedAggregation { + Optional getResults(); +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java index b513d01bf..dd9947329 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java @@ -8,12 +8,14 @@ import lombok.ToString; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import java.util.Collections; + /** * Extend Relation to describe the table itself */ @ToString public class DescribeRelation extends Relation{ public DescribeRelation(UnresolvedExpression tableName) { - super(tableName); + super(Collections.singletonList(tableName)); } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java new file mode 100644 index 000000000..0e164ccd7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedAttribute; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +/** Logical plan node of Expand */ +@RequiredArgsConstructor +public class Expand extends UnresolvedPlan { + private UnresolvedPlan child; + + @Getter + private final Field field; + @Getter + private final Optional alias; + + @Override + public Expand attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitExpand(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java index e31fbb6e3..9c57d2adf 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -14,7 +14,7 @@ public class Flatten extends UnresolvedPlan { private UnresolvedPlan child; @Getter - private final Field fieldToBeFlattened; + private final Field field; @Override public UnresolvedPlan attach(UnresolvedPlan child) { @@ -26,7 +26,7 @@ public UnresolvedPlan attach(UnresolvedPlan child) { public List getChild() { return child == null ? List.of() : List.of(child); } - + @Override public T accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitFlatten(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java index 89f787d34..176902911 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -25,15 +25,15 @@ public class Join extends UnresolvedPlan { private UnresolvedPlan left; private final UnresolvedPlan right; - private final String leftAlias; - private final String rightAlias; + private final Optional leftAlias; + private final Optional rightAlias; private final JoinType joinType; private final Optional joinCondition; private final JoinHint joinHint; @Override public UnresolvedPlan attach(UnresolvedPlan child) { - this.left = new SubqueryAlias(leftAlias, child); + this.left = leftAlias.isEmpty() ? child : new SubqueryAlias(leftAlias.get(), child); return this; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java index d5a637f3d..8e454685a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java @@ -6,21 +6,29 @@ package org.opensearch.sql.ast.tree; import lombok.EqualsAndHashCode; +import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */ @ToString +@Getter @EqualsAndHashCode(callSuper = true) -public class RareAggregation extends Aggregation { +public class RareAggregation extends Aggregation implements CountedAggregation{ + private final Optional results; + /** Aggregation Constructor without span and argument. */ public RareAggregation( + Optional results, List aggExprList, List sortExprList, List groupExprList) { super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + this.results = results; } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index 1b30a7998..d8ea104a4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -6,53 +6,34 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; -import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; /** Logical plan node of Relation, the interface for building the searching sources. */ -@AllArgsConstructor @ToString +@Getter @EqualsAndHashCode(callSuper = false) @RequiredArgsConstructor public class Relation extends UnresolvedPlan { private static final String COMMA = ","; - private final List tableName; - - public Relation(UnresolvedExpression tableName) { - this(tableName, null); - } - - public Relation(UnresolvedExpression tableName, String alias) { - this.tableName = Arrays.asList(tableName); - this.alias = alias; - } - - /** Optional alias name for the relation. */ - @Setter @Getter private String alias; - - /** - * Return table name. - * - * @return table name - */ - public List getTableName() { - return tableName.stream().map(Object::toString).collect(Collectors.toList()); - } + // A relation could contain more than one table/index names, such as + // source=account1, account2 + // source=`account1`,`account2` + // source=`account*` + // They translated into union call with fields. + private final List tableNames; public List getQualifiedNames() { - return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); + return tableNames.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } /** @@ -63,11 +44,11 @@ public List getQualifiedNames() { * @return TableQualifiedName. */ public QualifiedName getTableQualifiedName() { - if (tableName.size() == 1) { - return (QualifiedName) tableName.get(0); + if (tableNames.size() == 1) { + return (QualifiedName) tableNames.get(0); } else { return new QualifiedName( - tableName.stream() + tableNames.stream() .map(UnresolvedExpression::toString) .collect(Collectors.joining(COMMA))); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java index 29c3d4b90..ba66cca80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java @@ -6,19 +6,14 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import java.util.List; -import java.util.Objects; -@AllArgsConstructor @EqualsAndHashCode(callSuper = false) -@RequiredArgsConstructor @ToString public class SubqueryAlias extends UnresolvedPlan { @Getter private final String alias; @@ -32,6 +27,11 @@ public SubqueryAlias(UnresolvedPlan child, String suffix) { this.child = child; } + public SubqueryAlias(String alias, UnresolvedPlan child) { + this.alias = alias; + this.child = child; + } + public List getChild() { return ImmutableList.of(child); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java index e87a3b0b0..90aac5838 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java @@ -20,7 +20,7 @@ @ToString @Getter @EqualsAndHashCode(callSuper = true) -public class TopAggregation extends Aggregation { +public class TopAggregation extends Aggregation implements CountedAggregation { private final Optional results; /** Aggregation Constructor without span and argument. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 9fa1ae81d..d08e89e3b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -62,6 +62,6 @@ public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dat } public enum TrendlineType { - SMA + SMA, WMA } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/DefaultPatterns.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/DefaultPatterns.java new file mode 100644 index 000000000..411542fb4 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/DefaultPatterns.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.common.grok; + +import java.util.Map; + +public interface DefaultPatterns { + + /** + * populate map with default patterns as they appear under the '/resources/patterns/*' resource folder + */ + static Map withDefaultPatterns(Map patterns) { + patterns.put("PATH" , "(?:%{UNIXPATH}|%{WINPATH})"); + patterns.put("MONTH" , "\\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\\b"); + patterns.put("TZ" , "(?:[PMCE][SD]T|UTC)"); + patterns.put("DATESTAMP_OTHER" , "%{DAY} %{MONTH} %{MONTHDAY} %{TIME} %{TZ} %{YEAR}"); + patterns.put("HTTPDATE" , "%{MONTHDAY}/%{MONTH}/%{YEAR}:%{TIME} %{INT}"); + patterns.put("HOST" , "%{HOSTNAME:UNWANTED}"); + patterns.put("DATESTAMP_EVENTLOG" , "%{YEAR}%{MONTHNUM2}%{MONTHDAY}%{HOUR}%{MINUTE}%{SECOND}"); + patterns.put("MESSAGESLOG" , "%{SYSLOGBASE} %{DATA}"); + patterns.put("WINDOWSMAC" , "(?:(?:[A-Fa-f0-9]{2}-){5}[A-Fa-f0-9]{2})"); + patterns.put("YEAR" , "(?>\\d\\d){1,2}"); + patterns.put("POSINT" , "\\b(?:[1-9][0-9]*)\\b"); + patterns.put("USERNAME" , "[a-zA-Z0-9._-]+"); + patterns.put("MINUTE" , "(?:[0-5][0-9])"); + patterns.put("UUID" , "[A-Fa-f0-9]{8}-(?:[A-Fa-f0-9]{4}-){3}[A-Fa-f0-9]{12}"); + patterns.put("DATE_US" , "%{MONTHNUM}[/-]%{MONTHDAY}[/-]%{YEAR}"); + patterns.put("LOGLEVEL" , "([A|a]lert|ALERT|[T|t]race|TRACE|[D|d]ebug|DEBUG|[N|n]otice|NOTICE|[I|i]nfo|INFO|[W|w]arn?(?:ing)?|WARN?(?:ING)?|[E|e]rr?(?:or)?|ERR?(?:OR)?|[C|c]rit?(?:ical)?|CRIT?(?:ICAL)?|[F|f]atal|FATAL|[S|s]evere|SEVERE|EMERG(?:ENCY)?|[Ee]merg(?:ency)?)"); + patterns.put("WINPATH" , "(?>[A-Za-z]+:|\\)(?:\\[^\\?*]*)+"); + patterns.put("NUMBER" , "(?:%{BASE10NUM:UNWANTED})"); + patterns.put("WORD" , "\\b\\w+\\b"); + patterns.put("QS" , "%{QUOTEDSTRING:UNWANTED}"); + patterns.put("TIMESTAMP_ISO8601" , "%{YEAR}-%{MONTHNUM}-%{MONTHDAY}[T ]%{HOUR}:?%{MINUTE}(?::?%{SECOND})?%{ISO8601_TIMEZONE}?"); + patterns.put("MONTHNUM" , "(?:0?[1-9]|1[0-2])"); + patterns.put("NOTSPACE" , "\\S+"); + patterns.put("IPV6" , "((([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))|(([0-9A-Fa-f]{1,4}:){6}(:[0-9A-Fa-f]{1,4}|((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){5}(((:[0-9A-Fa-f]{1,4}){1,2})|:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){4}(((:[0-9A-Fa-f]{1,4}){1,3})|((:[0-9A-Fa-f]{1,4})?:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){3}(((:[0-9A-Fa-f]{1,4}){1,4})|((:[0-9A-Fa-f]{1,4}){0,2}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){2}(((:[0-9A-Fa-f]{1,4}){1,5})|((:[0-9A-Fa-f]{1,4}){0,3}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){1}(((:[0-9A-Fa-f]{1,4}){1,6})|((:[0-9A-Fa-f]{1,4}){0,4}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(:(((:[0-9A-Fa-f]{1,4}){1,7})|((:[0-9A-Fa-f]{1,4}){0,5}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:)))(%.+)?"); + patterns.put("IPV4" , "(?[+-]?(?:(?:[0-9]+(?:\\.[0-9]+)?)|(?:\\.[0-9]+)))"); + patterns.put("NONNEGINT" , "\\b(?:[0-9]+)\\b"); + patterns.put("DATESTAMP_RFC822" , "%{DAY} %{MONTH} %{MONTHDAY} %{YEAR} %{TIME} %{TZ}"); + patterns.put("URI" , "%{URIPROTO}://(?:%{USER}(?::[^@]*)?@)?(?:%{URIHOST})?(?:%{URIPATHPARAM})?"); + patterns.put("INT" , "(?:[+-]?(?:[0-9]+))"); + patterns.put("SPACE" , "\\s*"); + patterns.put("GREEDYDATA" , ".*"); + patterns.put("ISO8601_SECOND" , "(?:%{SECOND}|60)"); + patterns.put("UNIXPATH" , "(?>/(?>[\\w_%!$@:.,~-]+|\\.)*)+"); + patterns.put("TTY" , "(?:/dev/(pts|tty([pq])?)(\\w+)?/?(?:[0-9]+))"); + patterns.put("COMBINEDAPACHELOG" , "%{COMMONAPACHELOG} %{QS:referrer} %{QS:agent}"); + patterns.put("URIPROTO" , "[A-Za-z]+(\\+[A-Za-z+]+)?"); + patterns.put("HOSTPORT" , "(?:%{IPORHOST}:%{POSINT:PORT})"); + patterns.put("SYSLOGPROG" , "%{PROG:program}(?:\\[%{POSINT:pid}\\])?"); + patterns.put("SYSLOGBASE" , "%{SYSLOGTIMESTAMP:timestamp} (?:%{SYSLOGFACILITY} )?%{SYSLOGHOST:logsource} %{SYSLOGPROG}:"); + patterns.put("SYSLOGFACILITY" , "<%{NONNEGINT:facility}.%{NONNEGINT:priority}>"); + patterns.put("DATESTAMP" , "%{DATE}[- ]%{TIME}"); + patterns.put("TIME" , "(?!<[0-9])%{HOUR}:%{MINUTE}(?::%{SECOND})(?![0-9])"); + patterns.put("USER" , "%{USERNAME:UNWANTED}"); + patterns.put("COMMONMAC" , "(?:(?:[A-Fa-f0-9]{2}:){5}[A-Fa-f0-9]{2})"); + patterns.put("IPORHOST" , "(?:%{HOSTNAME:UNWANTED}|%{IP:UNWANTED})"); + patterns.put("BASE16NUM" , "(?(?\"(?>\\\\.|[^\\\\\"]+)+\"|\"\"|(?>'(?>\\\\.|[^\\\\']+)+')|''|(?>`(?>\\\\.|[^\\\\`]+)+`)|``))"); + patterns.put("DAY" , "(?:Mon(?:day)?|Tue(?:sday)?|Wed(?:nesday)?|Thu(?:rsday)?|Fri(?:day)?|Sat(?:urday)?|Sun(?:day)?)"); + patterns.put("ISO8601_TIMEZONE" , "(?:Z|[+-]%{HOUR}(?::?%{MINUTE}))"); + patterns.put("PROG" , "(?:[\\w._/%-]+)"); + return patterns; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java index 7d51038cd..b9dd2df83 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java @@ -26,14 +26,12 @@ import java.util.regex.Pattern; import static java.lang.String.format; +import static org.opensearch.sql.common.grok.DefaultPatterns.withDefaultPatterns; public class GrokCompiler implements Serializable { - - // We don't want \n and commented line - private static final Pattern patternLinePattern = Pattern.compile("^([A-z0-9_]+)\\s+(.*)$"); - + /** {@code Grok} patterns definitions. */ - private final Map grokPatternDefinitions = new HashMap<>(); + private final Map grokPatternDefinitions = withDefaultPatterns(new HashMap<>()); private GrokCompiler() {} @@ -41,76 +39,6 @@ public static GrokCompiler newInstance() { return new GrokCompiler(); } - public Map getPatternDefinitions() { - return grokPatternDefinitions; - } - - /** - * Registers a new pattern definition. - * - * @param name : Pattern Name - * @param pattern : Regular expression Or {@code Grok} pattern - * @throws GrokException runtime expt - */ - public void register(String name, String pattern) { - name = Objects.requireNonNull(name).trim(); - pattern = Objects.requireNonNull(pattern).trim(); - - if (!name.isEmpty() && !pattern.isEmpty()) { - grokPatternDefinitions.put(name, pattern); - } - } - - /** Registers multiple pattern definitions. */ - public void register(Map patternDefinitions) { - Objects.requireNonNull(patternDefinitions); - patternDefinitions.forEach(this::register); - } - - /** - * Registers multiple pattern definitions from a given inputStream, and decoded as a UTF-8 source. - */ - public void register(InputStream input) throws IOException { - register(input, StandardCharsets.UTF_8); - } - - /** Registers multiple pattern definitions from a given inputStream. */ - public void register(InputStream input, Charset charset) throws IOException { - try (BufferedReader in = new BufferedReader(new InputStreamReader(input, charset))) { - in.lines() - .map(patternLinePattern::matcher) - .filter(Matcher::matches) - .forEach(m -> register(m.group(1), m.group(2))); - } - } - - /** Registers multiple pattern definitions from a given Reader. */ - public void register(Reader input) throws IOException { - new BufferedReader(input) - .lines() - .map(patternLinePattern::matcher) - .filter(Matcher::matches) - .forEach(m -> register(m.group(1), m.group(2))); - } - - public void registerDefaultPatterns() { - registerPatternFromClasspath("/patterns/patterns"); - } - - public void registerPatternFromClasspath(String path) throws GrokException { - registerPatternFromClasspath(path, StandardCharsets.UTF_8); - } - - /** registerPatternFromClasspath. */ - public void registerPatternFromClasspath(String path, Charset charset) throws GrokException { - final InputStream inputStream = this.getClass().getResourceAsStream(path); - try (Reader reader = new InputStreamReader(inputStream, charset)) { - register(reader); - } catch (IOException e) { - throw new GrokException(e.getMessage(), e); - } - } - /** Compiles a given Grok pattern and returns a Grok object which can parse the pattern. */ public Grok compile(String pattern) throws IllegalArgumentException { return compile(pattern, false); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 13b5c20ef..86970cefb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -185,6 +185,7 @@ public enum BuiltinFunctionName { NESTED(FunctionName.of("nested")), PERCENTILE(FunctionName.of("percentile")), PERCENTILE_APPROX(FunctionName.of("percentile_approx")), + APPROX_COUNT_DISTINCT(FunctionName.of("approx_count_distinct")), /** Text Functions. */ ASCII(FunctionName.of("ascii")), @@ -213,6 +214,7 @@ public enum BuiltinFunctionName { JSON_OBJECT(FunctionName.of("json_object")), JSON_ARRAY(FunctionName.of("json_array")), JSON_ARRAY_LENGTH(FunctionName.of("json_array_length")), + TO_JSON_STRING(FunctionName.of("to_json_string")), JSON_EXTRACT(FunctionName.of("json_extract")), JSON_KEYS(FunctionName.of("json_keys")), JSON_VALID(FunctionName.of("json_valid")), @@ -228,6 +230,7 @@ public enum BuiltinFunctionName { /** COLLECTION Functions **/ ARRAY(FunctionName.of("array")), + ARRAY_LENGTH(FunctionName.of("array_length")), /** LAMBDA Functions **/ ARRAY_FORALL(FunctionName.of("forall")), @@ -289,7 +292,6 @@ public enum BuiltinFunctionName { MULTIMATCHQUERY(FunctionName.of("multimatchquery")), WILDCARDQUERY(FunctionName.of("wildcardquery")), WILDCARD_QUERY(FunctionName.of("wildcard_query")), - COALESCE(FunctionName.of("coalesce")); private FunctionName name; @@ -330,6 +332,7 @@ public FunctionName getName() { .put("take", BuiltinFunctionName.TAKE) .put("percentile", BuiltinFunctionName.PERCENTILE) .put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX) + .put("approx_count_distinct", BuiltinFunctionName.APPROX_COUNT_DISTINCT) .build(); public static Optional of(String str) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index fe8fdb0bd..2aef24e50 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -85,6 +85,11 @@ public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext c return unresolved.accept(this, context); } + /** This method is only for analyze the join condition expression */ + public Expression analyzeJoinCondition(UnresolvedExpression unresolved, CatalystPlanContext context) { + return context.resolveJoinCondition(unresolved, this::analyze); + } + @Override public Expression visitLiteral(Literal node, CatalystPlanContext context) { return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( @@ -172,6 +177,11 @@ public Expression visitCompare(Compare node, CatalystPlanContext context) { @Override public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + // When the qualified name is part of join condition, for example: table1.id = table2.id + // findRelation(context.traversalContext() only returns relation table1 which cause table2.id fail to resolve + if (context.isResolvingJoinCondition()) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } List relation = findRelation(context.traversalContext()); if (!relation.isEmpty()) { Optional resolveField = resolveField(relation, node, context.getRelations()); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 61762f616..1621e65d5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl; +import lombok.Getter; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -25,6 +26,7 @@ import java.util.Stack; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -39,19 +41,19 @@ public class CatalystPlanContext { /** * Catalyst relations list **/ - private List projectedFields = new ArrayList<>(); + @Getter private List projectedFields = new ArrayList<>(); /** * Catalyst relations list **/ - private List relations = new ArrayList<>(); + @Getter private List relations = new ArrayList<>(); /** * Catalyst SubqueryAlias list **/ - private List subqueryAlias = new ArrayList<>(); + @Getter private List subqueryAlias = new ArrayList<>(); /** * Catalyst evolving logical plan **/ - private Stack planBranches = new Stack<>(); + @Getter private Stack planBranches = new Stack<>(); /** * The current traversal context the visitor is going threw */ @@ -60,28 +62,12 @@ public class CatalystPlanContext { /** * NamedExpression contextual parameters **/ - private final Stack namedParseExpressions = new Stack<>(); + @Getter private final Stack namedParseExpressions = new Stack<>(); /** * Grouping NamedExpression contextual parameters **/ - private final Stack groupingParseExpressions = new Stack<>(); - - public Stack getPlanBranches() { - return planBranches; - } - - public List getRelations() { - return relations; - } - - public List getSubqueryAlias() { - return subqueryAlias; - } - - public List getProjectedFields() { - return projectedFields; - } + @Getter private final Stack groupingParseExpressions = new Stack<>(); public LogicalPlan getPlan() { if (this.planBranches.isEmpty()) return null; @@ -101,10 +87,6 @@ public Stack traversalContext() { return planTraversalContext; } - public Stack getNamedParseExpressions() { - return namedParseExpressions; - } - public void setNamedParseExpressions(Stack namedParseExpressions) { this.namedParseExpressions.clear(); this.namedParseExpressions.addAll(namedParseExpressions); @@ -114,10 +96,6 @@ public Optional popNamedParseExpressions() { return namedParseExpressions.isEmpty() ? Optional.empty() : Optional.of(namedParseExpressions.pop()); } - public Stack getGroupingParseExpressions() { - return groupingParseExpressions; - } - /** * define new field * @@ -154,13 +132,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + public LogicalPlan applyBranches(List> plans) { plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); planBranches.remove(0); return getPlan(); - } - + } + /** * append plan with evolving plans branches * @@ -210,7 +188,7 @@ public LogicalPlan reduce(BiFunction tran return result; }).orElse(getPlan())); } - + /** * apply for each plan with the given function * @@ -288,4 +266,21 @@ public static Optional findRelation(LogicalPlan plan) { return Optional.empty(); } + @Getter private boolean isResolvingJoinCondition = false; + + /** + * Resolve the join condition with the given function. + * A flag will be set to true ahead expression resolving, then false after resolving. + * @param expr + * @param transformFunction + * @return + */ + public Expression resolveJoinCondition( + UnresolvedExpression expr, + BiFunction transformFunction) { + isResolvingJoinCondition = true; + Expression result = transformFunction.apply(expr, this); + isResolvingJoinCondition = false; + return result; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 669459fba..000c16b92 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -11,15 +11,9 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; -import org.apache.spark.sql.catalyst.expressions.In$; -import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.InSubquery$; -import org.apache.spark.sql.catalyst.expressions.LessThan; -import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.ListQuery$; -import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; @@ -37,6 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; @@ -52,6 +47,7 @@ import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Correlation; +import org.opensearch.sql.ast.tree.CountedAggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; @@ -71,7 +67,6 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; -import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -89,10 +84,12 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; @@ -130,6 +127,10 @@ public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { return node.getPlan().accept(this, context); } + public LogicalPlan visitFirstChild(Node node, CatalystPlanContext context) { + return node.getChild().get(0).accept(this, context); + } + @Override public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { node.getStatement().accept(this, context); @@ -138,6 +139,7 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { + //relations doesnt have a visitFirstChild call since its the leaf of the AST tree if (node instanceof DescribeRelation) { TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName()); return context.with( @@ -149,15 +151,19 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { } //regular sql algebraic relations node.getQualifiedNames().forEach(q -> - // Resolving the qualifiedName which is composed of a datasource.schema.table - context.withRelation(new UnresolvedRelation(getTableIdentifier(q).nameParts(), CaseInsensitiveStringMap.empty(), false)) + // TODO Do not support 4+ parts table identifier in future (may be reverted this PR in 0.8.0) + // node.getQualifiedNames.getParts().size() > 3 + // A Spark TableIdentifier should only contain 3 parts: tableName, databaseName and catalogName. + // If the qualifiedName has more than 3 parts, + // we merge all parts from 3 to last parts into the tableName as one whole + context.withRelation(new UnresolvedRelation(seq(q.getParts()), CaseInsensitiveStringMap.empty(), false)) ); return context.getPlan(); } @Override public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { Expression conditionExpression = visitExpression(node.getCondition(), context); Optional innerConditionExpression = context.popNamedParseExpressions(); @@ -171,8 +177,7 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { */ @Override public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); return context.apply( searchSide -> { LogicalPlan lookupTable = node.getLookupRelation().accept(this, context); Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context); @@ -228,8 +233,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { @Override public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); node.getSortByField() .ifPresent(sortField -> { Expression sortFieldExpression = visitExpression(sortField, context); @@ -245,14 +249,14 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); } - trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), node.getSortByField(), context)); return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); } @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); context.reduce((left, right) -> { visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); Seq fields = context.retainAllNamedParseExpressions(e -> e); @@ -270,10 +274,11 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex @Override public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); - Optional joinCondition = node.getJoinCondition().map(c -> visitExpression(c, context)); + Optional joinCondition = node.getJoinCondition() + .map(c -> expressionAnalyzer.analyzeJoinCondition(c, context)); context.retainAllNamedParseExpressions(p -> p); context.retainAllPlans(p -> p); return join(left, right, node.getJoinType(), joinCondition, node.getJoinHint()); @@ -282,7 +287,7 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { @Override public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { var alias = org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$.MODULE$.apply(node.getAlias(), p); context.withSubqueryAlias(alias); @@ -293,7 +298,7 @@ public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext co @Override public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); if (!groupExpList.isEmpty()) { @@ -324,9 +329,9 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan)); } //visit TopAggregation results limit - if ((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) { + if ((node instanceof CountedAggregation) && ((CountedAggregation) node).getResults().isPresent()) { context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( - ((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); + ((CountedAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); } return logicalPlan; } @@ -339,7 +344,7 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { @Override public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); List partitionExpList = visitExpressionList(node.getPartExprList(), context); @@ -369,10 +374,11 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { + //update plan's context prior to visiting node children if (node.isExcluded()) { List intersect = context.getProjectedFields().stream() - .filter(node.getProjectList()::contains) - .collect(Collectors.toList()); + .filter(node.getProjectList()::contains) + .collect(Collectors.toList()); if (!intersect.isEmpty()) { // Fields in parent projection, but they have be excluded in child. For example, // source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved" @@ -381,7 +387,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { } else { context.withProjectedFields(node.getProjectList()); } - LogicalPlan child = node.getChild().get(0).accept(this, context); + LogicalPlan child = visitFirstChild(node, context); visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions @@ -402,7 +408,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { @Override public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); visitFieldList(node.getSortList(), context); Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); @@ -410,20 +416,20 @@ public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { @Override public LogicalPlan visitHead(Head node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( node.getSize(), DataTypes.IntegerType), p)); } @Override public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { - fieldSummary.getChild().get(0).accept(this, context); + visitFirstChild(fieldSummary, context); return FieldSummaryTransformer.translate(fieldSummary, context); } @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { - fillNull.getChild().get(0).accept(this, context); + visitFirstChild(fillNull, context); List aliases = new ArrayList<>(); for(FillNull.NullableFieldFill nullableFieldFill : fillNull.getNullableFieldFills()) { Field field = nullableFieldFill.getNullableFieldReference(); @@ -454,18 +460,39 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) @Override public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { - flatten.getChild().get(0).accept(this, context); + visitFirstChild(flatten, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); } - Expression field = visitExpression(flatten.getFieldToBeFlattened(), context); + Expression field = visitExpression(flatten.getField(), context); context.retainAllNamedParseExpressions(p -> (NamedExpression) p); FlattenGenerator flattenGenerator = new FlattenGenerator(field); context.apply(p -> new Generate(new GeneratorOuter(flattenGenerator), seq(), true, (Option) None$.MODULE$, seq(), p)); return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); } + @Override + public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, CatalystPlanContext context) { + visitFirstChild(node, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(node.getField(), context); + Optional alias = node.getAlias().map(aliasNode -> visitExpression(aliasNode, context)); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + Explode explodeGenerator = new Explode(field); + scala.collection.mutable.Seq outputs = alias.isEmpty() ? seq() : seq(alias.get()); + if(alias.isEmpty()) + return context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); + else { + //in case an alias does appear - remove the original field from the returning columns + context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } @@ -483,7 +510,7 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); Expression sourceField = visitExpression(node.getSourceField(), context); ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); @@ -493,7 +520,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { @Override public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); @@ -510,7 +537,7 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aliases = new ArrayList<>(); List letExpressions = node.getExpressionList(); for (Let let : letExpressions) { @@ -524,8 +551,7 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { List expressionList = visitExpressionList(aliases, context); Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - return child; + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); } @Override @@ -550,7 +576,7 @@ public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext @Override public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List options = node.getOptions(); Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 4e6b1f131..7d1cc072b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -131,6 +131,12 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext return new Filter(internalVisitExpression(ctx.logicalExpression())); } + @Override + public UnresolvedPlan visitExpandCommand(OpenSearchPPLParser.ExpandCommandContext ctx) { + return new Expand((Field) internalVisitExpression(ctx.fieldExpression()), + ctx.alias!=null ? Optional.of(internalVisitExpression(ctx.alias)) : Optional.empty()); + } + @Override public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { return new Correlation(ctx.correlationType().getText(), @@ -155,14 +161,25 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct joinType = Join.JoinType.CROSS; } Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - String leftAlias = ctx.sideAlias().leftAlias.getText(); - String rightAlias = ctx.sideAlias().rightAlias.getText(); + Optional leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty(); + Optional rightAlias = Optional.empty(); if (ctx.tableOrSubqueryClause().alias != null) { - // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous - throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText()); + } + if (ctx.sideAlias().rightAlias != null) { + rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText()); } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); - UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); + // Add a SubqueryAlias to the right plan when the right alias is present and no duplicated alias existing in right. + UnresolvedPlan right; + if (rightAlias.isEmpty() || + (rightRelation instanceof SubqueryAlias && + rightAlias.get().equals(((SubqueryAlias) rightRelation).getAlias()))) { + right = rightRelation; + } else { + right = new SubqueryAlias(rightAlias.get(), rightRelation); + } Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -370,7 +387,7 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo /** Lookup command */ @Override public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { - Relation lookupRelation = new Relation(this.internalVisitExpression(ctx.tableSource())); + Relation lookupRelation = new Relation(Collections.singletonList(this.internalVisitExpression(ctx.tableSource()))); Lookup.OutputStrategy strategy = ctx.APPEND() != null ? Lookup.OutputStrategy.APPEND : Lookup.OutputStrategy.REPLACE; java.util.Map lookupMappingList = buildLookupPair(ctx.lookupMappingList().lookupPair()); @@ -415,8 +432,9 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.TOP_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -441,14 +459,12 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) .collect(Collectors.toList())) .orElse(emptyList()) ); - UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); - TopAggregation aggregation = - new TopAggregation( - Optional.ofNullable((Literal) unresolvedPlan), + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new TopAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } /** Fieldsummary command. */ @@ -462,8 +478,9 @@ public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryC public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.RARE_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -488,12 +505,12 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct .collect(Collectors.toList())) .orElse(emptyList()) ); - RareAggregation aggregation = - new RareAggregation( + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new RareAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } @Override @@ -509,9 +526,8 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return ctx.alias == null - ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) - : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); + Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation; } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index bf029c49c..e758db7ef 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -159,6 +159,11 @@ public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArit ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); } + @Override + public UnresolvedExpression visitParentheticLogicalExpr(OpenSearchPPLParser.ParentheticLogicalExprContext ctx) { + return visit(ctx.logicalExpression()); // Discard parenthesis around + } + @Override public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) { return visit(ctx.valueExpression()); // Discard parenthesis around @@ -213,7 +218,8 @@ public UnresolvedExpression visitCountAllFunctionCall(OpenSearchPPLParser.CountA @Override public UnresolvedExpression visitDistinctCountFunctionCall(OpenSearchPPLParser.DistinctCountFunctionCallContext ctx) { - return new AggregateFunction("count", visit(ctx.valueExpression()), true); + String funcName = ctx.DISTINCT_COUNT_APPROX()!=null ? "approx_count_distinct" :"count"; + return new AggregateFunction(funcName, visit(ctx.valueExpression()), true); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java index 9788ac1bc..c06f37aa3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java @@ -57,6 +57,8 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); + case APPROX_COUNT_DISTINCT: + return new UnresolvedFunction(seq("APPROX_COUNT_DISTINCT"), seq(arg), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index e39c9ab38..0a4f19b53 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -26,8 +26,11 @@ import java.util.Map; import java.util.function.Function; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLLexer.DISTINCT_COUNT_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.APPROX_COUNT_DISTINCT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_SUB; @@ -58,6 +61,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TO_JSON_STRING; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_TIMESTAMP; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; @@ -102,9 +106,12 @@ public interface BuiltinFunctionTransformer { .put(COALESCE, "coalesce") .put(LENGTH, "length") .put(TRIM, "trim") + .put(ARRAY_LENGTH, "array_size") // json functions + .put(TO_JSON_STRING, "to_json") .put(JSON_KEYS, "json_object_keys") .put(JSON_EXTRACT, "get_json_object") + .put(APPROX_COUNT_DISTINCT, "approx_count_distinct") .build(); /** @@ -126,26 +133,12 @@ public interface BuiltinFunctionTransformer { .put( JSON_ARRAY_LENGTH, args -> { - // Check if the input is an array (from json_array()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - // Input is a JSON array - return UnresolvedFunction$.MODULE$.apply("json_array_length", - seq(UnresolvedFunction$.MODULE$.apply("to_json", seq(args), false)), false); - } else { - // Input is a JSON string - return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); - } + return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); }) .put( JSON, args -> { - // Check if the input is a named_struct (from json_object()) or a JSON string - if (args.get(0) instanceof UnresolvedFunction) { - return UnresolvedFunction$.MODULE$.apply("to_json", seq(args.get(0)), false); - } else { - return UnresolvedFunction$.MODULE$.apply("get_json_object", - seq(args.get(0), Literal$.MODULE$.apply("$")), false); - } + return UnresolvedFunction$.MODULE$.apply("get_json_object", seq(args.get(0), Literal$.MODULE$.apply("$")), false); }) .put( JSON_VALID, diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java index a463767f0..6a4d4b032 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java @@ -138,10 +138,6 @@ public static String extractPattern(String patterns, List columns) { public static class GrokExpression { private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); - static { - grokCompiler.registerDefaultPatterns(); - } - public static Expression getRegExpCommand(Expression sourceField, org.apache.spark.sql.catalyst.expressions.Literal patternLiteral, org.apache.spark.sql.catalyst.expressions.Literal groupIndexLiteral) { return new RegExpExtract(sourceField, patternLiteral, groupIndexLiteral); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java index 1dc7b9878..f959fe199 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java @@ -53,8 +53,15 @@ static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) { Option$.MODULE$.apply(qualifiedName.getParts().get(1)), Option$.MODULE$.apply(qualifiedName.getParts().get(0))); } else { - throw new IllegalArgumentException("Invalid table name: " + qualifiedName - + " Syntax: [ database_name. ] table_name"); + // TODO Do not support 4+ parts table identifier in future (may be reverted this PR in 0.8.0) + // qualifiedName.getParts().size() > 3 + // A Spark TableIdentifier should only contain 3 parts: tableName, databaseName and catalogName. + // If the qualifiedName has more than 3 parts, + // we merge all parts from 3 to last parts into the tableName as one whole + identifier = new TableIdentifier( + String.join(".", qualifiedName.getParts().subList(2, qualifiedName.getParts().size())), + Option$.MODULE$.apply(qualifiedName.getParts().get(1)), + Option$.MODULE$.apply(qualifiedName.getParts().get(0))); } return identifier; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 67603ccc7..647f4542e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -5,31 +5,40 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.*; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.collection.mutable.Seq; import scala.Option; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; +import static scala.collection.JavaConverters.asScalaBufferConverter; public interface TrendlineCatalystUtils { - static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, Optional sortField, CatalystPlanContext context) { return computations.stream() - .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, sortField, context)) .collect(Collectors.toList()); } - static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, Optional sortField, CatalystPlanContext context) { + //window lower boundary expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); Expression windowLowerBoundary = context.popNamedParseExpressions().get(); @@ -40,26 +49,28 @@ static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expre seq(), new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); - if (node.getComputationType() == Trendline.TrendlineType.SMA) { - //calculate avg value of the data field - expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); - Expression avgFunction = context.popNamedParseExpressions().get(); - - //sma window - WindowExpression sma = new WindowExpression( - avgFunction, - windowDefinition); - - CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); - - return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, - node.getAlias(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList())); - } else { - throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + switch (node.getComputationType()) { + case SMA: + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return getAlias(node.getAlias(), smaOrNull); + case WMA: + if (sortField.isPresent()) { + return getWMAComputationExpression(expressionVisitor, node, sortField.get(), context); + } else { + throw new IllegalArgumentException(node.getComputationType()+" requires a sort field for computation"); + } + default: + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); } } @@ -84,4 +95,136 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr ); return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); } + + /** + * Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan + * with configuration [dataField=salary, sortField=age, dataPoints=3] + * -- +- 'Project [ + * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) + * -- AS WMA#702] + * . + * And the corresponded SQL query: + * . + * SELECT name, salary, + * ( nth_value(salary, 1) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *1 + + * nth_value(salary, 2) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *2 + + * nth_value(salary, 3) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *3 )/6 AS WMA + * FROM employees + * ORDER BY age; + * + * @param visitor Visitor instance to process any UnresolvedExpression. + * @param node Trendline command's arguments. + * @param sortField Field used for window aggregation. + * @param context Context instance to retrieved Expression in resolved form. + * @return a NamedExpression instance which will calculate WMA with provided argument. + */ + private static NamedExpression getWMAComputationExpression(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + Field sortField, + CatalystPlanContext context) { + int dataPoints = node.getNumberOfDataPoints(); + //window lower boundary + Expression windowLowerBoundary = parseIntToExpression(visitor, context, + Math.negateExact(dataPoints - 1)); + //window definition + visitor.analyze(sortField, context); + Expression sortDefinition = context.popNamedParseExpressions().get(); + WindowSpecDefinition windowDefinition = getWmaCommonWindowDefinition( + sortDefinition, + SortUtils.isSortedAscending(sortField), + windowLowerBoundary); + // Divisor + Expression divisor = parseIntToExpression(visitor, context, + (dataPoints * (dataPoints + 1) / 2)); + // Aggregation + Expression wmaExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) + .stream() + .reduce(Add::new) + .orElse(null); + + return getAlias(node.getAlias(), new Divide(wmaExpression, divisor)); + } + + /** + * Helper method to produce an Alias Expression with provide value and name. + * @param name The name for the Alias. + * @param expression The expression which will be evaluated. + * @return An Alias instance with logical plan representation of `expression AS name`. + */ + private static NamedExpression getAlias(String name, Expression expression) { + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression, + name, + NamedExpression.newExprId(), + seq(Collections.emptyList()), + Option.empty(), + seq(Collections.emptyList())); + } + + /** + * Helper method to retrieve an Int expression instance for logical plan composition purpose. + * @param expressionVisitor Visitor instance to process the incoming object. + * @param context Context instance to retrieve the Expression instance. + * @param i Target value for the expression. + * @return An expression object which contain integer value i. + */ + static Expression parseIntToExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { + expressionVisitor.visitLiteral(new Literal(i, + DataType.INTEGER), context); + return context.popNamedParseExpressions().get(); + } + + + /** + * Helper method to retrieve a WindowSpecDefinition with provided sorting condition. + * `windowspecdefinition('sortField ascending NULLS FIRST, specifiedwindowframe(RowFrame, windowLowerBoundary, currentrow$())` + * + * @param sortField The field being used for the sorting operation. + * @param ascending The boolean instance for the sorting order. + * @param windowLowerBoundary The Integer expression instance which specify the even lookbehind / lookahead. + * @return A WindowSpecDefinition instance which will be used to composite the WMA calculation. + */ + static WindowSpecDefinition getWmaCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) { + return new WindowSpecDefinition( + seq(), + seq(SortUtils.sortOrder(sortField, ascending)), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + } + + /** + * To produce a list of Expressions responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below. + * (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * + * @param visitor Visitor instance to resolve Expression. + * @param node Treeline command instruction. + * @param context Context instance to retrieve the resolved expression. + * @param windowDefinition The windowDefinition for the individual datapoint lookbehind / lookahead. + * @param dataPoints Number of data-points for WMA calculation, this will always equal to number of Expression being generated. + * @return List instance which contain the SQL statement for WMA individual datapoint's calculations. + */ + private static List getNthValueAggregations(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + CatalystPlanContext context, + WindowSpecDefinition windowDefinition, + int dataPoints) { + List expressions = new ArrayList<>(); + for (int i = 1; i <= dataPoints; i++) { + // Get the offset parameter + Expression offSetExpression = parseIntToExpression(visitor, context, i); + // Get the dataField in Expression + visitor.analyze(node.getDataField(), context); + Expression dataField = context.popNamedParseExpressions().get(); + // nth_value Expression + UnresolvedFunction nthValueExp = new UnresolvedFunction( + asScalaBufferConverter(List.of("nth_value")).asScala().seq(), + asScalaBufferConverter(List.of(dataField, offSetExpression)).asScala().seq(), + false, empty(), false); + + expressions.add(new Multiply( + new WindowExpression(nthValueExp, windowDefinition), offSetExpression)); + } + return expressions; + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 9946bff6a..42cc7ed10 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -754,6 +754,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test approx distinct count product group by brand sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(product) by brand | sort brand"), + context) + val star = Seq(UnresolvedStar(None)) + val brandField = UnresolvedAttribute("brand") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(brandField, "brand")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(productField), isDistinct = true), + "distinct_count_approx(product)")() + val brandAlias = Alias(brandField, "brand")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, brandAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(brandField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count product with alias and filter") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -803,6 +831,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count age by span of interval of 10 years query with sort using approximation ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(age) by span(age, 10) as age_span | sort age"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count status by week window and group by status with limit") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -838,6 +894,42 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count status by week window and group by status with limit using approximation") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(status) by span(@timestamp, 1w) as status_count_by_week, status | head 100"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusCount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "status_count_by_week")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(statusCount), isDistinct = true), + "distinct_count_approx(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + table) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test("multiple stats - test average price and average age") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 2a569dbdf..1f081bd72 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, GreaterThan, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, EqualTo, GreaterThan, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand @@ -27,7 +27,8 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() - test("test error describe clause") { + // TODO Do not support 4+ parts table identifier in future (may be reverted this PR in 0.8.0) + ignore("test error describe clause") { val context = new CatalystPlanContext val thrown = intercept[IllegalArgumentException] { planTransformer.visit(plan(pplParser, "describe t.b.c.d"), context) @@ -50,6 +51,69 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + // TODO Do not support 4+ parts table identifier in future (may be reverted this PR in 0.8.0) + test("test describe with backticks and more then 3 parts") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "describe `t`.b.`c.d`.`e.f`"), context) + + val expectedPlan = DescribeTableCommand( + TableIdentifier("c.d.e.f", Option("b"), Option("t")), + Map.empty[String, String].empty, + isExtended = true, + output = DescribeRelation.getOutputAttrs) + comparePlans(expectedPlan, logPlan, false) + } + + test("test read table with backticks and more then 3 parts") { + val context = new CatalystPlanContext + val logPlan = { + planTransformer.visit(plan(pplParser, "source=`t`.b.`c.d`.`e.f`"), context) + } + + val table = UnresolvedRelation(Seq("t", "b", "c.d", "e.f")) + val expectedPlan = Project(Seq(UnresolvedStar(None)), table) + comparePlans(expectedPlan, logPlan, false) + } + + test("test describe with complex backticks and more then 3 parts") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "describe `_Basic`.default.`startTime:0,endTime:1`.`logGroups(logGroupIdentifier:['hello/service_log'])`"), + context) + + val expectedPlan = DescribeTableCommand( + TableIdentifier( + "startTime:0,endTime:1.logGroups(logGroupIdentifier:['hello/service_log'])", + Option("default"), + Option("_Basic")), + Map.empty[String, String].empty, + isExtended = true, + output = DescribeRelation.getOutputAttrs) + comparePlans(expectedPlan, logPlan, false) + } + + test("test read complex table with backticks and more then 3 parts") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=`_Basic`.default.`startTime:0,endTime:1`.`123.logGroups(logGroupIdentifier:['hello.world/service_log'])`"), + context) + val table = UnresolvedRelation( + Seq( + "_Basic", + "default", + "startTime:0,endTime:1", + "123.logGroups(logGroupIdentifier:['hello.world/service_log'])")) + val expectedPlan = Project(Seq(UnresolvedStar(None)), table) + comparePlans(expectedPlan, logPlan, false) + } + test("test describe FQN table clause") { val context = new CatalystPlanContext val logPlan = @@ -292,6 +356,44 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("Search multiple tables - with table alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """ + | source=table1, table2, table3 as t + | | where t.name = 'Molly' + |""".stripMargin), + context) + + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val star = UnresolvedStar(None) + val plan1 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table1))) + val plan2 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table2))) + val plan3 = Project( + Seq(star), + Filter( + EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")), + SubqueryAlias("t", table3))) + + val expectedPlan = + Union(Seq(plan1, plan2, plan3), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logPlan, false) + } + test("test fields + field list") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..2acaac529 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Explode, GeneratorOuter, Literal, RegExpExtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, Project} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanExpandCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test expand only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | expand field_with_array"), context) + + val relation = UnresolvedRelation(Seq("relation")) + val generator = Explode(UnresolvedAttribute("field_with_array")) + val generate = Generate(generator, seq(), false, None, seq(), relation) + val expectedPlan = Project(seq(UnresolvedStar(None)), generate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + s""" + | source = table + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin), + context) + + val relation = UnresolvedRelation(Seq("table")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), relation) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand on array field which is eval array=json_array") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields uid"), + context) + + val relation = UnresolvedRelation(Seq("table")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), relation) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project(seq(UnresolvedAttribute("uid")), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand only field with alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | expand field_with_array as array_list "), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val generate = Generate( + Explode(UnresolvedAttribute("field_with_array")), + seq(), + false, + None, + seq(UnresolvedAttribute("array_list")), + relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = + Aggregate(Seq(groupingState, groupingCompany), Seq(average, state, company), generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats with alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as workers | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("workers")), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval") { + val context = new CatalystPlanContext + val query = "source = table | expand employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval with fields and alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus " + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("worker")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = + Project(Seq(UnresolvedAttribute("worker"), UnresolvedAttribute("bonus")), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and fields") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=table | expand employee | parse description '(?.+@.+)' | fields employee, email"), + context) + val table = UnresolvedRelation(Seq("table")) + val generator = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generator) + val expectedPlan = + Project(Seq(UnresolvedAttribute("employee"), UnresolvedAttribute("email")), parseProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and flatten ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | expand employee | parse description '(?.+@.+)' | flatten roles "), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generateEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGrokTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGrokTranslatorTestSuite.scala index f33a4a66b..91da923de 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGrokTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGrokTranslatorTestSuite.scala @@ -30,7 +30,6 @@ class PPLLogicalPlanGrokTranslatorTestSuite test("test grok email & host expressions") { val grokCompiler = GrokCompiler.newInstance - grokCompiler.registerDefaultPatterns() /* Grok pattern to compile, here httpd logs */ /* Grok pattern to compile, here httpd logs */ val grok = grokCompiler.compile(".+@%{HOSTNAME:host}") diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 3ceff7735..f4ed397e3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -271,9 +271,9 @@ class PPLLogicalPlanJoinTranslatorTestSuite pplParser, s""" | source = $testTable1 - | | inner JOIN left = l,right = r ON l.id = r.id $testTable2 - | | left JOIN left = l,right = r ON l.name = r.name $testTable3 - | | cross JOIN left = l,right = r $testTable4 + | | inner JOIN left = l right = r ON l.id = r.id $testTable2 + | | left JOIN left = l right = r ON l.name = r.name $testTable3 + | | cross JOIN left = l right = r $testTable4 | """.stripMargin) val logicalPlan = planTransformer.visit(logPlan, context) val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) @@ -443,17 +443,17 @@ class PPLLogicalPlanJoinTranslatorTestSuite s""" | source = $testTable1 | | head 10 - | | inner JOIN left = l,right = r ON l.id = r.id + | | inner JOIN left = l right = r ON l.id = r.id | [ | source = $testTable2 | | where id > 10 | ] - | | left JOIN left = l,right = r ON l.name = r.name + | | left JOIN left = l right = r ON l.name = r.name | [ | source = $testTable3 | | fields id | ] - | | cross JOIN left = l,right = r + | | cross JOIN left = l right = r | [ | source = $testTable4 | | sort id @@ -565,4 +565,284 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test multiple joins with table alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with table and subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN left = l right = r ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN left = l right = r ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN left = l right = r ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("l", SubqueryAlias("t1", table1)), + SubqueryAlias("r", SubqueryAlias("t2", table2)), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + SubqueryAlias("l", joinPlan1), + SubqueryAlias("r", SubqueryAlias("t3", table3)), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + SubqueryAlias("l", joinPlan2), + SubqueryAlias("r", SubqueryAlias("t4", table4)), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins without table aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN ON table1.id = table2.id table2 + | | JOIN ON table1.id = table3.id table3 + | | JOIN ON table2.id = table4.id table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + table4, + Inner, + Some(EqualTo(UnresolvedAttribute("table2.id"), UnresolvedAttribute("table4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name table2 + | | JOIN right = t3 ON t1.name = t3.name table3 + | | JOIN right = t4 ON t2.name = t4.name table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test side alias will override the subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt + | | fields t1.name, t2.name + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", SubqueryAlias("tt", SubqueryAlias("ttt", table2))), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val expectedPlan = + Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index 216c0f232..6193bc43f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -48,7 +48,7 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', array(1, 2, 3)))"""), + plan(pplParser, """source=t a = to_json_string(json_object('key', array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -97,7 +97,9 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json(json_object('key', json_array(1, 2, 3)))"""), + plan( + pplParser, + """source=t a = to_json_string(json_object('key', json_array(1, 2, 3)))"""), context) val table = UnresolvedRelation(Seq("t")) @@ -139,25 +141,21 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } - test("test json_array_length(json_array())") { + test("test array_length(json_array())") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, """source=t a = json_array_length(json_array(1,2,3))"""), + plan(pplParser, """source=t a = array_length(json_array(1,2,3))"""), context) val table = UnresolvedRelation(Seq("t")) val jsonFunc = UnresolvedFunction( - "json_array_length", + "array_size", Seq( UnresolvedFunction( - "to_json", - Seq( - UnresolvedFunction( - "array", - Seq(Literal(1), Literal(2), Literal(3)), - isDistinct = false)), + "array", + Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false)), isDistinct = false) val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala new file mode 100644 index 000000000..a70415aab --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala @@ -0,0 +1,244 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Not, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanParenthesizedConditionTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test simple nested condition") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (age > 18 AND (state = 'California' OR state = 'New York'))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + GreaterThan(UnresolvedAttribute("age"), Literal(18)), + Or( + EqualTo(UnresolvedAttribute("state"), Literal("California")), + EqualTo(UnresolvedAttribute("state"), Literal("New York")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test nested condition with duplicated parentheses") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE ((((age > 18) AND ((((state = 'California') OR state = 'New York'))))))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + GreaterThan(UnresolvedAttribute("age"), Literal(18)), + Or( + EqualTo(UnresolvedAttribute("state"), Literal("California")), + EqualTo(UnresolvedAttribute("state"), Literal("New York")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test combining between function") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (year = 2023 AND (month BETWEEN 1 AND 6)) AND (age >= 31 OR country = 'Canada')"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val betweenCondition = And( + GreaterThanOrEqual(UnresolvedAttribute("month"), Literal(1)), + LessThanOrEqual(UnresolvedAttribute("month"), Literal(6))) + val filter = Filter( + And( + And(EqualTo(UnresolvedAttribute("year"), Literal(2023)), betweenCondition), + Or( + GreaterThanOrEqual(UnresolvedAttribute("age"), Literal(31)), + EqualTo(UnresolvedAttribute("country"), Literal("Canada")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple levels of nesting") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE ((state = 'Texas' OR state = 'California') AND (age < 30 OR (country = 'USA' AND year > 2020)))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Or( + EqualTo(UnresolvedAttribute("state"), Literal("Texas")), + EqualTo(UnresolvedAttribute("state"), Literal("California"))), + Or( + LessThan(UnresolvedAttribute("age"), Literal(30)), + And( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + GreaterThan(UnresolvedAttribute("year"), Literal(2020))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test with string functions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (LIKE(LOWER(name), 'a%') OR LIKE(LOWER(name), 'j%')) AND (LENGTH(state) > 6 OR (country = 'USA' AND age > 18))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Or( + UnresolvedFunction( + "like", + Seq( + UnresolvedFunction("lower", Seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("a%")), + isDistinct = false), + UnresolvedFunction( + "like", + Seq( + UnresolvedFunction("lower", Seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("j%")), + isDistinct = false)), + Or( + GreaterThan( + UnresolvedFunction("length", Seq(UnresolvedAttribute("state")), isDistinct = false), + Literal(6)), + And( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + GreaterThan(UnresolvedAttribute("age"), Literal(18))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex age ranges with nested conditions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (age BETWEEN 25 AND 40) AND ((state IN ('California', 'New York', 'Texas') AND year = 2023) OR (country != 'USA' AND (month = 1 OR month = 12)))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + And( + GreaterThanOrEqual(UnresolvedAttribute("age"), Literal(25)), + LessThanOrEqual(UnresolvedAttribute("age"), Literal(40))), + Or( + And( + In( + UnresolvedAttribute("state"), + Seq(Literal("California"), Literal("New York"), Literal("Texas"))), + EqualTo(UnresolvedAttribute("year"), Literal(2023))), + And( + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + Or( + EqualTo(UnresolvedAttribute("month"), Literal(1)), + EqualTo(UnresolvedAttribute("month"), Literal(12)))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test nested NOT conditions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE NOT (age < 18 OR (state = 'Alaska' AND year < 2020)) AND (country = 'USA' OR (country = 'Mexico' AND month BETWEEN 6 AND 8))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Not( + Or( + LessThan(UnresolvedAttribute("age"), Literal(18)), + And( + EqualTo(UnresolvedAttribute("state"), Literal("Alaska")), + LessThan(UnresolvedAttribute("year"), Literal(2020))))), + Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + And( + EqualTo(UnresolvedAttribute("country"), Literal("Mexico")), + And( + GreaterThanOrEqual(UnresolvedAttribute("month"), Literal(6)), + LessThanOrEqual(UnresolvedAttribute("month"), Literal(8)))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex boolean logic") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (NOT (year < 2020 OR age < 18)) AND ((state = 'Texas' AND month % 2 = 0) OR (country = 'Mexico' AND (year = 2023 OR (year = 2022 AND month > 6))))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Not( + Or( + LessThan(UnresolvedAttribute("year"), Literal(2020)), + LessThan(UnresolvedAttribute("age"), Literal(18)))), + Or( + And( + EqualTo(UnresolvedAttribute("state"), Literal("Texas")), + EqualTo( + UnresolvedFunction( + "%", + Seq(UnresolvedAttribute("month"), Literal(2)), + isDistinct = false), + Literal(0))), + And( + EqualTo(UnresolvedAttribute("country"), Literal("Mexico")), + Or( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + And( + EqualTo(UnresolvedAttribute("year"), Literal(2022)), + GreaterThan(UnresolvedAttribute("month"), Literal(6))))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 792a2dee6..106cba93a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -59,6 +59,42 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test simple rare command with a single field approximation") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=accounts | rare_approx address"), context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, tableRelation) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test simple rare command with a by field test") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index d22750ee0..ec1775631 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -6,12 +6,15 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} @@ -132,4 +135,147 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite Project(trendlineProjectList, sort)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } + + test("WMA - with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age)"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#0] ! +- 'Sort ['age ASC NULLS FIRST], true ! +- + * 'UnresolvedRelation [relation], [], false + */ + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("WMA - with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - multiple trendline commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), + context) + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - with negative dataPoint value") { + val context = new CatalystPlanContext + val exception = intercept[SyntaxCheckException]( + planTransformer + .visit(plan(pplParser, "source=relation | trendline sort age wma(-3, age)"), context)) + assert(exception.getMessage startsWith "Failed to parse query due to offending symbol [-]") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index c076f9974..63c120a2c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -10,7 +10,6 @@ import java.util.Locale import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException} import com.amazonaws.services.s3.model.AmazonS3Exception import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.commons.text.StringEscapeUtils.unescapeJava import org.opensearch.common.Strings import org.opensearch.flint.core.IRestHighLevelClient @@ -23,6 +22,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.exception.UnrecoverableException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY import org.apache.spark.sql.types._ @@ -45,13 +45,13 @@ trait FlintJobExecutor { this: Logging => val mapper = new ObjectMapper() - mapper.registerModule(DefaultScalaModule) + val throwableHandler = new ThrowableHandler() var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() var environmentProvider: EnvironmentProvider = new RealEnvironment() var enableHiveSupport: Boolean = true - // termiante JVM in the presence non-deamon thread before exiting + // terminate JVM in the presence non-daemon thread before exiting var terminateJVM = true // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, @@ -437,39 +437,42 @@ trait FlintJobExecutor { } private def handleQueryException( - e: Exception, + t: Throwable, messagePrefix: String, errorSource: Option[String] = None, statusCode: Option[Int] = None): String = { - val errorMessage = s"$messagePrefix: ${e.getMessage}" - val errorDetails = Map("Message" -> errorMessage) ++ - errorSource.map("ErrorSource" -> _) ++ - statusCode.map(code => "StatusCode" -> code.toString) + throwableHandler.setThrowable(t) + + val errorMessage = s"$messagePrefix: ${t.getMessage}" + val errorDetails = new java.util.LinkedHashMap[String, String]() + errorDetails.put("Message", errorMessage) + errorSource.foreach(es => errorDetails.put("ErrorSource", es)) + statusCode.foreach(code => errorDetails.put("StatusCode", code.toString)) val errorJson = mapper.writeValueAsString(errorDetails) // CustomLogging will call log4j logger.error() underneath statusCode match { case Some(code) => - CustomLogging.logError(new OperationMessage(errorMessage, code), e) + CustomLogging.logError(new OperationMessage(errorMessage, code), t) case None => - CustomLogging.logError(errorMessage, e) + CustomLogging.logError(errorMessage, t) } errorJson } - def getRootCause(e: Throwable): Throwable = { - if (e.getCause == null) e - else getRootCause(e.getCause) + def getRootCause(t: Throwable): Throwable = { + if (t.getCause == null) t + else getRootCause(t.getCause) } /** * This method converts query exception into error string, which then persist to query result * metadata */ - def processQueryException(ex: Exception): String = { - getRootCause(ex) match { + def processQueryException(throwable: Throwable): String = { + getRootCause(throwable) match { case r: ParseException => handleQueryException(r, ExceptionMessages.SyntaxErrorPrefix) case r: AmazonS3Exception => @@ -496,15 +499,15 @@ trait FlintJobExecutor { handleQueryException(r, ExceptionMessages.QueryAnalysisErrorPrefix) case r: SparkException => handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix) - case r: Exception => - val rootCauseClassName = r.getClass.getName - val errMsg = r.getMessage + case t: Throwable => + val rootCauseClassName = t.getClass.getName + val errMsg = t.getMessage if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" && errMsg.contains("com.amazonaws.services.glue.model.AccessDeniedException")) { val e = new SecurityException(ExceptionMessages.GlueAccessDeniedMessage) handleQueryException(e, ExceptionMessages.QueryRunErrorPrefix) } else { - handleQueryException(r, ExceptionMessages.QueryRunErrorPrefix) + handleQueryException(t, ExceptionMessages.QueryRunErrorPrefix) } } } @@ -533,6 +536,14 @@ trait FlintJobExecutor { throw t } + def checkAndThrowUnrecoverableExceptions(): Unit = { + throwableHandler.exceptionThrown.foreach { + case e: UnrecoverableException => + throw e + case _ => // Do nothing for other types of exceptions + } + } + def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { if (Strings.isNullOrEmpty(className)) { defaultConstructor @@ -552,5 +563,4 @@ trait FlintJobExecutor { } } } - } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index ef0e76557..6d7dcc0e7 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -17,7 +17,7 @@ import com.codahale.metrics.Timer import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging -import org.opensearch.flint.core.metrics.{MetricConstants, ReadWriteBytesSparkListener} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsSparkListener, MetricsUtil} import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.apache.spark.SparkConf @@ -57,6 +57,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private val sessionRunningCount = new AtomicInteger(0) private val statementRunningCount = new AtomicInteger(0) + private var queryCountMetric = 0 def main(args: Array[String]) { val (queryOption, resultIndexOption) = parseArgs(args) @@ -186,9 +187,9 @@ object FlintREPL extends Logging with FlintJobExecutor { } recordSessionSuccess(sessionTimerContext) } catch { - case e: Exception => + case t: Throwable => handleSessionError( - e, + t, applicationId, jobId, sessionId, @@ -203,6 +204,10 @@ object FlintREPL extends Logging with FlintJobExecutor { stopTimer(sessionTimerContext) spark.stop() + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() + // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, // which may be due to unresolved bugs in dependencies or threads not being properly shut down. @@ -355,6 +360,11 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = updatedVerificationResult canPickUpNextStatement = updatedCanPickUpNextStatement lastCanPickCheckTime = updatedLastCanPickCheckTime + } catch { + case t: Throwable => + // Record and rethrow in query loop + throwableHandler.recordThrowable(s"Query loop execution failed.", t) + throw t } finally { statementsExecutionManager.terminateStatementExecution() } @@ -365,6 +375,7 @@ object FlintREPL extends Logging with FlintJobExecutor { if (threadPool != null) { threadPoolFactory.shutdownThreadPool(threadPool) } + MetricsUtil.addHistoricGauge(MetricConstants.REPL_QUERY_COUNT_METRIC, queryCountMetric) } } @@ -410,32 +421,40 @@ object FlintREPL extends Logging with FlintJobExecutor { error = error, excludedJobIds = excludedJobIds)) logInfo(s"Current session: ${sessionDetails}") - logInfo(s"State is: ${sessionDetails.state}") sessionDetails.state = state - logInfo(s"State is: ${sessionDetails.state}") + sessionDetails.error = error sessionManager.updateSessionDetails(sessionDetails, updateMode = UPSERT) + logInfo(s"Updated session: ${sessionDetails}") sessionDetails } def handleSessionError( - e: Exception, + t: Throwable, applicationId: String, jobId: String, sessionId: String, sessionManager: SessionManager, jobStartTime: Long, sessionTimerContext: Timer.Context): Unit = { - val error = s"Session error: ${e.getMessage}" - CustomLogging.logError(error, e) + val error = s"Session error: ${t.getMessage}" + throwableHandler.recordThrowable(error, t) + + try { + refreshSessionState( + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + SessionStates.FAIL, + Some(error)) + } catch { + case t: Throwable => + throwableHandler.recordThrowable( + s"Failed to update session state. Original error: $error", + t) + } - refreshSessionState( - applicationId, - jobId, - sessionId, - sessionManager, - jobStartTime, - SessionStates.FAIL, - Some(e.getMessage)) recordSessionFailed(sessionTimerContext) } @@ -483,8 +502,8 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } - def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { - val error = super.processQueryException(ex) + def processQueryException(t: Throwable, flintStatement: FlintStatement): String = { + val error = super.processQueryException(t) flintStatement.fail() flintStatement.error = Some(error) error @@ -521,11 +540,12 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement.running() statementExecutionManager.updateStatement(flintStatement) statementRunningCount.incrementAndGet() + queryCountMetric += 1 val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val (dataToWrite, returnedVerificationResult) = - ReadWriteBytesSparkListener.withMetrics( + MetricsSparkListener.withMetrics( spark, () => { processStatementOnVerification( @@ -578,11 +598,13 @@ object FlintREPL extends Logging with FlintJobExecutor { } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) - case e: Exception => - val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" - CustomLogging.logError(error, e) + case e: Throwable => + throwableHandler.recordThrowable( + s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""", + e) flintStatement.fail() } finally { + logInfo(s"command complete: $flintStatement") statementExecutionManager.updateStatement(flintStatement) recordStatementStateChange(flintStatement, statementTimerContext) } @@ -668,8 +690,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintStatement, sessionId, startTime)) - case e: Exception => - val error = processQueryException(e, flintStatement) + case t: Throwable => + val error = processQueryException(t, flintStatement) Some( handleCommandFailureAndGetFailedData( applicationId, @@ -744,7 +766,7 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime)) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" - CustomLogging.logError(error, e) + throwableHandler.recordThrowable(error, e) dataToWrite = Some( handleCommandFailureAndGetFailedData( applicationId, @@ -783,7 +805,6 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis) } - logInfo(s"command complete: $flintStatement") (dataToWrite, verificationResult) } @@ -855,7 +876,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } } } catch { - case e: Exception => logError(s"Failed to update session state for $sessionId", e) + case t: Throwable => + throwableHandler.recordThrowable(s"Failed to update session state for $sessionId", t) } } } @@ -894,10 +916,10 @@ object FlintREPL extends Logging with FlintJobExecutor { MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC ) // Record heartbeat failure metric // maybe due to invalid sequence number or primary term - case e: Exception => - CustomLogging.logWarning( + case t: Throwable => + throwableHandler.recordThrowable( s"""Fail to update the last update time of the flint instance ${sessionId}""", - e) + t) incrementCounter( MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC ) // Record heartbeat failure metric @@ -945,8 +967,10 @@ object FlintREPL extends Logging with FlintJobExecutor { } } catch { // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) - case e: Exception => - CustomLogging.logError(s"""Fail to find id ${sessionId} from session index.""", e) + case t: Throwable => + throwableHandler.recordThrowable( + s"""Fail to find id ${sessionId} from session index.""", + t) true } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 6cdbdb16d..27b0be84f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -14,7 +14,7 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.common.scheduler.model.LangType -import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil, ReadWriteBytesSparkListener} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsSparkListener, MetricsUtil} import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.spark.FlintSpark @@ -70,7 +70,7 @@ case class JobOperator( val statementExecutionManager = instantiateStatementExecutionManager(commandContext, resultIndex, osClient) - val readWriteBytesSparkListener = new ReadWriteBytesSparkListener() + val readWriteBytesSparkListener = new MetricsSparkListener() sparkSession.sparkContext.addSparkListener(readWriteBytesSparkListener) val statement = @@ -82,9 +82,6 @@ case class JobOperator( LangType.SQL, currentTimeProvider.currentEpochMillis()) - var exceptionThrown = true - var error: String = null - try { val futurePrepareQueryExecution = Future { statementExecutionManager.prepareStatementExecution() @@ -94,7 +91,7 @@ case class JobOperator( ThreadUtils.awaitResult(futurePrepareQueryExecution, Duration(1, MINUTES)) match { case Right(_) => data case Left(err) => - error = err + throwableHandler.setError(err) constructErrorDF( applicationId, jobId, @@ -107,11 +104,9 @@ case class JobOperator( "", startTime) }) - exceptionThrown = false } catch { case e: TimeoutException => - error = s"Preparation for query execution timed out" - logError(error, e) + throwableHandler.recordThrowable(s"Preparation for query execution timed out", e) dataToWrite = Some( constructErrorDF( applicationId, @@ -119,13 +114,13 @@ case class JobOperator( sparkSession, dataSource, "TIMEOUT", - error, + throwableHandler.error, queryId, query, "", startTime)) - case e: Exception => - val error = processQueryException(e) + case t: Throwable => + val error = processQueryException(t) dataToWrite = Some( constructErrorDF( applicationId, @@ -146,27 +141,32 @@ case class JobOperator( try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) } catch { - case e: Exception => - exceptionThrown = true - error = s"Failed to write to result index. originalError='${error}'" - logError(error, e) + case t: Throwable => + throwableHandler.recordThrowable( + s"Failed to write to result index. originalError='${throwableHandler.error}'", + t) } - if (exceptionThrown) statement.fail() else statement.complete() - statement.error = Some(error) - statementExecutionManager.updateStatement(statement) + if (throwableHandler.hasException) statement.fail() else statement.complete() + statement.error = Some(throwableHandler.error) - cleanUpResources(exceptionThrown, threadPool, startTime) + try { + statementExecutionManager.updateStatement(statement) + } catch { + case t: Throwable => + throwableHandler.recordThrowable( + s"Failed to update statement. originalError='${throwableHandler.error}'", + t) + } + + cleanUpResources(threadPool) } } - def cleanUpResources( - exceptionThrown: Boolean, - threadPool: ThreadPoolExecutor, - startTime: Long): Unit = { + def cleanUpResources(threadPool: ThreadPoolExecutor): Unit = { val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) try { // Wait for streaming job complete if no error - if (!exceptionThrown && isStreaming) { + if (!throwableHandler.hasException && isStreaming) { // Clean Spark shuffle data after each microBatch. sparkSession.streams.addListener(new ShuffleCleaner(sparkSession)) // Await index monitor before the main thread terminates @@ -174,7 +174,7 @@ case class JobOperator( } else { logInfo(s""" | Skip streaming job await due to conditions not met: - | - exceptionThrown: $exceptionThrown + | - exceptionThrown: ${throwableHandler.hasException} | - streaming: $isStreaming | - activeStreams: ${sparkSession.streams.active.mkString(",")} |""".stripMargin) @@ -190,7 +190,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } - recordStreamingCompletionStatus(exceptionThrown) + recordStreamingCompletionStatus(throwableHandler.hasException) // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, @@ -219,8 +219,13 @@ case class JobOperator( logInfo("Stopped Spark session") } match { case Success(_) => - case Failure(e) => logError("unexpected error while stopping spark session", e) + case Failure(e) => + throwableHandler.recordThrowable("unexpected error while stopping spark session", e) } + + // After handling any exceptions from stopping the Spark session, + // check if there's a stored exception and throw it if it's an UnrecoverableException + checkAndThrowUnrecoverableExceptions() } /** diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala new file mode 100644 index 000000000..01c90bdd4 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ThrowableHandler.scala @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import org.opensearch.flint.core.logging.CustomLogging + +/** + * Handles and manages exceptions and error messages during each emr job run. Provides methods to + * set, retrieve, and reset exception information. + */ +class ThrowableHandler { + private var _throwableOption: Option[Throwable] = None + private var _error: String = _ + + def exceptionThrown: Option[Throwable] = _throwableOption + def error: String = _error + + def recordThrowable(err: String, t: Throwable): Unit = { + _error = err + _throwableOption = Some(t) + CustomLogging.logError(err, t) + } + + def setError(err: String): Unit = { + _error = err + } + + def setThrowable(t: Throwable): Unit = { + _throwableOption = Some(t) + } + + def reset(): Unit = { + _throwableOption = None + _error = null + } + + def hasException: Boolean = _throwableOption.isDefined +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 07ed94bdc..7edb0d4c3 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -33,9 +33,11 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.SparkListenerApplicationEnd import org.apache.spark.sql.FlintREPL.PreShutdownListener import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY +import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.exception.UnrecoverableException import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} @@ -195,19 +197,44 @@ class FlintREPLTest scheduledFutureRaw }) - // Invoke the method FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) - // Verifications verify(sessionManager, atLeastOnce()).recordHeartbeat(sessionId) + FlintREPL.throwableHandler.hasException shouldBe false } - test("PreShutdownListener updates FlintInstance if conditions are met") { + test("createHeartBeatUpdater should handle unrecoverable exception") { + val threadPool = mock[ScheduledExecutorService] + val scheduledFutureRaw = mock[ScheduledFuture[_]] + val sessionManager = mock[SessionManager] + val sessionId = "session1" + + FlintREPL.throwableHandler.reset() + val unrecoverableException = + UnrecoverableException(new RuntimeException("Unrecoverable error")) + when(sessionManager.recordHeartbeat(sessionId)) + .thenThrow(unrecoverableException) + + when(threadPool + .scheduleAtFixedRate(any[Runnable], *, *, eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + .thenAnswer((invocation: InvocationOnMock) => { + val runnable = invocation.getArgument[Runnable](0) + runnable.run() + scheduledFutureRaw + }) + + FlintREPL.createHeartBeatUpdater(sessionId, sessionManager, threadPool) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + } + + test("PreShutdownListener updates InteractiveSession if conditions are met") { // Mock dependencies val sessionId = "testSessionId" val timerContext = mock[Timer.Context] val sessionManager = mock[SessionManager] + FlintREPL.throwableHandler.reset() val interactiveSession = new InteractiveSession( "app123", "job123", @@ -227,6 +254,28 @@ class FlintREPLTest interactiveSession.state shouldBe SessionStates.DEAD } + test("PreShutdownListener handles unrecoverable exception from sessionManager") { + val sessionId = "testSessionId" + val timerContext = mock[Timer.Context] + val sessionManager = mock[SessionManager] + + FlintREPL.throwableHandler.reset() + val unrecoverableException = + UnrecoverableException(new RuntimeException("Unrecoverable database error")) + when(sessionManager.getSessionDetails(sessionId)) + .thenThrow(unrecoverableException) + + val listener = new PreShutdownListener(sessionId, sessionManager, timerContext) + + listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + FlintREPL.throwableHandler.error shouldBe s"Failed to update session state for $sessionId" + + verify(sessionManager, never()) + .updateSessionDetails(any[InteractiveSession], any[SessionUpdateMode]) + } + test("Test super.constructErrorDF should construct dataframe properly") { // Define expected dataframe val dataSourceName = "myGlueS3" @@ -463,6 +512,29 @@ class FlintREPLTest assert(result) } + test("test canPickNextStatement: sessionManager throws unrecoverableException") { + val sessionId = "session123" + val jobId = "jobABC" + val sessionIndex = "sessionIndex" + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(sessionIndex) + + FlintREPL.throwableHandler.reset() + val sessionManager = mock[SessionManager] + val unrecoverableException = + UnrecoverableException(new RuntimeException("OpenSearch cluster unresponsive")) + when(sessionManager.getSessionDetails(sessionId)) + .thenThrow(unrecoverableException) + + val result = FlintREPL.canPickNextStatement(sessionId, sessionManager, jobId) + + assert(result) + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + } + test( "test canPickNextStatement: Doc Exists and excludeJobIds is a Single String Not Matching JobId") { val sessionId = "session123" @@ -521,6 +593,7 @@ class FlintREPLTest verify(mockFlintStatement).error = Some(expectedError) assert(result == expectedError) + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(exception) } test("processQueryException should handle MetaException with AccessDeniedException properly") { @@ -665,8 +738,6 @@ class FlintREPLTest override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] - val commandContext = CommandContext( applicationId, jobId, @@ -1026,6 +1097,87 @@ class FlintREPLTest assert(!result) // Expecting false as the job proceeds normally } + test("handleSessionError handles unrecoverable exception") { + val sessionManager = mock[SessionManager] + val timerContext = mock[Timer.Context] + val applicationId = "app123" + val jobId = "job123" + val sessionId = "session123" + val jobStartTime = System.currentTimeMillis() + + FlintREPL.throwableHandler.reset() + val unrecoverableException = + UnrecoverableException(new RuntimeException("Unrecoverable error")) + val interactiveSession = new InteractiveSession( + applicationId, + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + + FlintREPL.handleSessionError( + unrecoverableException, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + timerContext) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) + + verify(sessionManager).updateSessionDetails( + argThat { (session: InteractiveSession) => + session.applicationId == applicationId && + session.jobId == jobId && + session.sessionId == sessionId && + session.state == SessionStates.FAIL && + session.error.contains(s"Session error: ${unrecoverableException.getMessage}") + }, + any[SessionUpdateMode]) + + verify(timerContext).stop() + } + + test("handleSessionError handles exception during refreshSessionState") { + val sessionManager = mock[SessionManager] + val timerContext = mock[Timer.Context] + val applicationId = "app123" + val jobId = "job123" + val sessionId = "session123" + val jobStartTime = System.currentTimeMillis() + + FlintREPL.throwableHandler.reset() + val initialException = UnrecoverableException(new RuntimeException("Unrecoverable error")) + val refreshException = + UnrecoverableException(new RuntimeException("Failed to refresh session state")) + + val interactiveSession = new InteractiveSession( + applicationId, + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) + when(sessionManager.updateSessionDetails(any[InteractiveSession], any[SessionUpdateMode])) + .thenThrow(refreshException) + + FlintREPL.handleSessionError( + initialException, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime, + timerContext) + + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(refreshException) + verify(timerContext).stop() + } + test("queryLoop continue until inactivity limit is reached") { val resultIndex = "testResultIndex" val dataSource = "testDataSource" @@ -1064,7 +1216,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1133,7 +1284,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1198,7 +1348,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1255,7 +1404,8 @@ class FlintREPLTest mockOSClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) .thenReturn(mockReader) // Simulate an exception thrown when hasNext is called - when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) + val unrecoverableException = UnrecoverableException(new RuntimeException("Test exception")) + when(mockReader.hasNext).thenThrow(unrecoverableException) when(mockOSClient.doesIndexExist(*)).thenReturn(true) when(mockOSClient.getIndexMetadata(*)).thenReturn(FlintREPL.resultIndexMapping) @@ -1268,7 +1418,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1287,13 +1436,15 @@ class FlintREPLTest // Mocking ThreadUtils to track the shutdown call val mockThreadPool = mock[ScheduledExecutorService] FlintREPL.threadPoolFactory = new MockThreadPoolFactory(mockThreadPool) + FlintREPL.throwableHandler.reset() - intercept[RuntimeException] { + intercept[UnrecoverableException] { FlintREPL.queryLoop(commandContext) } // Verify if the shutdown method was called on the thread pool verify(mockThreadPool).shutdown() + FlintREPL.throwableHandler.exceptionThrown shouldBe Some(unrecoverableException) } finally { // Stop the SparkSession spark.stop() @@ -1436,7 +1587,6 @@ class FlintREPLTest val sessionManager = new SessionManagerImpl(spark, Some(resultIndex)) { override val osClient: OSClient = mockOSClient } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId,