From 38defbf51cf1d87dcd3c543b550f385457e50b7c Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 01:19:15 -0700 Subject: [PATCH] New trendline ppl command (WMA) (#872) (#907) * WMA implementation * Update test cases * Update tests * Refactor code * Addres comments * Update doc * Update example * Update readme * Update scalafmt * Update grammar rule * Address review comments * Address review comments --------- (cherry picked from commit 439cf3e1bcb5daef54be92195a100ca3b26d90cd) Signed-off-by: Andy Kwok Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- DEVELOPER_GUIDE.md | 12 + docs/ppl-lang/PPL-Example-Commands.md | 1 + docs/ppl-lang/ppl-trendline-command.md | 64 ++++- .../ppl/FlintSparkPPLTrendlineITSuite.scala | 268 +++++++++++++++++- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 3 +- .../opensearch/sql/ast/tree/Trendline.java | 2 +- .../function/BuiltinFunctionName.java | 1 - .../sql/ppl/CatalystQueryPlanVisitor.java | 2 +- .../sql/ppl/utils/TrendlineCatalystUtils.java | 193 +++++++++++-- ...nTrendlineCommandTranslatorTestSuite.scala | 148 +++++++++- 11 files changed, 658 insertions(+), 37 deletions(-) diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index bb8f697ec..834a2a201 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -11,6 +11,11 @@ To execute the unit tests, run the following command: ``` sbt test ``` +To run a specific unit test in SBT, use the testOnly command with the full path of the test class: +``` +sbt "; project pplSparkIntegration; test:testOnly org.opensearch.flint.spark.ppl.PPLLogicalPlanTrendlineCommandTranslatorTestSuite" +``` + ## Integration Test The integration test is defined in the `integration` directory of the project. The integration tests will automatically trigger unit tests and will only run if all unit tests pass. If you want to run the integration test for the project, you can do so by running the following command: @@ -23,6 +28,13 @@ If you get integration test failures with error message "Previous attempts to fi 3. Run `sudo ln -s $HOME/.docker/desktop/docker.sock /var/run/docker.sock` or `sudo ln -s $HOME/.docker/run/docker.sock /var/run/docker.sock` 4. If you use Docker Desktop, as an alternative of `3`, check mark the "Allow the default Docker socket to be used (requires password)" in advanced settings of Docker Desktop. +Running only a selected set of integration test suites is possible with the following command: +``` +sbt "; project integtest; it:testOnly org.opensearch.flint.spark.ppl.FlintSparkPPLTrendlineITSuite" +``` +This command runs only the specified test suite within the integtest submodule. + + ### AWS Integration Test The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 4ea564111..f26088b8a 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -61,6 +61,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | trendline sma(2, temperature) as temp_trend` +- `source = table | trendline sort timestamp wma(2, temperature) as temp_trend` #### **IP related queries** [See additional command details](functions/ppl-ip.md) diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md index 393a9dd59..b466e2e8f 100644 --- a/docs/ppl-lang/ppl-trendline-command.md +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -3,8 +3,7 @@ **Description** Using ``trendline`` command to calculate moving averages of fields. - -### Syntax +### Syntax - SMA (Simple Moving Average) `TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. @@ -13,8 +12,6 @@ Using ``trendline`` command to calculate moving averages of fields. * field: mandatory. the name of the field the moving average should be calculated for. * alias: optional. the name of the resulting column containing the moving average. -And the moment only the Simple Moving Average (SMA) type is supported. - It is calculated like f[i]: The value of field 'f' in the i-th data-point @@ -23,7 +20,7 @@ It is calculated like SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t -### Example 1: Calculate simple moving average for a timeseries of temperatures +#### Example 1: Calculate simple moving average for a timeseries of temperatures The example calculates the simple moving average over temperatures using two datapoints. @@ -41,7 +38,7 @@ PPL query: | 15| 258|2023-04-06 17:07:...| 14.5| +-----------+---------+--------------------+----------+ -### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. @@ -58,3 +55,58 @@ PPL query: | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| +-----------+---------+--------------------+------------+------------------+ + + +### Syntax - WMA (Weighted Moving Average) +`TRENDLINE sort <[+|-] sort-field> WMA(number-of-datapoints, field) [AS alias] [WMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory. this field specifies the ordering of data poients when calculating the nth_value aggregation. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving averag should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data point + n: The number of data points in the moving window (period) + t: The current time index + w[i]: The weight assigned to the i-th data point, typically increasing for more recent points + + WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) + +#### Example 1: Calculate weighted moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sort timestamp wma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.6| + | 14| 257|2023-04-06 17:07:...| 13.6| + | 15| 258|2023-04-06 17:07:...| 14.6| + +-----------+---------+--------------------+----------+ + +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id wma(2, temperature) as temp_trend_2 wma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.3| NULL| + | 13| 256|2023-04-06 17:07:...| 13.3| 13.6| + | 12| 1492|2023-04-06 17:07:...| 12.3| 12.6| + | 12| 1492|2023-04-06 17:07:...| 12.0| 12.16| + +-----------+---------+--------------------+------------+------------------+ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index bc4463537..9a8379288 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -5,9 +5,14 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -244,4 +249,265 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) } + + test("test trendline wma command with sort field and without alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#185] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command with sort field and with alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) as trendline_alias + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "trendline_alias"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline wma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(2, age) as two_points_wma wma(3, age) as three_points_wma + | """.stripMargin) + + // Compare the headers + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "two_points_wma", + "three_points_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 23.333333333333332, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 28.333333333333332, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 56.666666666666664, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS two_points_wma#247, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age wma(2, doubled_age) as doubled_age_wma | fields name, doubled_age, doubled_age_wma + | """.stripMargin) + + // Compare the headers + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 46.666666666666664), + Row("Hello", 60, 56.666666666666664), + Row("Jake", 140, 113.33333333333333)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val dividend = Add( + getNthValueAggregation("doubled_age", "age", 1, -1), + getNthValueAggregation("doubled_age", "age", 2, -1)) + val wmaExpression = Divide(dividend, Literal(3)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val doubledAged = Alias( + UnresolvedFunction( + seq("*"), + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false), + "doubled_age")() + val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) + val sortedTable = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("doubled_age"), + UnresolvedAttribute("doubled_age_wma")), + Project(trendlineProjectList, sortedTable)) + + /** + * 'Project ['name, 'doubled_age, 'doubled_age_wma] +- 'Project [*, (( + * ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('doubled_age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS doubled_age_wma#288] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'Project [*, '`*`('age, 2) AS doubled_age#287] +- 'UnresolvedRelation [spark_catalog, + * default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + } + + test("test invalid wma command with negative dataPoint value") { + val exception = intercept[ParseException](sql(s""" + | source = $testTable | trendline sort + age wma(-3, age) + | """.stripMargin)) + assert(exception.getMessage contains "[PARSE_SYNTAX_ERROR] Syntax error") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 2c3344b3c..cb323f794 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -94,6 +94,7 @@ NULLS: 'NULLS'; //TRENDLINE KEYWORDS SMA: 'SMA'; +WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 1cfd172f7..e43b71fb0 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -265,11 +265,12 @@ trendlineCommand ; trendlineClause - : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + : trendlineType LT_PRTHS numberOfDataPoints = INTEGER_LITERAL COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? ; trendlineType : SMA + | WMA ; kmeansCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 9fa1ae81d..d08e89e3b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -62,6 +62,6 @@ public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dat } public enum TrendlineType { - SMA + SMA, WMA } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 1959d0f6d..a8de957cb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -291,7 +291,6 @@ public enum BuiltinFunctionName { MULTIMATCHQUERY(FunctionName.of("multimatchquery")), WILDCARDQUERY(FunctionName.of("wildcardquery")), WILDCARD_QUERY(FunctionName.of("wildcard_query")), - COALESCE(FunctionName.of("coalesce")); private FunctionName name; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index d2ee46ae6..bc6e486f4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -247,7 +247,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); } - trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), node.getSortByField(), context)); return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 67603ccc7..647f4542e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -5,31 +5,40 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.*; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.collection.mutable.Seq; import scala.Option; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; +import static scala.collection.JavaConverters.asScalaBufferConverter; public interface TrendlineCatalystUtils { - static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, Optional sortField, CatalystPlanContext context) { return computations.stream() - .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, sortField, context)) .collect(Collectors.toList()); } - static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, Optional sortField, CatalystPlanContext context) { + //window lower boundary expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); Expression windowLowerBoundary = context.popNamedParseExpressions().get(); @@ -40,26 +49,28 @@ static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expre seq(), new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); - if (node.getComputationType() == Trendline.TrendlineType.SMA) { - //calculate avg value of the data field - expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); - Expression avgFunction = context.popNamedParseExpressions().get(); - - //sma window - WindowExpression sma = new WindowExpression( - avgFunction, - windowDefinition); - - CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); - - return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, - node.getAlias(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList())); - } else { - throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + switch (node.getComputationType()) { + case SMA: + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return getAlias(node.getAlias(), smaOrNull); + case WMA: + if (sortField.isPresent()) { + return getWMAComputationExpression(expressionVisitor, node, sortField.get(), context); + } else { + throw new IllegalArgumentException(node.getComputationType()+" requires a sort field for computation"); + } + default: + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); } } @@ -84,4 +95,136 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr ); return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); } + + /** + * Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan + * with configuration [dataField=salary, sortField=age, dataPoints=3] + * -- +- 'Project [ + * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) + * -- AS WMA#702] + * . + * And the corresponded SQL query: + * . + * SELECT name, salary, + * ( nth_value(salary, 1) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *1 + + * nth_value(salary, 2) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *2 + + * nth_value(salary, 3) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *3 )/6 AS WMA + * FROM employees + * ORDER BY age; + * + * @param visitor Visitor instance to process any UnresolvedExpression. + * @param node Trendline command's arguments. + * @param sortField Field used for window aggregation. + * @param context Context instance to retrieved Expression in resolved form. + * @return a NamedExpression instance which will calculate WMA with provided argument. + */ + private static NamedExpression getWMAComputationExpression(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + Field sortField, + CatalystPlanContext context) { + int dataPoints = node.getNumberOfDataPoints(); + //window lower boundary + Expression windowLowerBoundary = parseIntToExpression(visitor, context, + Math.negateExact(dataPoints - 1)); + //window definition + visitor.analyze(sortField, context); + Expression sortDefinition = context.popNamedParseExpressions().get(); + WindowSpecDefinition windowDefinition = getWmaCommonWindowDefinition( + sortDefinition, + SortUtils.isSortedAscending(sortField), + windowLowerBoundary); + // Divisor + Expression divisor = parseIntToExpression(visitor, context, + (dataPoints * (dataPoints + 1) / 2)); + // Aggregation + Expression wmaExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) + .stream() + .reduce(Add::new) + .orElse(null); + + return getAlias(node.getAlias(), new Divide(wmaExpression, divisor)); + } + + /** + * Helper method to produce an Alias Expression with provide value and name. + * @param name The name for the Alias. + * @param expression The expression which will be evaluated. + * @return An Alias instance with logical plan representation of `expression AS name`. + */ + private static NamedExpression getAlias(String name, Expression expression) { + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression, + name, + NamedExpression.newExprId(), + seq(Collections.emptyList()), + Option.empty(), + seq(Collections.emptyList())); + } + + /** + * Helper method to retrieve an Int expression instance for logical plan composition purpose. + * @param expressionVisitor Visitor instance to process the incoming object. + * @param context Context instance to retrieve the Expression instance. + * @param i Target value for the expression. + * @return An expression object which contain integer value i. + */ + static Expression parseIntToExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { + expressionVisitor.visitLiteral(new Literal(i, + DataType.INTEGER), context); + return context.popNamedParseExpressions().get(); + } + + + /** + * Helper method to retrieve a WindowSpecDefinition with provided sorting condition. + * `windowspecdefinition('sortField ascending NULLS FIRST, specifiedwindowframe(RowFrame, windowLowerBoundary, currentrow$())` + * + * @param sortField The field being used for the sorting operation. + * @param ascending The boolean instance for the sorting order. + * @param windowLowerBoundary The Integer expression instance which specify the even lookbehind / lookahead. + * @return A WindowSpecDefinition instance which will be used to composite the WMA calculation. + */ + static WindowSpecDefinition getWmaCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) { + return new WindowSpecDefinition( + seq(), + seq(SortUtils.sortOrder(sortField, ascending)), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + } + + /** + * To produce a list of Expressions responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below. + * (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * + * @param visitor Visitor instance to resolve Expression. + * @param node Treeline command instruction. + * @param context Context instance to retrieve the resolved expression. + * @param windowDefinition The windowDefinition for the individual datapoint lookbehind / lookahead. + * @param dataPoints Number of data-points for WMA calculation, this will always equal to number of Expression being generated. + * @return List instance which contain the SQL statement for WMA individual datapoint's calculations. + */ + private static List getNthValueAggregations(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + CatalystPlanContext context, + WindowSpecDefinition windowDefinition, + int dataPoints) { + List expressions = new ArrayList<>(); + for (int i = 1; i <= dataPoints; i++) { + // Get the offset parameter + Expression offSetExpression = parseIntToExpression(visitor, context, i); + // Get the dataField in Expression + visitor.analyze(node.getDataField(), context); + Expression dataField = context.popNamedParseExpressions().get(); + // nth_value Expression + UnresolvedFunction nthValueExp = new UnresolvedFunction( + asScalaBufferConverter(List.of("nth_value")).asScala().seq(), + asScalaBufferConverter(List.of(dataField, offSetExpression)).asScala().seq(), + false, empty(), false); + + expressions.add(new Multiply( + new WindowExpression(nthValueExp, windowDefinition), offSetExpression)); + } + return expressions; + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index d22750ee0..ec1775631 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -6,12 +6,15 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} @@ -132,4 +135,147 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite Project(trendlineProjectList, sort)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } + + test("WMA - with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age)"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#0] ! +- 'Sort ['age ASC NULLS FIRST], true ! +- + * 'UnresolvedRelation [relation], [], false + */ + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("WMA - with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), + context) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val wmaExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - multiple trendline commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), + context) + + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, + * + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("WMA - with negative dataPoint value") { + val context = new CatalystPlanContext + val exception = intercept[SyntaxCheckException]( + planTransformer + .visit(plan(pplParser, "source=relation | trendline sort age wma(-3, age)"), context)) + assert(exception.getMessage startsWith "Failed to parse query due to offending symbol [-]") + } + + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), + Literal(lookBackPos)) + } + }