diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index bb8f697ec..834a2a201 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -11,6 +11,11 @@ To execute the unit tests, run the following command: ``` sbt test ``` +To run a specific unit test in SBT, use the testOnly command with the full path of the test class: +``` +sbt "; project pplSparkIntegration; test:testOnly org.opensearch.flint.spark.ppl.PPLLogicalPlanTrendlineCommandTranslatorTestSuite" +``` + ## Integration Test The integration test is defined in the `integration` directory of the project. The integration tests will automatically trigger unit tests and will only run if all unit tests pass. If you want to run the integration test for the project, you can do so by running the following command: @@ -23,6 +28,13 @@ If you get integration test failures with error message "Previous attempts to fi 3. Run `sudo ln -s $HOME/.docker/desktop/docker.sock /var/run/docker.sock` or `sudo ln -s $HOME/.docker/run/docker.sock /var/run/docker.sock` 4. If you use Docker Desktop, as an alternative of `3`, check mark the "Allow the default Docker socket to be used (requires password)" in advanced settings of Docker Desktop. +Running only a selected set of integration test suites is possible with the following command: +``` +sbt "; project integtest; it:testOnly org.opensearch.flint.spark.ppl.FlintSparkPPLTrendlineITSuite" +``` +This command runs only the specified test suite within the integtest submodule. + + ### AWS Integration Test The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index cb50431f6..7766c3b50 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -50,6 +50,10 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where a < 1 | fields a,b,c` - `source = table | where b != 'test' | fields a,b,c` - `source = table | where c = 'test' | fields a,b,c | head 3` +- `source = table | where c = 'test' AND a = 1 | fields a,b,c` +- `source = table | where c != 'test' OR a > 1 | fields a,b,c` +- `source = table | where (b > 1 OR a > 1) AND c != 'test' | fields a,b,c` +- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - Note: "AND" is optional - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` @@ -61,6 +65,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | trendline sma(2, temperature) as temp_trend` +- `source = table | trendline sort timestamp wma(2, temperature) as temp_trend` #### **IP related queries** [See additional command details](functions/ppl-ip.md) diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md index 393a9dd59..b466e2e8f 100644 --- a/docs/ppl-lang/ppl-trendline-command.md +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -3,8 +3,7 @@ **Description** Using ``trendline`` command to calculate moving averages of fields. - -### Syntax +### Syntax - SMA (Simple Moving Average) `TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. @@ -13,8 +12,6 @@ Using ``trendline`` command to calculate moving averages of fields. * field: mandatory. the name of the field the moving average should be calculated for. * alias: optional. the name of the resulting column containing the moving average. -And the moment only the Simple Moving Average (SMA) type is supported. - It is calculated like f[i]: The value of field 'f' in the i-th data-point @@ -23,7 +20,7 @@ It is calculated like SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t -### Example 1: Calculate simple moving average for a timeseries of temperatures +#### Example 1: Calculate simple moving average for a timeseries of temperatures The example calculates the simple moving average over temperatures using two datapoints. @@ -41,7 +38,7 @@ PPL query: | 15| 258|2023-04-06 17:07:...| 14.5| +-----------+---------+--------------------+----------+ -### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. @@ -58,3 +55,58 @@ PPL query: | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| +-----------+---------+--------------------+------------+------------------+ + + +### Syntax - WMA (Weighted Moving Average) +`TRENDLINE sort <[+|-] sort-field> WMA(number-of-datapoints, field) [AS alias] [WMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory. this field specifies the ordering of data poients when calculating the nth_value aggregation. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving averag should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data point + n: The number of data points in the moving window (period) + t: The current time index + w[i]: The weight assigned to the i-th data point, typically increasing for more recent points + + WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) + +#### Example 1: Calculate weighted moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sort timestamp wma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.6| + | 14| 257|2023-04-06 17:07:...| 13.6| + | 15| 258|2023-04-06 17:07:...| 14.6| + +-----------+---------+--------------------+----------+ + +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id wma(2, temperature) as temp_trend_2 wma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.3| NULL| + | 13| 256|2023-04-06 17:07:...| 13.3| 13.6| + | 12| 1492|2023-04-06 17:07:...| 12.3| 12.6| + | 12| 1492|2023-04-06 17:07:...| 12.0| 12.16| + +-----------+---------+--------------------+------------+------------------+ diff --git a/docs/ppl-lang/ppl-where-command.md b/docs/ppl-lang/ppl-where-command.md index c954623c3..aa7d9299e 100644 --- a/docs/ppl-lang/ppl-where-command.md +++ b/docs/ppl-lang/ppl-where-command.md @@ -27,15 +27,15 @@ PPL query: ### Additional Examples #### **Filters With Logical Conditions** -``` -- `source = table | where c = 'test' AND a = 1 | fields a,b,c` -- `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1` -- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - `source = table | where a = 1 | fields a,b,c` - `source = table | where a >= 1 | fields a,b,c` - `source = table | where a < 1 | fields a,b,c` - `source = table | where b != 'test' | fields a,b,c` - `source = table | where c = 'test' | fields a,b,c | head 3` +- `source = table | where c = 'test' AND a = 1 | fields a,b,c` +- `source = table | where c != 'test' OR a > 1 | fields a,b,c` +- `source = table | where (b > 1 OR a > 1) AND c != 'test' | fields a,b,c` +- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - Note: "AND" is optional - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` @@ -45,7 +45,6 @@ PPL query: - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - - `source = table | eval status_category = case(a >= 200 AND a < 300, 'Success', a >= 300 AND a < 400, 'Redirection', @@ -57,10 +56,8 @@ PPL query: a >= 400 AND a < 500, 'Client Error', a >= 500, 'Server Error' else 'Incorrect HTTP status code' - ) = 'Incorrect HTTP status code' - + ) = 'Incorrect HTTP status code'` - `source = table | eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1) | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even' | stats count() by factor` -``` \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q19.ppl b/integ-test/src/integration/resources/tpch/q19.ppl index 630d63bcc..63312d2f0 100644 --- a/integ-test/src/integration/resources/tpch/q19.ppl +++ b/integ-test/src/integration/resources/tpch/q19.ppl @@ -37,25 +37,30 @@ where */ source = lineitem -| join ON p_partkey = l_partkey - and p_brand = 'Brand#12' - and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') - and l_quantity >= 1 and l_quantity <= 1 + 10 - and p_size between 1 and 5 - and l_shipmode in ('AIR', 'AIR REG') - and l_shipinstruct = 'DELIVER IN PERSON' - OR p_partkey = l_partkey - and p_brand = 'Brand#23' - and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') - and l_quantity >= 10 and l_quantity <= 10 + 10 - and p_size between 1 and 10 - and l_shipmode in ('AIR', 'AIR REG') - and l_shipinstruct = 'DELIVER IN PERSON' - OR p_partkey = l_partkey - and p_brand = 'Brand#34' - and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') - and l_quantity >= 20 and l_quantity <= 20 + 10 - and p_size between 1 and 15 - and l_shipmode in ('AIR', 'AIR REG') - and l_shipinstruct = 'DELIVER IN PERSON' +| join ON + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) OR ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) OR ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) part \ No newline at end of file diff --git a/integ-test/src/integration/resources/tpch/q7.ppl b/integ-test/src/integration/resources/tpch/q7.ppl index ceda602b3..a6ea66d63 100644 --- a/integ-test/src/integration/resources/tpch/q7.ppl +++ b/integ-test/src/integration/resources/tpch/q7.ppl @@ -48,7 +48,7 @@ source = [ | join ON s_nationkey = n1.n_nationkey nation as n1 | join ON c_nationkey = n2.n_nationkey nation as n2 | where l_shipdate between date('1995-01-01') and date('1996-12-31') - and n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY' or n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE' + and ((n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')) | eval supp_nation = n1.n_name, cust_nation = n2.n_name, l_year = year(l_shipdate), volume = l_extendedprice * (1 - l_discount) | fields supp_nation, cust_nation, l_year, volume ] as shipping diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index f2d7ee844..62c735597 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -467,4 +467,96 @@ class FlintSparkPPLFiltersITSuite val expectedPlan = Project(Seq(UnresolvedAttribute("state")), filter) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test parenthesis in filter") { + val frame = sql(s""" + | source = $testTable | where country = 'Canada' or age > 60 and age < 25 | fields name, age, country + | """.stripMargin) + assertSameRows(Seq(Row("John", 25, "Canada"), Row("Jane", 20, "Canada")), frame) + + val frameWithParenthesis = sql(s""" + | source = $testTable | where (country = 'Canada' or age > 60) and age < 25 | fields name, age, country + | """.stripMargin) + assertSameRows(Seq(Row("Jane", 20, "Canada")), frameWithParenthesis) + + val logicalPlan: LogicalPlan = frameWithParenthesis.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filter = Filter( + And( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("Canada")), + GreaterThan(UnresolvedAttribute("age"), Literal(60))), + LessThan(UnresolvedAttribute("age"), Literal(25))), + table) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("country")), + filter) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test complex and nested parenthesis in filter") { + val frame1 = sql(s""" + | source = $testTable | WHERE (age > 18 AND (state = 'California' OR state = 'New York')) + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame1) + + val frame2 = sql(s""" + | source = $testTable | WHERE ((((age > 18) AND ((((state = 'California') OR state = 'New York')))))) + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame2) + + val frame3 = sql(s""" + | source = $testTable | WHERE (year = 2023 AND (month BETWEEN 1 AND 6)) AND (age >= 31 OR country = 'Canada') + | """.stripMargin) + assertSameRows( + Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4)), + frame3) + + val frame4 = sql(s""" + | source = $testTable | WHERE ((state = 'Texas' OR state = 'California') AND (age < 30 OR (country = 'USA' AND year > 2020))) + | """.stripMargin) + assertSameRows(Seq(Row("Jake", 70, "California", "USA", 2023, 4)), frame4) + + val frame5 = sql(s""" + | source = $testTable | WHERE (LIKE(LOWER(name), 'a%') OR LIKE(LOWER(name), 'j%')) AND (LENGTH(state) > 6 OR (country = 'USA' AND age > 18)) + | """.stripMargin) + assertSameRows( + Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame5) + + val frame6 = sql(s""" + | source = $testTable | WHERE (age BETWEEN 25 AND 40) AND ((state IN ('California', 'New York', 'Texas') AND year = 2023) OR (country != 'USA' AND (month = 1 OR month = 12))) + | """.stripMargin) + assertSameRows(Seq(Row("Hello", 30, "New York", "USA", 2023, 4)), frame6) + + val frame7 = sql(s""" + | source = $testTable | WHERE NOT (age < 18 OR (state = 'Alaska' AND year < 2020)) AND (country = 'USA' OR (country = 'Mexico' AND month BETWEEN 6 AND 8)) + | """.stripMargin) + assertSameRows( + Seq( + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4)), + frame7) + + val frame8 = sql(s""" + | source = $testTable | WHERE (NOT (year < 2020 OR age < 18)) AND ((state = 'Texas' AND month % 2 = 0) OR (country = 'Mexico' AND (year = 2023 OR (year = 2022 AND month > 6)))) + | """.stripMargin) + assertSameRows(Seq(), frame8) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index bc4463537..9a8379288 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -5,9 +5,14 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -244,4 +249,265 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) } + + test("test trendline wma command with sort field and without alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#185] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command with sort field and with alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) as trendline_alias + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "trendline_alias"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline wma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(2, age) as two_points_wma wma(3, age) as three_points_wma + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "two_points_wma", + "three_points_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 23.333333333333332, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 28.333333333333332, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 56.666666666666664, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS two_points_wma#247, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age wma(2, doubled_age) as doubled_age_wma | fields name, doubled_age, doubled_age_wma + | """.stripMargin) + + // Compare the headers + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 46.666666666666664), + Row("Hello", 60, 56.666666666666664), + Row("Jake", 140, 113.33333333333333)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + getNthValueAggregation("doubled_age", "age", 1, -1), + getNthValueAggregation("doubled_age", "age", 2, -1)) + val wmaExpression = Divide(dividend, Literal(3)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val doubledAged = Alias( + UnresolvedFunction( + seq("*"), + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false), + "doubled_age")() + val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) + val sortedTable = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("doubled_age"), + UnresolvedAttribute("doubled_age_wma")), + Project(trendlineProjectList, sortedTable)) + + /** + * 'Project ['name, 'doubled_age, 'doubled_age_wma] +- 'Project [*, (( + * ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('doubled_age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS doubled_age_wma#288] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'Project [*, '`*`('age, 2) AS doubled_age#287] +- 'UnresolvedRelation [spark_catalog, + * default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + } + + test("test invalid wma command with negative dataPoint value") { + val exception = intercept[ParseException](sql(s""" + | source = $testTable | trendline sort + age wma(-3, age) + | """.stripMargin)) + assert(exception.getMessage contains "[PARSE_SYNTAX_ERROR] Syntax error") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 10b2e01b8..3ce8b6f1e 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -96,6 +96,7 @@ NULLS: 'NULLS'; //TRENDLINE KEYWORDS SMA: 'SMA'; +WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 63efd8c6c..357673e73 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -267,11 +267,12 @@ trendlineCommand ; trendlineClause - : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + : trendlineType LT_PRTHS numberOfDataPoints = INTEGER_LITERAL COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? ; trendlineType : SMA + | WMA ; kmeansCommand @@ -424,6 +425,7 @@ expression logicalExpression : NOT logicalExpression # logicalNot + | LT_PRTHS logicalExpression RT_PRTHS # parentheticLogicalExpr | comparisonExpression # comparsion | left = logicalExpression (AND)? right = logicalExpression # logicalAnd | left = logicalExpression OR right = logicalExpression # logicalOr diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 9fa1ae81d..d08e89e3b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -62,6 +62,6 @@ public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dat } public enum TrendlineType { - SMA + SMA, WMA } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f039bf47f..86970cefb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -292,7 +292,6 @@ public enum BuiltinFunctionName { MULTIMATCHQUERY(FunctionName.of("multimatchquery")), WILDCARDQUERY(FunctionName.of("wildcardquery")), WILDCARD_QUERY(FunctionName.of("wildcard_query")), - COALESCE(FunctionName.of("coalesce")); private FunctionName name; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 00a7905f0..debd37376 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -245,7 +245,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); } - trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), node.getSortByField(), context)); return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 36d9f9577..e683a1395 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -157,6 +157,11 @@ public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArit ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); } + @Override + public UnresolvedExpression visitParentheticLogicalExpr(OpenSearchPPLParser.ParentheticLogicalExprContext ctx) { + return visit(ctx.logicalExpression()); // Discard parenthesis around + } + @Override public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) { return visit(ctx.valueExpression()); // Discard parenthesis around diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 67603ccc7..647f4542e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -5,31 +5,40 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.*; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.collection.mutable.Seq; import scala.Option; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; +import static scala.collection.JavaConverters.asScalaBufferConverter; public interface TrendlineCatalystUtils { - static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, Optional sortField, CatalystPlanContext context) { return computations.stream() - .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, sortField, context)) .collect(Collectors.toList()); } - static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, Optional sortField, CatalystPlanContext context) { + //window lower boundary expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); Expression windowLowerBoundary = context.popNamedParseExpressions().get(); @@ -40,26 +49,28 @@ static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expre seq(), new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); - if (node.getComputationType() == Trendline.TrendlineType.SMA) { - //calculate avg value of the data field - expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); - Expression avgFunction = context.popNamedParseExpressions().get(); - - //sma window - WindowExpression sma = new WindowExpression( - avgFunction, - windowDefinition); - - CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); - - return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, - node.getAlias(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList())); - } else { - throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + switch (node.getComputationType()) { + case SMA: + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return getAlias(node.getAlias(), smaOrNull); + case WMA: + if (sortField.isPresent()) { + return getWMAComputationExpression(expressionVisitor, node, sortField.get(), context); + } else { + throw new IllegalArgumentException(node.getComputationType()+" requires a sort field for computation"); + } + default: + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); } } @@ -84,4 +95,136 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr ); return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); } + + /** + * Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan + * with configuration [dataField=salary, sortField=age, dataPoints=3] + * -- +- 'Project [ + * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) + * -- AS WMA#702] + * . + * And the corresponded SQL query: + * . + * SELECT name, salary, + * ( nth_value(salary, 1) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *1 + + * nth_value(salary, 2) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *2 + + * nth_value(salary, 3) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *3 )/6 AS WMA + * FROM employees + * ORDER BY age; + * + * @param visitor Visitor instance to process any UnresolvedExpression. + * @param node Trendline command's arguments. + * @param sortField Field used for window aggregation. + * @param context Context instance to retrieved Expression in resolved form. + * @return a NamedExpression instance which will calculate WMA with provided argument. + */ + private static NamedExpression getWMAComputationExpression(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + Field sortField, + CatalystPlanContext context) { + int dataPoints = node.getNumberOfDataPoints(); + //window lower boundary + Expression windowLowerBoundary = parseIntToExpression(visitor, context, + Math.negateExact(dataPoints - 1)); + //window definition + visitor.analyze(sortField, context); + Expression sortDefinition = context.popNamedParseExpressions().get(); + WindowSpecDefinition windowDefinition = getWmaCommonWindowDefinition( + sortDefinition, + SortUtils.isSortedAscending(sortField), + windowLowerBoundary); + // Divisor + Expression divisor = parseIntToExpression(visitor, context, + (dataPoints * (dataPoints + 1) / 2)); + // Aggregation + Expression wmaExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) + .stream() + .reduce(Add::new) + .orElse(null); + + return getAlias(node.getAlias(), new Divide(wmaExpression, divisor)); + } + + /** + * Helper method to produce an Alias Expression with provide value and name. + * @param name The name for the Alias. + * @param expression The expression which will be evaluated. + * @return An Alias instance with logical plan representation of `expression AS name`. + */ + private static NamedExpression getAlias(String name, Expression expression) { + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression, + name, + NamedExpression.newExprId(), + seq(Collections.emptyList()), + Option.empty(), + seq(Collections.emptyList())); + } + + /** + * Helper method to retrieve an Int expression instance for logical plan composition purpose. + * @param expressionVisitor Visitor instance to process the incoming object. + * @param context Context instance to retrieve the Expression instance. + * @param i Target value for the expression. + * @return An expression object which contain integer value i. + */ + static Expression parseIntToExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { + expressionVisitor.visitLiteral(new Literal(i, + DataType.INTEGER), context); + return context.popNamedParseExpressions().get(); + } + + + /** + * Helper method to retrieve a WindowSpecDefinition with provided sorting condition. + * `windowspecdefinition('sortField ascending NULLS FIRST, specifiedwindowframe(RowFrame, windowLowerBoundary, currentrow$())` + * + * @param sortField The field being used for the sorting operation. + * @param ascending The boolean instance for the sorting order. + * @param windowLowerBoundary The Integer expression instance which specify the even lookbehind / lookahead. + * @return A WindowSpecDefinition instance which will be used to composite the WMA calculation. + */ + static WindowSpecDefinition getWmaCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) { + return new WindowSpecDefinition( + seq(), + seq(SortUtils.sortOrder(sortField, ascending)), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + } + + /** + * To produce a list of Expressions responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below. + * (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * + * @param visitor Visitor instance to resolve Expression. + * @param node Treeline command instruction. + * @param context Context instance to retrieve the resolved expression. + * @param windowDefinition The windowDefinition for the individual datapoint lookbehind / lookahead. + * @param dataPoints Number of data-points for WMA calculation, this will always equal to number of Expression being generated. + * @return List instance which contain the SQL statement for WMA individual datapoint's calculations. + */ + private static List getNthValueAggregations(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + CatalystPlanContext context, + WindowSpecDefinition windowDefinition, + int dataPoints) { + List expressions = new ArrayList<>(); + for (int i = 1; i <= dataPoints; i++) { + // Get the offset parameter + Expression offSetExpression = parseIntToExpression(visitor, context, i); + // Get the dataField in Expression + visitor.analyze(node.getDataField(), context); + Expression dataField = context.popNamedParseExpressions().get(); + // nth_value Expression + UnresolvedFunction nthValueExp = new UnresolvedFunction( + asScalaBufferConverter(List.of("nth_value")).asScala().seq(), + asScalaBufferConverter(List.of(dataField, offSetExpression)).asScala().seq(), + false, empty(), false); + + expressions.add(new Multiply( + new WindowExpression(nthValueExp, windowDefinition), offSetExpression)); + } + return expressions; + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala new file mode 100644 index 000000000..a70415aab --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala @@ -0,0 +1,244 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Not, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanParenthesizedConditionTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test simple nested condition") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (age > 18 AND (state = 'California' OR state = 'New York'))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + GreaterThan(UnresolvedAttribute("age"), Literal(18)), + Or( + EqualTo(UnresolvedAttribute("state"), Literal("California")), + EqualTo(UnresolvedAttribute("state"), Literal("New York")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test nested condition with duplicated parentheses") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE ((((age > 18) AND ((((state = 'California') OR state = 'New York'))))))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + GreaterThan(UnresolvedAttribute("age"), Literal(18)), + Or( + EqualTo(UnresolvedAttribute("state"), Literal("California")), + EqualTo(UnresolvedAttribute("state"), Literal("New York")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test combining between function") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (year = 2023 AND (month BETWEEN 1 AND 6)) AND (age >= 31 OR country = 'Canada')"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val betweenCondition = And( + GreaterThanOrEqual(UnresolvedAttribute("month"), Literal(1)), + LessThanOrEqual(UnresolvedAttribute("month"), Literal(6))) + val filter = Filter( + And( + And(EqualTo(UnresolvedAttribute("year"), Literal(2023)), betweenCondition), + Or( + GreaterThanOrEqual(UnresolvedAttribute("age"), Literal(31)), + EqualTo(UnresolvedAttribute("country"), Literal("Canada")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple levels of nesting") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE ((state = 'Texas' OR state = 'California') AND (age < 30 OR (country = 'USA' AND year > 2020)))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Or( + EqualTo(UnresolvedAttribute("state"), Literal("Texas")), + EqualTo(UnresolvedAttribute("state"), Literal("California"))), + Or( + LessThan(UnresolvedAttribute("age"), Literal(30)), + And( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + GreaterThan(UnresolvedAttribute("year"), Literal(2020))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test with string functions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (LIKE(LOWER(name), 'a%') OR LIKE(LOWER(name), 'j%')) AND (LENGTH(state) > 6 OR (country = 'USA' AND age > 18))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Or( + UnresolvedFunction( + "like", + Seq( + UnresolvedFunction("lower", Seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("a%")), + isDistinct = false), + UnresolvedFunction( + "like", + Seq( + UnresolvedFunction("lower", Seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("j%")), + isDistinct = false)), + Or( + GreaterThan( + UnresolvedFunction("length", Seq(UnresolvedAttribute("state")), isDistinct = false), + Literal(6)), + And( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + GreaterThan(UnresolvedAttribute("age"), Literal(18))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex age ranges with nested conditions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (age BETWEEN 25 AND 40) AND ((state IN ('California', 'New York', 'Texas') AND year = 2023) OR (country != 'USA' AND (month = 1 OR month = 12)))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + And( + GreaterThanOrEqual(UnresolvedAttribute("age"), Literal(25)), + LessThanOrEqual(UnresolvedAttribute("age"), Literal(40))), + Or( + And( + In( + UnresolvedAttribute("state"), + Seq(Literal("California"), Literal("New York"), Literal("Texas"))), + EqualTo(UnresolvedAttribute("year"), Literal(2023))), + And( + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + Or( + EqualTo(UnresolvedAttribute("month"), Literal(1)), + EqualTo(UnresolvedAttribute("month"), Literal(12)))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test nested NOT conditions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE NOT (age < 18 OR (state = 'Alaska' AND year < 2020)) AND (country = 'USA' OR (country = 'Mexico' AND month BETWEEN 6 AND 8))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Not( + Or( + LessThan(UnresolvedAttribute("age"), Literal(18)), + And( + EqualTo(UnresolvedAttribute("state"), Literal("Alaska")), + LessThan(UnresolvedAttribute("year"), Literal(2020))))), + Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + And( + EqualTo(UnresolvedAttribute("country"), Literal("Mexico")), + And( + GreaterThanOrEqual(UnresolvedAttribute("month"), Literal(6)), + LessThanOrEqual(UnresolvedAttribute("month"), Literal(8)))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex boolean logic") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (NOT (year < 2020 OR age < 18)) AND ((state = 'Texas' AND month % 2 = 0) OR (country = 'Mexico' AND (year = 2023 OR (year = 2022 AND month > 6))))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Not( + Or( + LessThan(UnresolvedAttribute("year"), Literal(2020)), + LessThan(UnresolvedAttribute("age"), Literal(18)))), + Or( + And( + EqualTo(UnresolvedAttribute("state"), Literal("Texas")), + EqualTo( + UnresolvedFunction( + "%", + Seq(UnresolvedAttribute("month"), Literal(2)), + isDistinct = false), + Literal(0))), + And( + EqualTo(UnresolvedAttribute("country"), Literal("Mexico")), + Or( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + And( + EqualTo(UnresolvedAttribute("year"), Literal(2022)), + GreaterThan(UnresolvedAttribute("month"), Literal(6))))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index d22750ee0..ec1775631 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -6,12 +6,15 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} @@ -132,4 +135,147 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite Project(trendlineProjectList, sort)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } + + test("WMA - with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age)"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#0] ! +- 'Sort ['age ASC NULLS FIRST], true ! +- + * 'UnresolvedRelation [relation], [], false + */ + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("WMA - with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - multiple trendline commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), + context) + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - with negative dataPoint value") { + val context = new CatalystPlanContext + val exception = intercept[SyntaxCheckException]( + planTransformer + .visit(plan(pplParser, "source=relation | trendline sort age wma(-3, age)"), context)) + assert(exception.getMessage startsWith "Failed to parse query due to offending symbol [-]") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + }